NumPy sélectionnant un index de colonne spécifique par ligne à l'aide d'une liste d'index


90

J'ai du mal à sélectionner les colonnes spécifiques par ligne d'une NumPymatrice.

Supposons que j'ai la matrice suivante que j'appellerais X:

[1, 2, 3]
[4, 5, 6]
[7, 8, 9]

J'ai également un listindex de colonnes pour chaque ligne que j'appellerais Y:

[1, 0, 2]

J'ai besoin d'obtenir les valeurs:

[2]
[4]
[9]

Au lieu d'un listavec des index Y, je peux également produire une matrice avec la même forme que Xoù chaque colonne est un bool/ intdans la valeur 0-1, indiquant s'il s'agit de la colonne requise.

[0, 1, 0]
[1, 0, 0]
[0, 0, 1]

Je sais que cela peut être fait en itérant sur le tableau et en sélectionnant les valeurs de colonne dont j'ai besoin. Cependant, cela sera fréquemment exécuté sur de grands tableaux de données et c'est pourquoi il doit fonctionner aussi vite que possible.

Je me demandais donc s'il y avait une meilleure solution?

Merci.


La réponse est-elle meilleure pour vous? stackoverflow.com/a/17081678/5046896
GoingMyWay

Réponses:


102

Si vous avez un tableau booléen, vous pouvez faire une sélection directe basée sur cela comme ceci:

>>> a = np.array([True, True, True, False, False])
>>> b = np.array([1,2,3,4,5])
>>> b[a]
array([1, 2, 3])

Pour accompagner votre exemple initial, vous pouvez procéder comme suit:

>>> a = np.array([[1,2,3], [4,5,6], [7,8,9]])
>>> b = np.array([[False,True,False],[True,False,False],[False,False,True]])
>>> a[b]
array([2, 4, 9])

Vous pouvez également ajouter une arangesélection directe et effectuer une sélection directe à ce sujet, selon la façon dont vous générez votre tableau booléen et à quoi ressemble votre code YMMV.

>>> a = np.array([[1,2,3], [4,5,6], [7,8,9]])
>>> a[np.arange(len(a)), [1,0,2]]
array([2, 4, 9])

J'espère que cela vous aidera, faites-moi savoir si vous avez d'autres questions.


11
+1 pour l'exemple utilisant arange. Cela m'a été particulièrement utile pour récupérer différents blocs de plusieurs matrices (donc fondamentalement le cas 3D de cet exemple)
Griddo

1
Salut, pouvez-vous expliquer pourquoi nous devons utiliser à la arangeplace de :? Je sais que votre méthode fonctionne et la mienne ne fonctionne pas, mais j'aimerais comprendre pourquoi.
marcotama

@tamzord parce que c'est un tableau numpy et non une liste python vanille, donc la :syntaxe ne fonctionne pas de la même manière.
Slater Victoroff

1
@SlaterTyranus, merci d'avoir répondu. Ma compréhension, après quelques lectures, est que mélanger :avec l'indexation avancée signifie: "pour chaque sous-espace le long :, appliquer l'indexation avancée donnée". Ma compréhension est-elle correcte?
marcotama

@tamzord explique ce que vous entendez par "sous-espace"
Slater Victoroff

35

Vous pouvez faire quelque chose comme ceci:

In [7]: a = np.array([[1, 2, 3],
   ...: [4, 5, 6],
   ...: [7, 8, 9]])

In [8]: lst = [1, 0, 2]

In [9]: a[np.arange(len(a)), lst]
Out[9]: array([2, 4, 9])

En savoir plus sur l'indexation de tableaux multidimensionnels: http://docs.scipy.org/doc/numpy/user/basics.indexing.html#indexing-multi-dimensional-arrays


1
du mal à comprendre pourquoi la plage est nécessaire au lieu de simplement «:» ou plage.
MadmanLee

@MadmanLee Hi, en utilisant :affichera plusieurs len(a)fois les résultats, à la place, indiquant que l'index de chaque ligne imprimera les résultats attendus.
GoingMyWay

