3
votes

Comptez le nombre de valeurs non nulles dans un tableau numpy dans Numba

Très simple. J'essaie de compter le nombre de valeurs non nulles dans un tableau dans NumPy jit compilé avec Numba ( njit () ). Ce que j'ai essayé n'est pas autorisé par Numba.

  1. a [a! = 0] .size
  2. np.count_nonzero(a)
  3. len (a [a! = 0])
  4. len (a) - len (a [a == 0])

Je ne veux pas utiliser pour les boucles s'il existe encore un moyen plus rapide, plus pythonique et élégant.

Pour ce commentateur qui voulait voir un exemple de code complet ...

import numpy as np
from numba import njit

@njit()
def n_nonzero(a):
    return a[a != 0].size


3 commentaires

Veuillez afficher au moins un morceau de code réel et complet que vous avez essayé, y compris les instructions import , les décorateurs et un exemple de faisceau de test.


[list (filter ((0) .__ ne__, l)) pour l dans a]


@MarkSetchell Bien sûr ...


4 Réponses :


1
votes

Vous pouvez utiliser np.nonzero et induire la longueur de celui-ci:

@njit
def count_non_zero(np_arr):
    return len(np.nonzero(np_arr)[0])

count_non_zero(np.array([0,1,0,1]))
# 2


2 commentaires

cela [0] semble être la chose qui l'a fait. Merci beaucoup!


Le plus drôle est que np.nonzero utilise np.count_nonzero (au niveau c-api) pour déterminer la taille des tableaux qu'il remplira lors d'une deuxième itération. Je pensais que tout l'intérêt d'utiliser numba était de pouvoir itérer en toute impunité. :)



1
votes

Je ne sais pas si j'ai fait une erreur ici, mais cela semble 6 fois plus rapide:

# Make something worth checking
a=np.random.randint(0,3,1000000000,dtype=np.uint8)  

In [41]: @njit() 
    ...: def methodA(a): 
    ...:     return len(np.nonzero(a)[0])                                                                                           

# Call and check result
In [42]: methodA(a)                                                                                 
Out[42]: 666644445

In [43]: %timeit methodA(a)                                                                         
4.65 s ± 28.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [44]: @njit() 
    ...: def methodB(a): 
    ...:     return (a!=0).sum()                                                                                         

# Call and check result    
In [45]: methodB(a)                                                                                 
Out[45]: 666644445

In [46]: %timeit methodB(a)                                                                         
724 ms ± 14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


0 commentaires

4
votes

Vous pouvez également envisager, eh bien, de compter les valeurs différentes de zéro:

%timeit np.count_nonzero(a)
# 4.36 s ± 69.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Je sais que cela semble faux, mais soyez patient:

import numpy as np
import numba as nb

@nb.njit()
def count_loop(a):
    s = 0
    for i in a:
        if i != 0:
            s += 1
    return s

@nb.njit()
def count_len_nonzero(a):
    return len(np.nonzero(a)[0])

@nb.njit()
def count_sum_neq_zero(a):
    return (a != 0).sum()

np.random.seed(100)
a = np.random.randint(0, 3, 1000000000, dtype=np.uint8)
c = np.count_nonzero(a)
assert count_len_nonzero(a) == c
assert count_sum_neq_zero(a) == c
assert count_loop(a) == c

%timeit count_len_nonzero(a)
# 5.94 s ± 141 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit count_sum_neq_zero(a)
# 848 ms ± 80.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit count_loop(a)
# 189 ms ± 4.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Il est en fait plus rapide que np.count_nonzero , qui peut devenir assez lent pour une raison quelconque:

import numba as nb

@nb.njit()
def count_loop(a):
    s = 0
    for i in a:
        if i != 0:
            s += 1
    return s


1 commentaires

Ouais, numba excelle vraiment quand il voit une boucle qu'il peut optimiser. +1



3
votes

Si vous en avez besoin très rapidement pour les grands tableaux, vous pouvez même utiliser numbas prange pour traiter le décompte en parallèle (pour les petits tableaux, il sera plus lent en raison de la surcharge de traitement parallèle).

import numpy as np
from numba import njit, prange

@njit
def n_nonzero(a):
    return a[a != 0].size

@njit
def count_non_zero(np_arr):
    return len(np.nonzero(np_arr)[0])

@njit() 
def methodB(a): 
    return (a!=0).sum()

@njit(parallel=True)
def parallel_nonzero_count(arr):
    flattened = arr.ravel()
    sum_ = 0
    for i in prange(flattened.size):
        sum_ += flattened[i] != 0
    return sum_

@njit()
def count_loop(a):
    s = 0
    for i in a:
        if i != 0:
            s += 1
    return s

from simple_benchmark import benchmark

args = {}
for exp in range(2, 20):
    size = 2**exp
    arr = np.random.random(size)
    arr[arr < 0.3] = 0.0
    args[size] = arr

b = benchmark(
    funcs=(n_nonzero, count_non_zero, methodB, np.count_nonzero, parallel_nonzero_count, count_loop),
    arguments=args,
    argument_name='array size',
    warmups=(n_nonzero, count_non_zero, methodB, np.count_nonzero, parallel_nonzero_count, count_loop)
)


2 commentaires

Est-ce le partage sécurisé de sum_ entre des itérations de boucles parallèles? (Je ne connais pas grand chose aux garanties de Numba parallélisé)


Oui, numba a quelques réductions qu'il peut paralléliser en toute sécurité. la sommation et les multiplications en font partie. C'est parce que numba se rend compte qu'il peut traiter en parallèle en utilisant sum_ = 0 pour chacun d'eux, puis les ajouter simplement après la fin de chaque processus. J'ai également vérifié la cohérence par rapport à np.count_nonzero .