Est-il courant de minimiser la perte moyenne sur les lots au lieu de la somme?


15

Tensorflow propose un exemple de didacticiel sur la classification de CIFAR-10 . Dans le didacticiel, la perte d'entropie croisée moyenne sur le lot est minimisée.

def loss(logits, labels):
  """Add L2Loss to all the trainable variables.
  Add summary for for "Loss" and "Loss/avg".
  Args:
    logits: Logits from inference().
    labels: Labels from distorted_inputs or inputs(). 1-D tensor
            of shape [batch_size]
  Returns:
    Loss tensor of type float.
  """
  # Calculate the average cross entropy loss across the batch.
  labels = tf.cast(labels, tf.int64)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
      logits, labels, name='cross_entropy_per_example')
  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
  tf.add_to_collection('losses', cross_entropy_mean)

  # The total loss is defined as the cross entropy loss plus all of the weight
  # decay terms (L2 loss).
  return tf.add_n(tf.get_collection('losses'), name='total_loss')

Voir cifar10.py , ligne 267.

Pourquoi ne minimise-t-il pas plutôt la somme sur le lot? Est-ce que cela fait une différence? Je ne comprends pas comment cela affecterait le calcul du backprop.


Pas exactement lié à la somme / moyenne, mais le choix de perte est un choix de conception d'application. Par exemple, si vous avez raison d'avoir raison en moyenne, optimisez la moyenne. Si votre application est sensible au pire des cas (par exemple, un accident automobile), vous devez optimiser la valeur maximale.
Alex Kreimer

Réponses:


15

Comme mentionné par pkubik, il y a généralement un terme de régularisation pour les paramètres qui ne dépend pas de l'entrée, par exemple dans tensorflow c'est comme

# Loss function using L2 Regularization
regularizer = tf.nn.l2_loss(weights)
loss = tf.reduce_mean(loss + beta * regularizer)

Dans ce cas, la moyenne sur le mini-lot permet de maintenir un rapport fixe entre la cross_entropyperte et la regularizerperte pendant que la taille du lot est modifiée.

De plus, le taux d'apprentissage est également sensible à l'ampleur de la perte (gradient), afin de normaliser le résultat de différentes tailles de lot, la prise de la moyenne semble une meilleure option.


Mise à jour

Cet article de Facebook (Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour) montre qu'en réalité, la mise à l'échelle du taux d'apprentissage en fonction de la taille du lot fonctionne très bien:

Règle de mise à l'échelle linéaire: lorsque la taille du mini-lot est multipliée par k, multipliez le taux d'apprentissage par k.

ce qui est essentiellement la même chose que de multiplier le gradient par k et de garder le taux d'apprentissage inchangé, donc je suppose que prendre la moyenne n'est pas nécessaire.


8

Je vais me concentrer sur la partie:

Je ne comprends pas comment cela affecterait le calcul du backprop.

1BLSUM=BLAVGBdLSUMdx=BdLAVGdx

dLdx=limΔ0L(x+Δ)L(x)Δ
d(cL)dx=limΔ0cL(x+Δ)cL(x)Δ
d(cL)dx=climΔ0L(x+Δ)L(x)Δ=cdLdx

Dans SGD, nous mettions à jour les poids en utilisant leur gradient multiplié par le taux d'apprentissage et nous pouvons clairement voir que nous pouvons choisir ce paramètre de telle manière que les mises à jour finales des poids soient égales. La première règle de mise à jour: et la deuxième règle de mise à jour (imaginez que ): λ

W:=W+λ1dLSUMdW
λ1=λ2B
W:=W+λ1dLAVGdW=W+λ2BdLSUMdW


L'excellente conclusion de dontloo peut suggérer que l'utilisation de la somme pourrait être une approche un peu plus appropriée. Pour justifier la moyenne qui semble être plus populaire, j'ajouterais que l'utilisation de la somme pourrait probablement causer des problèmes de régularisation du poids. Le réglage du facteur d'échelle pour les régularisateurs pour différentes tailles de lot peut être tout aussi ennuyeux que le réglage du taux d'apprentissage.

En utilisant notre site, vous reconnaissez avoir lu et compris notre politique liée aux cookies et notre politique de confidentialité.
Licensed under cc by-sa 3.0 with attribution required.