1
votes

Algorithme rapide pour la fonction log gamma

J'essaie d'écrire un algorithme rapide pour calculer la fonction log gamma . Actuellement, mon implémentation semble naïve, et répète juste 10 millions de fois pour calculer le journal de la fonction gamma (j'utilise également numba pour optimiser le code).

import numpy as np
from numba import njit
EULER_MAS = 0.577215664901532 # euler mascheroni constant
HARMONC_10MIL = 16.695311365860007 # sum of 1/k from 1 to 10,000,000

@njit(fastmath=True)
def gammaln(z):
"""Compute log of gamma function for some real positive float z"""
    out = -EULER_MAS*z - np.log(z) + z*HARMONC_10MIL
    n = 10000000 # number of iters
    for k in range(1,n+1,4):
        # loop unrolling
        v1 = np.log(1 + z/k)
        v2 = np.log(1 + z/(k+1))
        v3 = np.log(1 + z/(k+2))
        v4 = np.log(1 + z/(k+3))
        out -= v1 + v2 + v3 + v4

    return out

J'ai chronométré mon code par rapport au scipy.special.gammaln et la mienne est littéralement 100 000 fois plus lent. Donc je fais quelque chose de très mal ou de très naïf (probablement les deux). Bien que mes réponses soient au moins correctes à moins de 4 décimales au pire par rapport à scipy.

J'ai essayé de lire le code _ufunc implémentant la fonction gammaln de scipy, mais je ne comprends pas le code cython que le _gammaln est écrite.

Existe-t-il un moyen plus rapide et plus optimisé de calculer la fonction log gamma? Comment puis-je comprendre l'implémentation de scipy pour pouvoir l'intégrer à la mienne?


12 commentaires

Qu'est-ce qu'un exemple d'entrée pour z ? Je ne connais pas la formule, mais cela ne signifie pas que les gens ne peuvent pas essayer de vectoriser cela - nous devons savoir comment appeler la fonction à tester, cependant.


De plus, si nous parlons de 100000 fois plus lent que Scipy, assurez-vous qu'il ne nous faut pas un âge pour l'exécuter avec l'exemple d'entrée :)


. @ roganjosh L'exécution de la fonction avec l'argument 1 a pris environ 50 ms sur ma machine, donc je suppose que ce serait sûr de partir


@ user8408080 oki doki. L'entrée est-elle censée être un int ou un tableau le savez-vous?


Autant que je sache, il peut s'agir de n'importe quel nombre complexe (voir ici ). Mais un seul numéro


@ user8408080 ok, merci. Quelque chose n'allait pas avec votre timing, car j'obtiens % timeit gammaln (1) 23,2 s ± 1,79 s par boucle (moyenne ± dev. Standard de 7 courses, 1 boucle chacune) . La suppression de n est une solution assez simple pour cela, cependant, pour les tests.


@roganjosh J'ai réessayé et je n'ai toujours eu que 50 ms environ. Quelle version numpy / numba utilisez-vous? Je travaille uniquement sur un i5-3470


Par intérêt: pourquoi ne souhaitez-vous pas utiliser la fonction fournie par scipy ? Publié ci-dessous une réponse qui devrait vous aider.


@ user8408080 les horaires étaient sans numba mais j'y suis retourné et j'ai réessayé avec numba et cela prend encore des siècles. Comment chronométrez-vous cela? J'ai le sentiment que vous ne capturez que le wrapper njit et la définition de la fonction, et pas réellement le temps de traitement. numpy 1.14.5 , numba 0.38.0 .


@roganjosh J'ai utilisé la magie % timeit d'IPython comme ceci: % timeit gammaln (1.5) . Est-ce une mauvaise pratique?


@ user8408080 non, c'est exactement ce que je fais! Que se passe t-il ici?! Vous avez conservé n = 10000000 ? Je veux dire, je fais ça sur un ordinateur portable, mais cet écart est fou.


L'implémentation gammaln utilisée par scipy est écrite en C. github .com / scipy / scipy / blob / master / scipy / special / cephes / gamm‌ ac (lgam). Le nom de l'implémentation backend peut être trouvé ici github.com/ scipy / scipy / blob / master / scipy / special / functions.j‌ fils


3 Réponses :


2
votes

