Comment fonctionne python numpy.where ()?


94

Je joue avec numpyet fouille dans la documentation et je suis tombé sur de la magie. A savoir, je parle de numpy.where():

>>> x = np.arange(9.).reshape(3, 3)
>>> np.where( x > 5 )
(array([2, 2, 2]), array([0, 1, 2]))

Comment parviennent-ils en interne à ce que vous puissiez passer quelque chose comme x > 5une méthode? Je suppose que cela a quelque chose à voir avec, __gt__mais je cherche une explication détaillée.

Réponses:


75

Comment parviennent-ils en interne à ce que vous puissiez passer quelque chose comme x> 5 dans une méthode?

La réponse courte est que non.

Toute sorte d'opération logique sur un tableau numpy renvoie un tableau booléen. (c'est-à __gt__- dire __lt__,, etc. tous renvoient des tableaux booléens où la condition donnée est vraie).

Par exemple

x = np.arange(9).reshape(3,3)
print x > 5

donne:

array([[False, False, False],
       [False, False, False],
       [ True,  True,  True]], dtype=bool)

C'est la même raison pour laquelle quelque chose comme if x > 5:lève une ValueError si xest un tableau numpy. C'est un tableau de valeurs True / False, pas une seule valeur.

De plus, les tableaux numpy peuvent être indexés par des tableaux booléens. Par exemple, les x[x>5]rendements [6 7 8], dans ce cas.

Honnêtement, il est assez rare que vous en ayez réellement besoin, numpy.wheremais cela renvoie simplement les indications où se trouve un tableau booléen True. Vous pouvez généralement faire ce dont vous avez besoin avec une simple indexation booléenne.


10
Juste pour souligner qu'il numpy.wherey a 2 `` modes de fonctionnement '', le premier retourne les paramètres optionnels indices, where condition is Trueet if xet ysont présents (même forme que condition, ou diffusables à une telle forme!), Il retournera des valeurs de xquand condition is Trueautrement y. Cela le rend donc whereplus polyvalent et lui permet d'être utilisé plus souvent. Merci
manger le

1
Il peut également y avoir une surcharge dans certains cas en utilisant la __getitem__syntaxe []over soit numpy.whereou numpy.take. Étant donné qu'il __getitem__doit également prendre en charge le découpage, il y a des frais généraux. J'ai constaté des différences de vitesse notables lors de l'utilisation des structures de données Python Pandas et de l'indexation logique de très grandes colonnes. Dans ces cas, si vous n'avez pas besoin de trancher, alors takeet wherec'est mieux.
ely

24

Old Answer, c'est un peu déroutant. Il vous donne les LIEUX (tous) où votre déclaration est vraie.

alors:

>>> a = np.arange(100)
>>> np.where(a > 30)
(array([31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
       48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
       65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
       82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98,
       99]),)
>>> np.where(a == 90)
(array([90]),)

a = a*40
>>> np.where(a > 1000)
(array([26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
       43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
       60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
       77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93,
       94, 95, 96, 97, 98, 99]),)
>>> a[25]
1000
>>> a[26]
1040

Je l'utilise comme alternative à list.index (), mais il a également de nombreuses autres utilisations. Je ne l'ai jamais utilisé avec des tableaux 2D.

http://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html

Nouvelle réponse Il semble que la personne demandait quelque chose de plus fondamental.

La question était de savoir comment vous pourriez implémenter quelque chose qui permet à une fonction (comme où) de savoir ce qui a été demandé.

Notez tout d'abord qu'appeler l'un des opérateurs de comparaison fait une chose intéressante.

a > 1000
array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True`,  True,  True,  True,  True,  True,  True,  True,  True,  True], dtype=bool)`

Cela se fait en surchargeant la méthode "__gt__". Par exemple:

>>> class demo(object):
    def __gt__(self, item):
        print item


>>> a = demo()
>>> a > 4
4

Comme vous pouvez le voir, "a> 4" était un code valide.

Vous pouvez obtenir une liste complète et une documentation de toutes les fonctions surchargées ici: http://docs.python.org/reference/datamodel.html

Ce qui est incroyable, c'est à quel point il est simple de le faire. TOUTES les opérations en python sont effectuées de cette manière. Dire a> b équivaut à a. gt (b)!


3
Cette surcharge d'opérateurs de comparaison ne semble pas bien fonctionner avec des expressions logiques plus complexes - par exemple, je ne peux pas le faire np.where(a > 30 and a < 50)ou np.where(30 < a < 50)parce qu'elle finit par essayer d'évaluer le ET logique de deux tableaux de booléens, ce qui n'a pas de sens. Existe-t-il un moyen d'écrire une telle condition avec np.where?
davidA

@meowsqueaknp.where((a > 30) & (a < 50))
tibalt

Pourquoi np.where () renvoie-t-il une liste dans votre exemple?
Andreas Yankopolus

0

np.whererenvoie un tuple de longueur égale à la dimension du ndarray numpy sur lequel il est appelé (en d'autres termes ndim) et chaque élément du tuple est un ndarray numpy d'indices de toutes ces valeurs dans le ndarray initial pour lequel la condition est True. (Veuillez ne pas confondre dimension avec forme)

Par exemple:

x=np.arange(9).reshape(3,3)
print(x)
array([[0, 1, 2],
      [3, 4, 5],
      [6, 7, 8]])
y = np.where(x>4)
print(y)
array([1, 2, 2, 2], dtype=int64), array([2, 0, 1, 2], dtype=int64))


y est un tuple de longueur 2 parce qu'il x.ndimest 2. Le premier élément du tuple contient les numéros de ligne de tous les éléments supérieurs à 4 et le deuxième élément contient les numéros de colonne de tous les éléments supérieurs à 4. Comme vous pouvez le voir, [1,2,2 , 2] correspond aux numéros de ligne de 5,6,7,8 et [2,0,1,2] correspond aux numéros de colonne de 5,6,7,8. Notez que le ndarray est parcouru le long de la première dimension (en ligne ).

De même,

x=np.arange(27).reshape(3,3,3)
np.where(x>4)


renverra un tuple de longueur 3 car x a 3 dimensions.

Mais attendez, il y a plus à np.where!

lorsque deux arguments supplémentaires sont ajoutés np.where; il effectuera une opération de remplacement pour toutes ces combinaisons ligne-colonne par paires qui sont obtenues par le tuple ci-dessus.

x=np.arange(9).reshape(3,3)
y = np.where(x>4, 1, 0)
print(y)
array([[0, 0, 0],
   [0, 0, 1],
   [1, 1, 1]])
En utilisant notre site, vous reconnaissez avoir lu et compris notre politique liée aux cookies et notre politique de confidentialité.
Licensed under cc by-sa 3.0 with attribution required.