Il numpy.einsum()
est très facile de saisir l'idée de si vous la comprenez intuitivement. À titre d'exemple, commençons par une description simple impliquant la multiplication matricielle .
Pour l'utiliser numpy.einsum()
, tout ce que vous avez à faire est de passer la soi-disant chaîne d'indices en argument, suivie de vos tableaux d'entrée .
Disons que vous avez deux tableaux 2D, A
et B
, et que vous voulez faire la multiplication de matrices. Alors, vous faites:
np.einsum("ij, jk -> ik", A, B)
Ici, la chaîne de l' indiceij
correspond au tableau A
tandis que la chaîne de l' indicejk
correspond au tableau B
. En outre, la chose la plus importante à noter ici est que le nombre de caractères dans chaque chaîne d'indice doit correspondre aux dimensions du tableau. (c'est-à-dire deux caractères pour les tableaux 2D, trois caractères pour les tableaux 3D, et ainsi de suite.) Et si vous répétez les caractères entre les chaînes d'indice ( j
dans notre cas), cela signifie que vous voulez que la ein
somme se produise le long de ces dimensions. Ainsi, ils seront réduits en somme. (c'est-à-dire que cette dimension aura disparu )
La chaîne d'indice après cela ->
, sera notre tableau résultant. Si vous le laissez vide, tout sera additionné et une valeur scalaire est renvoyée en résultat. Sinon, le tableau résultant aura des dimensions en fonction de la chaîne d'indice . Dans notre exemple, ce sera le cas ik
. C'est intuitif car nous savons que pour la multiplication matricielle, le nombre de colonnes dans le tableau A
doit correspondre au nombre de lignes dans le tableau, B
ce qui se passe ici (c'est-à-dire que nous encodons cette connaissance en répétant le caractère j
dans la chaîne d'indice )
Voici quelques exemples supplémentaires illustrant l'utilisation / la puissance de np.einsum()
dans la mise en œuvre de certaines opérations courantes de tenseurs ou de nd-tableaux , de manière succincte.
Contributions
In [197]: vec
Out[197]: array([0, 1, 2, 3])
In [198]: A
Out[198]:
array([[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
[41, 42, 43, 44]])
In [199]: B
Out[199]:
array([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[4, 4, 4, 4]])
1) Multiplication matricielle (similaire à np.matmul(arr1, arr2)
)
In [200]: np.einsum("ij, jk -> ik", A, B)
Out[200]:
array([[130, 130, 130, 130],
[230, 230, 230, 230],
[330, 330, 330, 330],
[430, 430, 430, 430]])
2) Extraire les éléments le long de la diagonale principale (similaire à np.diag(arr)
)
In [202]: np.einsum("ii -> i", A)
Out[202]: array([11, 22, 33, 44])
3) Produit Hadamard (c.-à-d. Produit élément par élément de deux tableaux) (similaire à arr1 * arr2
)
In [203]: np.einsum("ij, ij -> ij", A, B)
Out[203]:
array([[ 11, 12, 13, 14],
[ 42, 44, 46, 48],
[ 93, 96, 99, 102],
[164, 168, 172, 176]])
4) Mise au carré élément par élément (similaire à np.square(arr)
ou arr ** 2
)
In [210]: np.einsum("ij, ij -> ij", B, B)
Out[210]:
array([[ 1, 1, 1, 1],
[ 4, 4, 4, 4],
[ 9, 9, 9, 9],
[16, 16, 16, 16]])
5) Trace (c'est-à-dire somme des éléments de la diagonale principale) (similaire à np.trace(arr)
)
In [217]: np.einsum("ii -> ", A)
Out[217]: 110
6) Transposition de matrice (similaire à np.transpose(arr)
)
In [221]: np.einsum("ij -> ji", A)
Out[221]:
array([[11, 21, 31, 41],
[12, 22, 32, 42],
[13, 23, 33, 43],
[14, 24, 34, 44]])
7) Produit extérieur (de vecteurs) (similaire à np.outer(vec1, vec2)
)
In [255]: np.einsum("i, j -> ij", vec, vec)
Out[255]:
array([[0, 0, 0, 0],
[0, 1, 2, 3],
[0, 2, 4, 6],
[0, 3, 6, 9]])
8) Produit intérieur (de vecteurs) (similaire à np.inner(vec1, vec2)
)
In [256]: np.einsum("i, i -> ", vec, vec)
Out[256]: 14
9) Somme le long de l'axe 0 (similaire à np.sum(arr, axis=0)
)
In [260]: np.einsum("ij -> j", B)
Out[260]: array([10, 10, 10, 10])
10) Somme le long de l'axe 1 (similaire à np.sum(arr, axis=1)
)
In [261]: np.einsum("ij -> i", B)
Out[261]: array([ 4, 8, 12, 16])
11) Multiplication de la matrice par lots
In [287]: BM = np.stack((A, B), axis=0)
In [288]: BM
Out[288]:
array([[[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
[41, 42, 43, 44]],
[[ 1, 1, 1, 1],
[ 2, 2, 2, 2],
[ 3, 3, 3, 3],
[ 4, 4, 4, 4]]])
In [289]: BM.shape
Out[289]: (2, 4, 4)
In [292]: BMM = np.einsum("bij, bjk -> bik", BM, BM)
In [293]: BMM
Out[293]:
array([[[1350, 1400, 1450, 1500],
[2390, 2480, 2570, 2660],
[3430, 3560, 3690, 3820],
[4470, 4640, 4810, 4980]],
[[ 10, 10, 10, 10],
[ 20, 20, 20, 20],
[ 30, 30, 30, 30],
[ 40, 40, 40, 40]]])
In [294]: BMM.shape
Out[294]: (2, 4, 4)
12) Somme le long de l'axe 2 (similaire à np.sum(arr, axis=2)
)
In [330]: np.einsum("ijk -> ij", BM)
Out[330]:
array([[ 50, 90, 130, 170],
[ 4, 8, 12, 16]])
13) Somme tous les éléments du tableau (similaire à np.sum(arr)
)
In [335]: np.einsum("ijk -> ", BM)
Out[335]: 480
14) Somme sur plusieurs axes (c.-à-d. Marginalisation)
(similaire à np.sum(arr, axis=(axis0, axis1, axis2, axis3, axis4, axis6, axis7))
)
In [354]: R = np.random.standard_normal((3,5,4,6,8,2,7,9))
In [363]: esum = np.einsum("ijklmnop -> n", R)
In [364]: nsum = np.sum(R, axis=(0,1,2,3,4,6,7))
In [365]: np.allclose(esum, nsum)
Out[365]: True
15) Double Dot Products (similaire à np.sum (hadamard-product) cf. 3 )
In [772]: A
Out[772]:
array([[1, 2, 3],
[4, 2, 2],
[2, 3, 4]])
In [773]: B
Out[773]:
array([[1, 4, 7],
[2, 5, 8],
[3, 6, 9]])
In [774]: np.einsum("ij, ij -> ", A, B)
Out[774]: 124
16) Multiplication de tableaux 2D et 3D
Une telle multiplication pourrait être très utile lors de la résolution d'un système d'équations linéaires ( Ax = b ) où vous souhaitez vérifier le résultat.
In [115]: A = np.random.rand(3,3)
In [116]: b = np.random.rand(3, 4, 5)
In [117]: x = np.linalg.solve(A, b.reshape(b.shape[0], -1)).reshape(b.shape)
In [118]: Ax = np.einsum('ij, jkl', A, x)
In [119]: np.allclose(Ax, b)
Out[119]: True
Au contraire, si l'on doit utiliser np.matmul()
pour cette vérification, nous devons faire quelques reshape
opérations pour obtenir le même résultat comme:
In [123]: Ax_matmul = np.matmul(A, x.reshape(x.shape[0], -1)).reshape(x.shape)
In [124]: np.allclose(Ax, Ax_matmul)
Out[124]: True
Bonus : Lisez plus de maths ici: Einstein-Summation et certainement ici: Tensor-Notation
(A * B)^T
, ou de manière équivalenteB^T * A^T
.