Le temps d'exécution de votre fonction sera mis à l'échelle linéairement (jusqu'à une surcharge constante) avec le nombre d'itérations. Réduire le nombre d'itérations est donc essentiel pour accélérer l'algorithme. Bien que calculer à l'avance HARMONIC_10MIL soit une idée intelligente, cela conduit en fait à une moins bonne précision lorsque vous tronquez la série; le calcul d'une partie seulement de la série s'avère donner une plus grande précision.

Le code ci-dessous est une version modifiée du code posté ci-dessus (bien qu'en utilisant cython au lieu de numba ) .

t - log1p(t) = t ** 2 / 2 + ...

Il est capable d'obtenir une approximation proche même après 100 approximations comme le montre la figure ci-dessous.

 entrez la description de l'image ici

À 100 itérations, son exécution est du même ordre de grandeur que scipy.special.gammaln:

log1p(t) = t - t ** 2 / 2 + ...

La question restante est bien sûr le nombre d'itérations à utiliser. La fonction log1p (t) peut être étendue comme une série de Taylor pour les petits t (ce qui est pertinent dans la limite des grands k ). En particulier,

%timeit special.gammaln(5)
# 932 ns ± 19 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit gammaln(5, 100)
# 1.25 µs ± 20.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

de telle sorte que, pour un grand k , l'argument de la somme devienne

from libc.math cimport log, log1p
cimport cython
cdef:
    float EULER_MAS = 0.577215664901532 # euler mascheroni constant

@cython.cdivision(True)
def gammaln(float z, int n=1000):
    """Compute log of gamma function for some real positive float z"""
    cdef:
        float out = -EULER_MAS*z - log(z)
        int k
        float t
    for k in range(1, n):
        t = z / k
        out += t - log1p(t)

    return out

Par conséquent, l'argument de la somme est nul jusqu'au second ordre dans t ce qui est négligeable si t est suffisamment petit. En d'autres termes, le nombre d'itérations doit être au moins aussi grand que z , de préférence au moins un ordre de grandeur plus grand.

Cependant, je m'en tiens à l'implémentation bien testée de scipy si possible.


4 commentaires

excellente réponse, cela semble vraiment fonctionner rapidement dans votre exemple! une question stupide, comment puis-je obtenir la bibliothèque libc.math? J'ai déjà installé Cython par pip, mais je n'arrive pas à trouver la bibliothèque libc.math.


libc.math devrait être inclus par défaut, je pense. Cependant, je fais régulièrement l'erreur d'écrire import plutôt que cimport pour cython includes. Est-ce peut-être le problème?


Cela ne semble pas fonctionner avec une permutation de import ou cimport dans le code que vous avez ci-dessus ... Je peux exécuter import cython mais pas de libc.math cimport ... (cela entraîne une erreur de syntaxe) ou de libc.math import ... (cela entraîne ModuleNotFoundError)


mon erreur @Till Hoffmann, j'exécutais cela dans un ordinateur portable Jupyter sans la configuration appropriée. Ça marche. Merci beaucoup!



0
votes

J'ai réussi à obtenir une augmentation des performances d'environ 3x en essayant le mode parallèle de numba et en utilisant principalement des fonctions vectorisées (malheureusement, numba ne comprend pas numpy.substract.reduce )

#Your function:
%timeit gammaln(1.5)
48.6 ms ± 1.23 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

#My function:
%timeit gammaln_vec(1.5)
15 ms ± 340 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

#scpiy's function
%timeit gammaln_sp(1.5)
1.07 µs ± 18.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

Times:

from functools import reduce
import numpy as np
from numba import njit

@njit(fastmath=True, parallel=True)
def gammaln_vec(z):
    out = -EULER_MAS*z - np.log(z) + z*HARMONC_10MIL
    n = 10000000

    v = np.log(1 + z/np.arange(1, n+1))

    return out-reduce(lambda x1, x2: x1-x2, v, 0)

Alors encore, vous serez beaucoup mieux en utilisant la fonction de scipy. Sans code C, je ne sais pas comment le décomposer davantage


0 commentaires

0
votes

En ce qui concerne vos questions précédentes, je suppose qu'un exemple d'encapsulation des fonctions scipy.special dans Numba est également utile.

Exemple

Emballage Les fonctions cdef de Cython sont assez simples et portables tant qu'il n'y a que des types de données simples impliqués (int, double, double *, ...). Pour la documentation sur la façon d'appeler les fonctions scipy.special consultez ceci . Les noms de fonction dont vous avez réellement besoin pour envelopper la fonction se trouvent dans scipy.special.cython_special .__ pyx_capi__ . Les noms de fonction, qui peuvent être appelés avec différents types de données sont mutilés, mais déterminer le bon est assez facile (il suffit de regarder les types de données)

data=np.random.rand(1_000_000)
Test_func(A): 39.1ms
gammaln(A):   39.1ms

Utilisation dans Numba

#Numba example with loops
import numba as nb
import numpy as np
@nb.njit()
def Test_func(A):
  out=np.empty(A.shape[0])
  for i in range(A.shape[0]):
    out[i]=numba_gammaln(A[i])
  return out

Timings

#slightly modified version of https://github.com/numba/numba/issues/3086
from numba.extending import get_cython_function_address
from numba import vectorize, njit
import ctypes
import numpy as np

_PTR = ctypes.POINTER
_dble = ctypes.c_double
_ptr_dble = _PTR(_dble)

addr = get_cython_function_address("scipy.special.cython_special", "gammaln")
functype = ctypes.CFUNCTYPE(_dble, _dble)
gammaln_float64 = functype(addr)

@njit
def numba_gammaln(x):
  return gammaln_float64(x)

Bien sûr, vous pouvez facilement paralléliser cette fonction et surpassent l'implémentation de gammaln à thread unique dans scipy et vous pouvez appeler cette fonction efficacement dans n'importe quelle fonction compilée par Numba.


0 commentaires