Capture des motifs initiaux lors de l'utilisation de la rétropropagation tronquée dans le temps (RNN / LSTM)


12

Disons que j'utilise un RNN / LSTM pour faire une analyse de sentiment, qui est une approche à plusieurs (voir ce blog ). Le réseau est formé à travers une rétropropagation tronquée dans le temps (BPTT), où le réseau est déroulé pour seulement 30 dernières étapes comme d'habitude.

Dans mon cas, chacune de mes sections de texte que je veux classer est beaucoup plus longue que les 30 étapes déroulées (~ 100 mots). D'après mes connaissances, BPTT n'est exécuté qu'une seule fois pour une seule section de texte, c'est-à-dire lorsqu'il a passé la totalité de la section de texte et calculé la cible de classification binaire, , qu'il compare ensuite à la fonction de perte pour trouver l'erreur.y

Les gradients ne seront alors jamais calculés par rapport aux premiers mots de chaque section de texte. Comment le RNN / LSTM peut-il encore ajuster ses poids pour capturer des modèles spécifiques qui ne se produisent que dans les premiers mots? Par exemple, disons que toutes les phrases marquées comme commencent par "J'aime ça" et toutes les phrases marquées comme n e g a t i v e commencent par "Je déteste ça". Comment le RNN / LSTM capturerait-il cela alors qu'il n'est déroulé que pour les 30 dernières étapes lorsqu'il atteint la fin d'une longue séquence de 100 étapes?posjetjevenegunetjeve


généralement, l'abréviation est TBPTT pour Troncated Back-Propagation Through Time.
Charlie Parker

Réponses:


11

Il est vrai que limiter votre propagation de gradient à 30 pas de temps l'empêchera d'apprendre tout ce qui est possible dans votre jeu de données. Cependant, cela dépend fortement de votre ensemble de données si cela l'empêchera d'apprendre des choses importantes sur les fonctionnalités de votre modèle!

Limiter le gradient pendant la formation revient plus à limiter la fenêtre sur laquelle votre modèle peut assimiler les fonctionnalités d'entrée et l'état caché avec une grande confiance. Parce qu'au moment du test, vous appliquez votre modèle à la séquence d'entrée entière, il pourra toujours incorporer des informations sur toutes les fonctions d'entrée dans son état masqué. Il pourrait ne pas savoir exactement comment conserver ces informations jusqu'à ce qu'il fasse sa prédiction finale pour la phrase, mais il pourrait y avoir des connexions (certes plus faibles) qu'il serait toujours en mesure d'établir.

Pensez d'abord à un exemple artificiel. Supposons que votre réseau génère un 1 s'il y a un 1 n'importe où dans son entrée et un 0 dans le cas contraire. Supposons que vous entraînez le réseau sur des séquences de longueur 20 et limitez ensuite le gradient à 10 étapes. Si le jeu de données d'apprentissage ne contient jamais de 1 dans les 10 dernières étapes d'une entrée, le réseau va avoir un problème avec les entrées de test de n'importe quelle configuration. Cependant, si l'ensemble d'apprentissage contient des exemples comme [1 0 0 ... 0 0 0] et d'autres comme [0 0 0 ... 1 0 0], le réseau pourra détecter la "présence de une fonction 1 "n'importe où dans son entrée.

Revenons ensuite à l'analyse des sentiments. Disons que pendant la formation, votre modèle rencontre une longue phrase négative comme "Je déteste ça parce que ... autour et autour" avec, disons, 50 mots dans les points de suspension. En limitant la propagation du gradient à 30 pas de temps, le modèle ne connectera pas le "Je déteste cela parce que" à l'étiquette de sortie, donc il ne reprendra pas "Je", "Je déteste" ou "ceci" de cette formation exemple. Mais il reprendra les mots qui se trouvent dans les 30 pas de temps à partir de la fin de la phrase. Si votre ensemble d'entraînement contient d'autres exemples qui contiennent ces mêmes mots, éventuellement avec «haine», il a alors une chance de détecter le lien entre «haine» et l'étiquette de sentiment négatif. De plus, si vous avez des exemples de formation plus courts, dites: "Nous détestons cela parce que c'est terrible!" votre modèle pourra alors connecter les fonctionnalités "haine" et "ceci" à l'étiquette cible. Si vous avez suffisamment de ces exemples de formation, le modèle devrait être en mesure d'apprendre la connexion efficacement.

Au moment du test, disons que vous présentez le modèle avec une autre longue phrase comme "Je déteste ça parce que ... sur le gecko!" L'entrée du modèle commencera par "Je déteste ça", qui sera passé dans l'état caché du modèle sous une forme ou une autre. Cet état caché est utilisé pour influencer les futurs états cachés du modèle, donc même s'il peut y avoir 50 mots avant la fin de la phrase, l'état caché de ces mots initiaux a une chance théorique d'influencer la sortie, même s'il n'a jamais été formé sur des échantillons qui contenaient une si grande distance entre le "Je déteste ça" et la fin de la phrase.


0

@ Imjohns3 a raison, si vous traitez de longues séquences (taille N) et limitez la rétropropagation aux K dernières étapes, le réseau n'apprendra pas les schémas au début.

J'ai travaillé avec de longs textes et j'utilise l'approche où je calcule la perte et fais une rétropropagation après chaque K étapes. Supposons que ma séquence avait N = 1000 jetons, mon processus RNN d'abord K = 100, puis j'essaie de faire des prédictions (calculer la perte) et de rétropropager. Ensuite, tout en maintenant l'état RNN, freinez la chaîne de gradient (en pytorch-> détachez) et commencez un autre k = 100 pas.

Un bon exemple de cette technique que vous pouvez trouver ici: https://github.com/ksopyla/pytorch_neural_networks/blob/master/RNN/lstm_imdb_tbptt.py

En utilisant notre site, vous reconnaissez avoir lu et compris notre politique liée aux cookies et notre politique de confidentialité.
Licensed under cc by-sa 3.0 with attribution required.