Le meilleur moyen d'enregistrer un modèle entraîné dans PyTorch?


192

Je cherchais des moyens alternatifs pour enregistrer un modèle entraîné dans PyTorch. Jusqu'à présent, j'ai trouvé deux alternatives.

  1. torch.save () pour enregistrer un modèle et torch.load () pour charger un modèle.
  2. model.state_dict () pour enregistrer un modèle entraîné et model.load_state_dict () pour charger le modèle enregistré.

Je suis tombé sur cette discussion où l'approche 2 est recommandée par rapport à l'approche 1.

Ma question est la suivante: pourquoi la deuxième approche est-elle préférée? Est-ce uniquement parce que les modules torch.nn ont ces deux fonctions et que nous sommes encouragés à les utiliser?


2
Je pense que c'est parce que torch.save () enregistre également toutes les variables intermédiaires, comme les sorties intermédiaires pour une utilisation en rétro-propagation. Mais il vous suffit de sauvegarder les paramètres du modèle, comme le poids / le biais, etc. Parfois, le premier peut être beaucoup plus grand que le second.
Dawei Yang

2
J'ai testé torch.save(model, f)et torch.save(model.state_dict(), f). Les fichiers enregistrés ont la même taille. Maintenant je suis confus. De plus, j'ai trouvé l'utilisation de pickle pour enregistrer model.state_dict () extrêmement lente. Je pense que le meilleur moyen est d'utiliser torch.save(model.state_dict(), f)puisque vous gérez la création du modèle et que la torche gère le chargement des poids du modèle, éliminant ainsi les problèmes possibles. Référence: discuss.pytorch.org/t/saving-torch-models/838/4
Dawei Yang

On dirait que PyTorch a abordé cela un peu plus explicitement dans sa section tutoriels - il y a beaucoup de bonnes informations qui ne sont pas répertoriées dans les réponses ici, y compris la sauvegarde de plus d'un modèle à la fois et des modèles de démarrage chaleureux.
whlteXbread

quel est le problème avec l'utilisation pickle?
Charlie Parker le

1
@CharlieParker torch.save est basé sur pickle. Ce qui suit est tiré du didacticiel lié ci-dessus: «[torch.save] enregistrera le module entier en utilisant le module pickle de Python. L'inconvénient de cette approche est que les données sérialisées sont liées aux classes spécifiques et à la structure de répertoire exacte utilisée lorsque le modèle est enregistré. La raison en est que pickle n'enregistre pas la classe de modèle elle-même. Il enregistre plutôt un chemin vers le fichier contenant la classe, qui est utilisé pendant le chargement. De ce fait, votre code peut être interrompu de différentes manières lorsque utilisé dans d'autres projets ou après des refactors. "
David Miller le

Réponses:


214

J'ai trouvé cette page sur leur dépôt github, je vais simplement coller le contenu ici.


Approche recommandée pour enregistrer un modèle

Il existe deux approches principales pour la sérialisation et la restauration d'un modèle.

Le premier (recommandé) enregistre et charge uniquement les paramètres du modèle:

torch.save(the_model.state_dict(), PATH)

Puis plus tard:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

La seconde sauvegarde et charge l'ensemble du modèle:

torch.save(the_model, PATH)

Puis plus tard:

the_model = torch.load(PATH)

Cependant, dans ce cas, les données sérialisées sont liées aux classes spécifiques et à la structure de répertoire exacte utilisée, de sorte qu'elles peuvent se briser de diverses manières lorsqu'elles sont utilisées dans d'autres projets ou après de sérieux refactors.


8
Selon @smth discuss.pytorch.org/t/saving-and-loading-a-model-in-pytorch/ ... le modèle se recharge pour entraîner le modèle par défaut. donc besoin d'appeler manuellement the_model.eval () après le chargement, si vous le chargez pour l'inférence, pas pour reprendre l'entraînement.
WillZ

la deuxième méthode donne une erreur stackoverflow.com/questions/53798009/… sur Windows 10. n'a pas pu le résoudre
Gulzar

Existe-t-il une option pour enregistrer sans avoir besoin d'un accès pour la classe de modèle?
Michael D du

Avec cette approche, comment garder une trace des * args et ** kwargs que vous devez transmettre pour le cas de charge?
Mariano Kamp le

quel est le problème avec l'utilisation pickle?
Charlie Parker le

144

Cela dépend de ce que vous voulez faire.

Cas n ° 1: enregistrez le modèle pour l'utiliser vous-même pour l'inférence : vous enregistrez le modèle, vous le restaurez, puis vous changez le modèle en mode d'évaluation. Ceci est fait parce que vous avez en général BatchNormet des Dropoutcouches qui sont par défaut en mode train sur la construction:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Cas n ° 2: Enregistrer le modèle pour reprendre l'entraînement plus tard : Si vous devez continuer à entraîner le modèle que vous êtes sur le point d'enregistrer, vous devez enregistrer plus que le modèle. Vous devez également enregistrer l'état de l'optimiseur, les époques, le score, etc. Vous le feriez comme ceci:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

Pour reprendre l'entraînement, vous feriez des choses comme:, state = torch.load(filepath)puis, pour restaurer l'état de chaque objet individuel, quelque chose comme ceci:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Puisque vous reprenez l'entraînement, NE PAS appeler model.eval()une fois que vous avez restauré les états lors du chargement.

