Pourquoi avons-nous besoin d'appeler zero_grad () dans PyTorch?


Réponses:


144

Dans PyTorch, nous devons définir les gradients sur zéro avant de commencer à effectuer une rétro-prolifération, car PyTorch accumule les gradients lors des passages en arrière suivants. Ceci est pratique lors de la formation des RNN. Ainsi, l'action par défaut est d' accumuler (c'est-à-dire de faire la somme) des gradients à chaque loss.backward()appel.

Pour cette raison, lorsque vous démarrez votre boucle d'entraînement, vous devriez idéalement zero out the gradientsfaire la mise à jour des paramètres correctement. Sinon, le gradient pointerait dans une autre direction que la direction prévue vers le minimum (ou le maximum , en cas d'objectifs de maximisation).

Voici un exemple simple:

import torch
from torch.autograd import Variable
import torch.optim as optim

def linear_model(x, W, b):
    return torch.matmul(x, W) + b

data, targets = ...

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

optimizer = optim.Adam([W, b])

for sample, target in zip(data, targets):
    # clear out the gradients of all Variables 
    # in this optimizer (i.e. W, b)
    optimizer.zero_grad()
    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()
    optimizer.step()

Alternativement, si vous effectuez une descente en dégradé vanille , alors:

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

for sample, target in zip(data, targets):
    # clear out the gradients of Variables 
    # (i.e. W, b)
    W.grad.data.zero_()
    b.grad.data.zero_()

    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()

    W -= learning_rate * W.grad.data
    b -= learning_rate * b.grad.data

Remarque : L' accumulation (c'est-à-dire la somme ) des gradients se produit lorsque .backward()est appelé sur le losstenseur .


3
merci beaucoup, c'est vraiment utile! Savez-vous si le tensorflow a le comportement?
layser

Juste pour être sûr ... si vous ne le faites pas, vous rencontrerez un problème de gradient explosif, non?
zwep

2
@zwep Si nous accumulons des dégradés, cela ne signifie pas que leur magnitude augmente: un exemple serait si le signe du dégradé continue de basculer. Cela ne garantirait donc pas que vous vous heurteriez au problème du gradient qui explose. En outre, des dégradés explosifs existent même si vous mettez à zéro correctement.
Tom Roth le

Lorsque vous exécutez la descente de dégradé vanille, n'obtenez-vous pas une erreur «Variable feuille qui nécessite que grad a été utilisée dans une opération sur place» lorsque vous essayez de mettre à jour les pondérations?
MUAS

1

zero_grad () redémarre la boucle sans pertes à partir de la dernière étape si vous utilisez la méthode du gradient pour diminuer l'erreur (ou les pertes)

si vous n'utilisez pas zero_grad () la perte sera diminuée et non augmentée comme requis

par exemple, si vous utilisez zero_grad (), vous trouverez la sortie suivante:

model training loss is 1.5
model training loss is 1.4
model training loss is 1.3
model training loss is 1.2

si vous n'utilisez pas zero_grad (), vous trouverez la sortie suivante:

model training loss is 1.4
model training loss is 1.9
model training loss is 2
model training loss is 2.8
model training loss is 3.5
En utilisant notre site, vous reconnaissez avoir lu et compris notre politique liée aux cookies et notre politique de confidentialité.
Licensed under cc by-sa 3.0 with attribution required.