Teste si le tableau numpy ne contient que des zéros


92

Nous initialisons un tableau numpy avec des zéros comme ci-dessous:

np.zeros((N,N+1))

Mais comment vérifier si tous les éléments d'une matrice de tableau n * n numpy donnée sont égaux à zéro.
La méthode a juste besoin de renvoyer un True si toutes les valeurs sont effectivement nulles.

Réponses:



161

Les autres réponses publiées ici fonctionneront, mais la fonction la plus claire et la plus efficace à utiliser est numpy.any():

>>> all_zeros = not np.any(a)

ou

>>> all_zeros = not a.any()
  • Ceci est préféré numpy.all(a==0)car il utilise moins de RAM. (Il ne nécessite pas le tableau temporaire créé par le a==0terme.)
  • En outre, il est plus rapide que numpy.count_nonzero(a)parce qu'il peut revenir immédiatement lorsque le premier élément différent de zéro a été trouvé.
    • Edit: Comme @Rachel l'a souligné dans les commentaires, np.any()n'utilise plus la logique de "court-circuit", donc vous ne verrez pas d'avantage de vitesse pour les petits tableaux.

2
Comme il y a une minute, numpy de anyet allne pas court-circuit. Je crois qu'ils sont du sucre pour logical_or.reduceet logical_and.reduce. Comparez les uns aux autres et mon court-circuit is_in: all_false = np.zeros(10**8) all_true = np.ones(10**8) %timeit np.any(all_false) 91.5 ms ± 1.82 ms per loop %timeit np.any(all_true) 93.7 ms ± 6.16 ms per loop %timeit is_in(1, all_true) 293 ns ± 1.65 ns per loop
Rachel

2
C'est un excellent point, merci. Il ressemble à un court-circuit utilisé pour le comportement, mais qui a été perdu à un moment donné. Il y a une discussion intéressante dans les réponses à cette question .
Stuart Berg

50

J'utiliserais np.all ici, si vous avez un tableau a:

>>> np.all(a==0)

3
J'aime que cette réponse vérifie également les valeurs non nulles. Par exemple, on peut vérifier si tous les éléments d'un tableau sont identiques en faisant np.all(a==a[0]). Merci beaucoup!
aignas

9

Comme le dit une autre réponse, vous pouvez tirer parti des évaluations véridiques / fausses si vous savez que 0c'est le seul élément faux éventuellement dans votre tableau. Tous les éléments d'un tableau sont faux ssi il n'y a aucun élément de vérité. *

>>> a = np.zeros(10)
>>> not np.any(a)
True

Cependant, la réponse affirmait que anyc'était plus rapide que les autres options en raison en partie du court-circuit. À partir de 2018, Numpy allet any ne court-circuitent pas .

Si vous faites souvent ce genre de chose, il est très facile de créer vos propres versions de court-circuit en utilisant numba:

import numba as nb

# short-circuiting replacement for np.any()
@nb.jit(nopython=True)
def sc_any(array):
    for x in array.flat:
        if x:
            return True
    return False

# short-circuiting replacement for np.all()
@nb.jit(nopython=True)
def sc_all(array):
    for x in array.flat:
        if not x:
            return False
    return True

Celles-ci ont tendance à être plus rapides que les versions de Numpy même lorsqu'elles ne sont pas en court-circuit. count_nonzeroest le plus lent.

Quelques entrées pour vérifier les performances:

import numpy as np

n = 10**8
middle = n//2
all_0 = np.zeros(n, dtype=int)
all_1 = np.ones(n, dtype=int)
mid_0 = np.ones(n, dtype=int)
mid_1 = np.zeros(n, dtype=int)
np.put(mid_0, middle, 0)
np.put(mid_1, middle, 1)
# mid_0 = [1 1 1 ... 1 0 1 ... 1 1 1]
# mid_1 = [0 0 0 ... 0 1 0 ... 0 0 0]

Vérifier:

## count_nonzero
%timeit np.count_nonzero(all_0) 
# 220 ms ± 8.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.count_nonzero(all_1)
# 150 ms ± 4.56 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

### all
# np.all
%timeit np.all(all_1)
%timeit np.all(mid_0)
%timeit np.all(all_0)
# 56.8 ms ± 3.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.4 ms ± 1.76 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 55.9 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_all
%timeit sc_all(all_1)
%timeit sc_all(mid_0)
%timeit sc_all(all_0)
# 44.4 ms ± 2.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.7 ms ± 599 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 288 ns ± 6.36 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

### any
# np.any
%timeit np.any(all_0)
%timeit np.any(mid_1)
%timeit np.any(all_1)
# 60.7 ms ± 1.38 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 60 ms ± 287 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.7 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_any
%timeit sc_any(all_0)
%timeit sc_any(mid_1)
%timeit sc_any(all_1)
# 41.7 ms ± 1.24 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.4 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 287 ns ± 12.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

* Utile allet anyéquivalences:

np.all(a) == np.logical_not(np.any(np.logical_not(a)))
np.any(a) == np.logical_not(np.all(np.logical_not(a)))
not np.all(a) == np.any(np.logical_not(a))
not np.any(a) == np.all(np.logical_not(a))

-9

Si vous testez tous les zéros pour éviter un avertissement sur une autre fonction numpy, encapsulez la ligne dans un essai, sauf que le bloc vous évitera d'avoir à faire le test des zéros avant l'opération qui vous intéresse ie

try: # removes output noise for empty slice 
    mean = np.mean(array)
except:
    mean = 0
En utilisant notre site, vous reconnaissez avoir lu et compris notre politique liée aux cookies et notre politique de confidentialité.
Licensed under cc by-sa 3.0 with attribution required.