Cas n ° 3: Modèle à utiliser par quelqu'un d'autre sans accès à votre code : Dans Tensorflow, vous pouvez créer un .pbfichier qui définit à la fois l'architecture et les poids du modèle. Ceci est très pratique, surtout lors de l'utilisation Tensorflow serve. La façon équivalente de faire cela dans Pytorch serait:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

Cette méthode n'est toujours pas à l'épreuve des balles et comme pytorch subit encore de nombreux changements, je ne le recommanderais pas.


1
Existe-t-il une fin de fichier recommandée pour les 3 cas? Ou est-ce toujours .pth?
Verena Haunschmid

1
Dans le cas n ° 3 torch.loadrenvoie juste un OrderedDict. Comment obtenir le modèle pour faire des prédictions?
Alber8295

Bonjour, Puis-je savoir comment faire le "Cas n ° 2: Enregistrer le modèle pour reprendre l'entraînement plus tard"? J'ai réussi à charger le point de contrôle pour modéliser, puis je n'ai pas pu exécuter ou reprendre le train de modèle comme "model.to (appareil) model = train_model_epoch (modèle, critère, optimiseur, ordonnanceur, époques)"
dnez

1
Salut, pour le cas un qui est pour l'inférence, dans le document officiel de pytorch, dites que vous devez enregistrer l'optimiseur state_dict pour l'inférence ou la formation. "Lors de l'enregistrement d'un point de contrôle général, à utiliser pour l'inférence ou la reprise de la formation, vous devez enregistrer plus que le state_dict du modèle. Il est important de sauvegarder également le state_dict de l'optimiseur, car il contient des tampons et des paramètres qui sont mis à jour à mesure que le modèle s'entraîne . "
Mohammed Awney

1
Dans le cas n ° 3, la classe de modèle doit être définie quelque part.
Michael D

12

La bibliothèque pickle Python implémente des protocoles binaires pour la sérialisation et la désérialisation d'un objet Python.

Lorsque vous import torch(ou lorsque vous utilisez PyTorch), ce sera import picklepour vous et vous n'avez pas besoin d'appeler pickle.dump()et pickle.load()directement, quelles sont les méthodes pour enregistrer et charger l'objet.

En fait, torch.save()et torch.load()emballera pickle.dump()et pickle.load()pour vous.

Une state_dictautre réponse mentionnée mérite juste quelques notes supplémentaires.

Qu'avons state_dict-nous à l'intérieur de PyTorch? Il y a en fait deux state_dictart.

Le modèle est PyTorch torch.nn.Modulea model.parameters()appel pour obtenir des paramètres apprenables (w et b). Ces paramètres apprenables, une fois définis aléatoirement, seront mis à jour au fil du temps à mesure que nous apprenons. Les paramètres apprenables sont les premiers state_dict.

Le second state_dictest le dict d'état de l'optimiseur. Vous vous rappelez que l'optimiseur est utilisé pour améliorer nos paramètres apprenables. Mais l'optimiseur state_dictest fixe. Rien à apprendre là-dedans.

Les state_dictobjets étant des dictionnaires Python, ils peuvent être facilement enregistrés, mis à jour, modifiés et restaurés, ajoutant une grande modularité aux modèles et optimiseurs PyTorch.

Créons un modèle super simple pour expliquer cela:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Ce code affichera les éléments suivants:

Model's state_dict:
weight   torch.Size([2, 5])
bias     torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

Notez qu'il s'agit d'un modèle minimal. Vous pouvez essayer d'ajouter une pile de séquences

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

Notez que seules les couches avec des paramètres apprenables (couches convolutionnelles, couches linéaires, etc.) et des tampons enregistrés (couches batchnorm) ont des entrées dans le modèle state_dict.

Les choses non apprenables appartiennent à l'objet optimiseur state_dict, qui contient des informations sur l'état de l'optimiseur, ainsi que les hyperparamètres utilisés.

Le reste de l'histoire est le même; dans la phase d'inférence (c'est une phase où l'on utilise le modèle après l'entraînement) pour la prédiction; nous prédisons en fonction des paramètres que nous avons appris. Donc, pour l'inférence, nous avons juste besoin de sauvegarder les paramètres model.state_dict().

torch.save(model.state_dict(), filepath)

Et pour utiliser plus tard model.load_state_dict (torch.load (filepath)) model.eval ()

Remarque: n'oubliez pas la dernière ligne qui model.eval()est cruciale après le chargement du modèle.

N'essayez pas non plus de sauvegarder torch.save(model.parameters(), filepath). Le model.parameters()n'est que l'objet générateur.

De l'autre côté, torch.save(model, filepath)enregistre l'objet de modèle lui-même, mais gardez à l'esprit que le modèle n'a pas d'optimiseur state_dict. Vérifiez l'autre excellente réponse de @Jadiel de Armas pour enregistrer le dict d'état de l'optimiseur.


Bien que ce ne soit pas une solution simple, l'essence du problème est profondément analysée! Vote positif.
Jason Young le

7

Une convention PyTorch courante consiste à enregistrer les modèles en utilisant une extension de fichier .pt ou .pth.

Enregistrer / charger tout le modèle Enregistrer:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

Charge:

La classe de modèle doit être définie quelque part

model = torch.load(PATH)
model.eval()

4

Si vous souhaitez enregistrer le modèle et que vous souhaitez reprendre l'entraînement plus tard:

GPU unique: Enregistrer:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Charge:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

GPU multiple: enregistrer

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Charge:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU
En utilisant notre site, vous reconnaissez avoir lu et compris notre politique liée aux cookies et notre politique de confidentialité.
Licensed under cc by-sa 3.0 with attribution required.