En termes simples, torch.Tensor.view()
qui est inspiré par numpy.ndarray.reshape()
ou numpy.reshape()
, crée une nouvelle vue du tenseur, tant que la nouvelle forme est compatible avec la forme du tenseur d'origine.
Comprenons cela en détail à l'aide d'un exemple concret.
In [43]: t = torch.arange(18)
In [44]: t
Out[44]:
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17])
Avec ce tenseur t
de forme (18,)
, de nouvelles vues ne peuvent être créées que pour les formes suivantes:
(1, 18)
ou de manière équivalente (1, -1)
ou ou équivalente ou ou équivalente ou ou équivalente ou ou équivalente ou ou équivalente ou(-1, 18)
(2, 9)
(2, -1)
(-1, 9)
(3, 6)
(3, -1)
(-1, 6)
(6, 3)
(6, -1)
(-1, 3)
(9, 2)
(9, -1)
(-1, 2)
(18, 1)
(18, -1)
(-1, 1)
Comme nous pouvons déjà l'observer à partir des tuples de forme ci-dessus, la multiplication des éléments du tuple de forme (par exemple 2*9
, 3*6
etc.) doit toujours être égale au nombre total d'éléments dans le tenseur d'origine (18
dans notre exemple).
Une autre chose à observer est que nous avons utilisé un -1
dans l'un des endroits de chacun des tuples de forme. En utilisant a -1
, nous sommes paresseux dans le calcul nous-mêmes et déléguons plutôt la tâche à PyTorch pour faire le calcul de cette valeur pour la forme lors de la création de la nouvelle vue . Une chose importante à noter est que nous pouvons seulement utiliser un seul -1
dans le tuple de forme. Les valeurs restantes doivent être fournies explicitement par nous. Sinon PyTorch se plaindra en lançant un RuntimeError
:
RuntimeError: une seule dimension peut être déduite
Ainsi, avec toutes les formes mentionnées ci-dessus, PyTorch retournera toujours une nouvelle vue du tenseur d'originet
. Cela signifie essentiellement qu'il modifie simplement les informations de foulée du tenseur pour chacune des nouvelles vues demandées.
Vous trouverez ci-dessous quelques exemples illustrant comment les foulées des tenseurs sont modifiées à chaque nouvelle vue .
# stride of our original tensor `t`
In [53]: t.stride()
Out[53]: (1,)
Maintenant, nous allons voir les progrès pour les nouvelles vues :
# shape (1, 18)
In [54]: t1 = t.view(1, -1)
# stride tensor `t1` with shape (1, 18)
In [55]: t1.stride()
Out[55]: (18, 1)
# shape (2, 9)
In [56]: t2 = t.view(2, -1)
# stride of tensor `t2` with shape (2, 9)
In [57]: t2.stride()
Out[57]: (9, 1)
# shape (3, 6)
In [59]: t3 = t.view(3, -1)
# stride of tensor `t3` with shape (3, 6)
In [60]: t3.stride()
Out[60]: (6, 1)
# shape (6, 3)
In [62]: t4 = t.view(6,-1)
# stride of tensor `t4` with shape (6, 3)
In [63]: t4.stride()
Out[63]: (3, 1)
# shape (9, 2)
In [65]: t5 = t.view(9, -1)
# stride of tensor `t5` with shape (9, 2)
In [66]: t5.stride()
Out[66]: (2, 1)
# shape (18, 1)
In [68]: t6 = t.view(18, -1)
# stride of tensor `t6` with shape (18, 1)
In [69]: t6.stride()
Out[69]: (1, 1)
Voilà donc la magie de la view()
fonction. Il modifie simplement les enjambées du tenseur (d'origine) pour chacune des nouvelles vues , tant que la forme de la nouvelle vue est compatible avec la forme d'origine.
Une autre chose intéressante que l'on peut observer à partir des tuples de foulées est que la valeur de l'élément en 0 ème position est égale à la valeur de l'élément en 1 ère position du tuple de forme.
In [74]: t3.shape
Out[74]: torch.Size([3, 6])
|
In [75]: t3.stride() |
Out[75]: (6, 1) |
|_____________|
Ceci est dû au fait:
In [76]: t3
Out[76]:
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17]])
la foulée (6, 1)
dit que pour passer d'un élément à l'élément suivant le long de la 0 ème dimension, il faut sauter ou faire 6 pas. (c.-à-d. pour aller de 0
à 6
, il faut faire 6 étapes.) Mais pour passer d'un élément à l'élément suivant dans la 1ère dimension, nous n'avons besoin que d'une seule étape (par exemple pour aller de 2
à3
).
Ainsi, les informations de pas sont au cœur de l'accès aux éléments depuis la mémoire pour effectuer le calcul.
Cette fonction retournerait une vue et est exactement la même que celle utilisée torch.Tensor.view()
tant que la nouvelle forme est compatible avec la forme du tenseur d'origine. Sinon, il en retournera une copie.
Cependant, les notes de torch.reshape()
préviennent que:
Les entrées contiguës et les entrées avec des pas compatibles peuvent être remodelées sans copie, mais il ne faut pas dépendre du comportement de copie par rapport à l'affichage.
reshape
dans PyTorch?!