Différence entre Variable et get_variable dans TensorFlow


125

Autant que je sache, Variablec'est l'opération par défaut pour créer une variable, et get_variableest principalement utilisée pour le partage de poids.

D'une part, certaines personnes suggèrent d'utiliser à la get_variableplace de l' Variableopération primitive chaque fois que vous avez besoin d'une variable. D'un autre côté, je vois simplement toute utilisation de get_variabledans les documents officiels et les démos de TensorFlow.

Je souhaite donc connaître quelques règles empiriques sur la manière d'utiliser correctement ces deux mécanismes. Existe-t-il des principes «standards»?


6
get_variable est une nouvelle façon, la variable est une ancienne méthode (qui pourrait être prise en charge pour toujours) comme le dit Lukasz (PS: il a écrit une grande partie de la portée du nom de variable dans TF)
Yaroslav Bulatov

Réponses:


90

Je recommanderais de toujours utiliser tf.get_variable(...)- cela facilitera la refactorisation de votre code si vous avez besoin de partager des variables à tout moment, par exemple dans un paramètre multi-gpu (voir l'exemple CIFAR multi-gpu). Il n'y a aucun inconvénient à cela.

Pure tf.Variableest de niveau inférieur; à un moment donné, il tf.get_variable()n'existait pas, donc certains codes utilisent toujours la méthode de bas niveau.


5
Merci beaucoup pour votre réponse. Mais j'ai encore une question sur la façon de remplacer tf.Variablepar tf.get_variablepartout. C'est à ce moment que je veux initialiser une variable avec un tableau numpy, je ne peux pas trouver un moyen propre et efficace de le faire comme je le fais avec tf.Variable. Comment le résolvez-vous? Merci.
Lifu Huang

69

tf.Variable est une classe, et il existe plusieurs façons de créer tf.Variable, y compris tf.Variable.__init__et tf.get_variable.

tf.Variable.__init__: Crée une nouvelle variable avec valeur_initial .

W = tf.Variable(<initial-value>, name=<optional-name>)

tf.get_variable: Obtient une variable existante avec ces paramètres ou en crée une nouvelle. Vous pouvez également utiliser l'initialiseur.

W = tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None,
       regularizer=None, trainable=True, collections=None)

Il est très utile d'utiliser des initialiseurs tels que xavier_initializer:

W = tf.get_variable("W", shape=[784, 256],
       initializer=tf.contrib.layers.xavier_initializer())

Plus d'informations ici .


Oui, en Variablefait, je veux dire en utilisant son __init__. Comme get_variablec'est si pratique, je me demande pourquoi la plupart du code TensorFlow que j'ai vu utiliser à la Variableplace get_variable. Y a-t-il des conventions ou des facteurs à prendre en compte lors du choix entre eux. Je vous remercie!
Lifu Huang

Si vous voulez avoir une certaine valeur, l'utilisation de Variable est simple: x = tf.Variable (3).
Sung Kim

@SungKim normalement lorsque nous l'utilisons, tf.Variable()nous pouvons l'initialiser en tant que valeur aléatoire à partir d'une distribution normale tronquée. Voici mon exemple w1 = tf.Variable(tf.truncated_normal([5, 50], stddev = 0.01), name = 'w1'). Quel serait l'équivalent de cela? comment dire que je veux une normale tronquée? Dois-je juste faire w1 = tf.get_variable(name = 'w1', shape = [5,50], initializer = tf.truncated_normal, regularizer = tf.nn.l2_loss)?
Euler_Salter

@Euler_Salter: Vous pouvez utiliser tf.truncated_normal_initializer()pour obtenir le résultat souhaité.
Bêta du

46

Je peux trouver deux différences principales entre l'un et l'autre:

  1. Tout d'abord, cela tf.Variablecréera toujours une nouvelle variable, alors que tf.get_variablerécupère une variable existante avec des paramètres spécifiés à partir du graphique, et si elle n'existe pas, en crée une nouvelle.

  2. tf.Variable nécessite qu'une valeur initiale soit spécifiée.

Il est important de préciser que la fonction tf.get_variablepréfixe le nom avec la portée de la variable actuelle pour effectuer des vérifications de réutilisation. Par exemple:

with tf.variable_scope("one"):
    a = tf.get_variable("v", [1]) #a.name == "one/v:0"
with tf.variable_scope("one"):
    b = tf.get_variable("v", [1]) #ValueError: Variable one/v already exists
with tf.variable_scope("one", reuse = True):
    c = tf.get_variable("v", [1]) #c.name == "one/v:0"

with tf.variable_scope("two"):
    d = tf.get_variable("v", [1]) #d.name == "two/v:0"
    e = tf.Variable(1, name = "v", expected_shape = [1]) #e.name == "two/v_1:0"

assert(a is c)  #Assertion is true, they refer to the same object.
assert(a is d)  #AssertionError: they are different objects
assert(d is e)  #AssertionError: they are different objects

La dernière erreur d'assertion est intéressante: deux variables de même nom sous la même portée sont supposées être la même variable. Mais si vous testez les noms des variables det que evous vous rendrez compte que Tensorflow a changé le nom de la variable e:

d.name   #d.name == "two/v:0"
e.name   #e.name == "two/v_1:0"

Excellent exemple! En ce qui concerne d.nameet e.name, je viens de tomber sur un document TensorFlow sur l'opération de dénomination des graphes tensoriels qui l'explique:If the default graph already contained an operation named "answer", the TensorFlow would append "_1", "_2", and so on to the name, in order to make it unique.
Atlas7

2

Une autre différence réside dans le fait que l'un est en ('variable_store',)collection mais que l'autre ne l'est pas.

Veuillez consulter le code source :

def _get_default_variable_store():
  store = ops.get_collection(_VARSTORE_KEY)
  if store:
    return store[0]
  store = _VariableStore()
  ops.add_to_collection(_VARSTORE_KEY, store)
  return store

Laissez-moi illustrer cela:

import tensorflow as tf
from tensorflow.python.framework import ops

embedding_1 = tf.Variable(tf.constant(1.0, shape=[30522, 1024]), name="word_embeddings_1", dtype=tf.float32) 
embedding_2 = tf.get_variable("word_embeddings_2", shape=[30522, 1024])

graph = tf.get_default_graph()
collections = graph.collections

for c in collections:
    stores = ops.get_collection(c)
    print('collection %s: ' % str(c))
    for k, store in enumerate(stores):
        try:
            print('\t%d: %s' % (k, str(store._vars)))
        except:
            print('\t%d: %s' % (k, str(store)))
    print('')

Le résultat:

collection ('__variable_store',): 0: {'word_embeddings_2': <tf.Variable 'word_embeddings_2:0' shape=(30522, 1024) dtype=float32_ref>}

En utilisant notre site, vous reconnaissez avoir lu et compris notre politique liée aux cookies et notre politique de confidentialité.
Licensed under cc by-sa 3.0 with attribution required.