Comment obtenir l'index d'un élément maximum dans un tableau numpy le long d'un axe


118

J'ai un tableau NumPy à 2 dimensions. Je sais comment obtenir les valeurs maximales sur les axes:

>>> a = array([[1,2,3],[4,3,1]])
>>> amax(a,axis=0)
array([4, 3, 3])

Comment puis-je obtenir les indices des éléments maximum? Donc je voudrais comme sortiearray([1,1,0])

Réponses:


142
>>> a.argmax(axis=0)

array([1, 1, 0])

1
cela fonctionne bien pour les entiers mais que puis-je faire pour les valeurs flottantes et les nombres entre 0 et 1
Priyom saha

100
>>> import numpy as np
>>> a = np.array([[1,2,3],[4,3,1]])
>>> i,j = np.unravel_index(a.argmax(), a.shape)
>>> a[i,j]
4

11
Notez que cette réponse est trompeuse. Il calcule l'indice de l'élément maximum du tableau sur tous les axes, pas le long d'un axe donné comme le demande l'OP: c'est faux. De plus, s'il y a plus d'un maximum, il ne récupère les indices que du premier maximum: il faut le signaler. Essayez avec a = np.array([[1,4,3],[4,3,1]])pour voir qu'il revient i,j==0,1, et néglige la solution à i,j==1,0. Pour les indices de tous les maxima, utilisez à la place i,j = where(a==a.max().
gg349 le

36

argmax()ne renverra que la première occurrence pour chaque ligne. http://docs.scipy.org/doc/numpy/reference/generated/numpy.argmax.html

Si jamais vous avez besoin de le faire pour un tableau mis en forme, cela fonctionne mieux que unravel:

import numpy as np
a = np.array([[1,2,3], [4,3,1]])  # Can be of any shape
indices = np.where(a == a.max())

Vous pouvez également modifier vos conditions:

indices = np.where(a >= 1.5)

Ce qui précède vous donne les résultats sous la forme que vous avez demandée. Alternativement, vous pouvez convertir en une liste de coordonnées x, y en:

x_y_coords =  zip(indices[0], indices[1])

2
Cela n'a pas fonctionné pour moi ... Voulez-vous dire indices = np.where(a==a.max())à la ligne 3?
atomh33ls

Vous avez raison, atomh33ls! Merci d'avoir remarqué ça. J'ai corrigé cette instruction pour inclure le deuxième signe égal pour le conditionnel approprié.
SevakPrime

@SevakPrime, il y a eu une deuxième erreur signalée par @ atomh33ls, .max()au lieu de .argmax(). Veuillez modifier la réponse
gg349

@ gg349, cela dépend de ce que vous voulez. argmax le fournit le long d'un axe qui semble être la façon dont l'OP veut qu'il ait approuvé cette réponse d'eumiro.
SevakPrime

Je vois que la correction @ atomh33ls et que je propose conduit à l'index du ou des plus gros élément (s) du tableau, alors que l'OP demandait les plus grands éléments le long d'un certain axe. Notez cependant que votre solution actuelle conduit à x_y_coord = [(0, 2), (1, 1)]cela ne correspond PAS à la réponse @eumiro, et est erronée. Par exemple, essayez avec a = array([[7,8,9],[10,11,12]])pour voir que votre code n'a aucun hit sur cette entrée. Vous mentionnez également que cela fonctionne mieux que unravel, mais la solution postée par @blas répond au problème du maximum absolu, pas seulement le long d'un axe.
gg349 le

3
v = alli.max()
index = alli.argmax()
x, y = index/8, index%8
En utilisant notre site, vous reconnaissez avoir lu et compris notre politique liée aux cookies et notre politique de confidentialité.
Licensed under cc by-sa 3.0 with attribution required.