À quoi sert torch.no_grad dans pytorch?


21

Je suis nouveau sur pytorch et j'ai commencé avec ce code github. Je ne comprends pas le commentaire de la ligne 60-61 du code "because weights have requires_grad=True, but we don't need to track this in autograd". J'ai compris que nous mentionnons requires_grad=Trueles variables dont nous avons besoin pour calculer les gradients pour utiliser autograd, mais qu'est-ce que cela signifie "tracked by autograd"?

Réponses:


24

L'encapsuleur "avec torch.no_grad ()" définit temporairement tous les indicateurs require_grad sur false. Un exemple du didacticiel officiel PyTorch ( https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#gradients ):

x = torch.randn(3, requires_grad=True)
print(x.requires_grad)
print((x ** 2).requires_grad)

with torch.no_grad():
    print((x ** 2).requires_grad)

En dehors:

True
True
False

Je vous recommande de lire tous les tutoriels du site Web ci-dessus.

Dans votre exemple: je suppose que l'auteur ne veut pas que PyTorch calcule les gradients des nouvelles variables définies w1 et w2 car il veut juste mettre à jour leurs valeurs.


6
with torch.no_grad()

rendra toutes les opérations dans le bloc sans dégradés.

Dans pytorch, vous ne pouvez pas faire de changement d'inplacement de w1 et w2, qui sont deux variables avec require_grad = True. Je pense qu'éviter le changement d'inplacement de w1 et w2 est parce que cela entraînera une erreur dans le calcul de la propagation arrière. Étant donné que le changement de placement changera totalement w1 et w2.

Cependant, si vous l'utilisez no_grad(), vous pouvez contrôler que le nouveau w1 et le nouveau w2 n'ont pas de dégradés car ils sont générés par des opérations, ce qui signifie que vous ne modifiez que la valeur de w1 et w2, pas une partie de gradient, ils ont toujours des informations de gradient variable définies précédemment et la propagation arrière peut continuer.

En utilisant notre site, vous reconnaissez avoir lu et compris notre politique liée aux cookies et notre politique de confidentialité.
Licensed under cc by-sa 3.0 with attribution required.