L'attention est une méthode d'agrégation d'un ensemble de vecteurs vi en un seul vecteur, souvent via un vecteur de recherche u . Habituellement, vi est soit les entrées du modèle, soit les états cachés des pas de temps précédents, soit les états cachés un niveau plus bas (dans le cas des LSTM empilés).
Le résultat est souvent appelé le vecteur de contexte c , car il contient le contexte correspondant au pas de temps actuel.
Ce vecteur de contexte supplémentaire c est ensuite également introduit dans le RNN / LSTM (il peut être simplement concaténé avec l'entrée d'origine). Par conséquent, le contexte peut être utilisé pour aider à la prédiction.
La façon la plus simple de le faire est de calculer le vecteur de probabilité p=softmax(VTu) et c=∑ipivi où V est la concaténation de tous les vi précédents . Un vecteur de recherche commun u est l'état caché actuel ht .
Il existe de nombreuses variantes à ce sujet et vous pouvez rendre les choses aussi compliquées que vous le souhaitez. Par exemple, au lieu d'utiliser vTjeu comme logits, on peut choisir F( vje, u ) place, où F est un réseau neuronal arbitraire.
Un mécanisme d'attention commun pour les modèles de séquence à séquence utilise p = softmax ( qTtanh( W1vje+ W2ht) ) , où v sont les états cachés du codeur et ht est l'état caché actuel du décodeur. q et les deux W s sont des paramètres.
Quelques articles qui montrent différentes variations sur l'idée d'attention:
Les réseaux de pointeurs font attention aux entrées de référence afin de résoudre les problèmes d'optimisation combinatoire.
Les réseaux d'entités récurrents maintiennent des états de mémoire distincts pour différentes entités (personnes / objets) lors de la lecture de texte et mettent à jour l'état de mémoire correct en faisant attention.
Les modèles de transformateurs font également largement appel à l'attention. Leur formulation de l'attention est légèrement plus générale et implique également des vecteurs clés kje : les poids d'attention p sont en fait calculés entre les clés et la recherche, et le contexte est ensuite construit avec le vje .
Voici une mise en œuvre rapide d'une forme d'attention, bien que je ne puisse garantir l'exactitude au-delà du fait qu'elle a réussi quelques tests simples.
RNN de base:
def rnn(inputs_split):
bias = tf.get_variable('bias', shape = [hidden_dim, 1])
weight_hidden = tf.tile(tf.get_variable('hidden', shape = [1, hidden_dim, hidden_dim]), [batch, 1, 1])
weight_input = tf.tile(tf.get_variable('input', shape = [1, hidden_dim, in_dim]), [batch, 1, 1])
hidden_states = [tf.zeros((batch, hidden_dim, 1), tf.float32)]
for i, input in enumerate(inputs_split):
input = tf.reshape(input, (batch, in_dim, 1))
last_state = hidden_states[-1]
hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )
hidden_states.append(hidden)
return hidden_states[-1]
Avec attention, nous ajoutons seulement quelques lignes avant que le nouvel état caché soit calculé:
if len(hidden_states) > 1:
logits = tf.transpose(tf.reduce_mean(last_state * hidden_states[:-1], axis = [2, 3]))
probs = tf.nn.softmax(logits)
probs = tf.reshape(probs, (batch, -1, 1, 1))
context = tf.add_n([v * prob for (v, prob) in zip(hidden_states[:-1], tf.unstack(probs, axis = 1))])
else:
context = tf.zeros_like(last_state)
last_state = tf.concat([last_state, context], axis = 1)
hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )
le code complet