1
Je pense que c'est exactement la bonne et élégante façon de résoudre ce problème.
GoingMyWay

6

Un moyen simple pourrait ressembler à:

In [1]: a = np.array([[1, 2, 3],
   ...: [4, 5, 6],
   ...: [7, 8, 9]])

In [2]: y = [1, 0, 2]  #list of indices we want to select from matrix 'a'

range(a.shape[0]) reviendra array([0, 1, 2])

In [3]: a[range(a.shape[0]), y] #we're selecting y indices from every row
Out[3]: array([2, 4, 9])

1
S'il vous plaît, pensez à ajouter des explications.
souki

@souki J'ai ajouté une explication maintenant. Merci
Dhaval Mayatra

6

Les numpyversions récentes ont ajouté un take_along_axis(et put_along_axis) qui effectue cette indexation proprement.

In [101]: a = np.arange(1,10).reshape(3,3)                                                             
In [102]: b = np.array([1,0,2])                                                                        
In [103]: np.take_along_axis(a, b[:,None], axis=1)                                                     
Out[103]: 
array([[2],
       [4],
       [9]])

Il fonctionne de la même manière que:

In [104]: a[np.arange(3), b]                                                                           
Out[104]: array([2, 4, 9])

mais avec une manipulation d'axe différente. Il vise particulièrement à appliquer les résultats de argsortet argmax.


3

Vous pouvez le faire en utilisant iterator. Comme ça:

np.fromiter((row[index] for row, index in zip(X, Y)), dtype=int)

Temps:

N = 1000
X = np.zeros(shape=(N, N))
Y = np.arange(N)

#@Aशwini चhaudhary
%timeit X[np.arange(len(X)), Y]
10000 loops, best of 3: 30.7 us per loop

#mine
%timeit np.fromiter((row[index] for row, index in zip(X, Y)), dtype=int)
1000 loops, best of 3: 1.15 ms per loop

#mine
%timeit np.diag(X.T[Y])
10 loops, best of 3: 20.8 ms per loop

1
OP a mentionné qu'il devrait fonctionner rapidement sur les grands tableaux, donc vos benchmarks ne sont pas très représentatifs. Je suis curieux de savoir comment votre dernière méthode fonctionne pour des tableaux (beaucoup) plus grands!

@moarningsun: mis à jour. np.diag(X.T[Y])est si lent ... Mais np.diag(X.T)est si rapide (10us). Je ne sais pas pourquoi.
Kei Minagawa

0

Un autre moyen astucieux consiste à d'abord transposer le tableau et à l'indexer par la suite. Enfin, prenez la diagonale, c'est toujours la bonne réponse.

X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
Y = np.array([1, 0, 2, 2])

np.diag(X.T[Y])

Pas à pas:

Tableaux originaux:

>>> X
array([[ 1,  2,  3],
       [ 4,  5,  6],
       [ 7,  8,  9],
       [10, 11, 12]])

>>> Y
array([1, 0, 2, 2])

Transposer pour permettre de l'indexer correctement.

>>> X.T
array([[ 1,  4,  7, 10],
       [ 2,  5,  8, 11],
       [ 3,  6,  9, 12]])

Obtenez les lignes dans l'ordre Y.

>>> X.T[Y]
array([[ 2,  5,  8, 11],
       [ 1,  4,  7, 10],
       [ 3,  6,  9, 12],
       [ 3,  6,  9, 12]])

La diagonale devrait maintenant devenir claire.

>>> np.diag(X.T[Y])
array([ 2,  4,  9, 12]

1
Cela fonctionne techniquement et semble très élégant. Cependant, je trouve que cette approche explose complètement lorsque vous avez affaire à de grands tableaux. Dans mon cas, NumPy a avalé 30 Go de swap et rempli mon SSD. Je recommande d'utiliser plutôt l'approche d'indexation avancée.
5nefarious
En utilisant notre site, vous reconnaissez avoir lu et compris notre politique liée aux cookies et notre politique de confidentialité.
Licensed under cc by-sa 3.0 with attribution required.