La bibliothèque pickle Python implémente des protocoles binaires pour la sérialisation et la désérialisation d'un objet Python.
Lorsque vous import torch
(ou lorsque vous utilisez PyTorch), ce sera import pickle
pour vous et vous n'avez pas besoin d'appeler pickle.dump()
et pickle.load()
directement, quelles sont les méthodes pour enregistrer et charger l'objet.
En fait, torch.save()
et torch.load()
emballera pickle.dump()
et pickle.load()
pour vous.
Une state_dict
autre réponse mentionnée mérite juste quelques notes supplémentaires.
Qu'avons state_dict
-nous à l'intérieur de PyTorch? Il y a en fait deux state_dict
art.
Le modèle est PyTorch torch.nn.Module
a model.parameters()
appel pour obtenir des paramètres apprenables (w et b). Ces paramètres apprenables, une fois définis aléatoirement, seront mis à jour au fil du temps à mesure que nous apprenons. Les paramètres apprenables sont les premiers state_dict
.
Le second state_dict
est le dict d'état de l'optimiseur. Vous vous rappelez que l'optimiseur est utilisé pour améliorer nos paramètres apprenables. Mais l'optimiseur state_dict
est fixe. Rien à apprendre là-dedans.
Les state_dict
objets étant des dictionnaires Python, ils peuvent être facilement enregistrés, mis à jour, modifiés et restaurés, ajoutant une grande modularité aux modèles et optimiseurs PyTorch.
Créons un modèle super simple pour expliquer cela:
import torch
import torch.optim as optim
model = torch.nn.Linear(5, 2)
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("Model weight:")
print(model.weight)
print("Model bias:")
print(model.bias)
print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
Ce code affichera les éléments suivants:
Model's state_dict:
weight torch.Size([2, 5])
bias torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328, 0.1360, 0.1553, -0.1838, -0.0316],
[ 0.0479, 0.1760, 0.1712, 0.2244, 0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]
Notez qu'il s'agit d'un modèle minimal. Vous pouvez essayer d'ajouter une pile de séquences
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.Conv2d(A, B, C)
torch.nn.Linear(H, D_out),
)
Notez que seules les couches avec des paramètres apprenables (couches convolutionnelles, couches linéaires, etc.) et des tampons enregistrés (couches batchnorm) ont des entrées dans le modèle state_dict
.
Les choses non apprenables appartiennent à l'objet optimiseur state_dict
, qui contient des informations sur l'état de l'optimiseur, ainsi que les hyperparamètres utilisés.
Le reste de l'histoire est le même; dans la phase d'inférence (c'est une phase où l'on utilise le modèle après l'entraînement) pour la prédiction; nous prédisons en fonction des paramètres que nous avons appris. Donc, pour l'inférence, nous avons juste besoin de sauvegarder les paramètres model.state_dict()
.
torch.save(model.state_dict(), filepath)
Et pour utiliser plus tard model.load_state_dict (torch.load (filepath)) model.eval ()
Remarque: n'oubliez pas la dernière ligne qui model.eval()
est cruciale après le chargement du modèle.
N'essayez pas non plus de sauvegarder torch.save(model.parameters(), filepath)
. Le model.parameters()
n'est que l'objet générateur.
De l'autre côté, torch.save(model, filepath)
enregistre l'objet de modèle lui-même, mais gardez à l'esprit que le modèle n'a pas d'optimiseur state_dict
. Vérifiez l'autre excellente réponse de @Jadiel de Armas pour enregistrer le dict d'état de l'optimiseur.