Diferencia entre Variable y get_variable en TensorFlow

Que yo sepa, Variable es la operación predeterminada para hacer una variable, y get_variable se usa principalmente para compartir el peso.

Por un lado, hay algunas personas que sugieren usar get_variable lugar de la operación Variable primitiva siempre que necesite una variable. Por otro lado, simplemente veo cualquier uso de get_variable en los documentos y demostraciones oficiales de TensorFlow.

Por lo tanto, quiero saber algunas reglas básicas sobre cómo usar correctamente estos dos mecanismos. ¿Hay algún principio “estándar”?

Recomiendo usar siempre tf.get_variable(...) – hará que sea más fácil refactorizar su código si necesita compartir variables en cualquier momento, por ejemplo, en una configuración multi-gpu (vea la multi-gpu Ejemplo de CIFAR). No hay ningún inconveniente en ello.

La tf.Variable es de nivel inferior; en algún punto, tf.get_variable() no existía, por lo que algunos códigos todavía utilizan la forma de bajo nivel.

tf.Variable es una clase, y hay varias maneras de crear tf.Variable, incluyendo tf.Variable .__ init__ y tf.get_variable.

tf.Variable .__ init__: Crea una nueva variable con initial_value .

 W = tf.Variable(, name=) 

tf.get_variable: Obtiene una variable existente con estos parámetros o crea uno nuevo. También puede utilizar el inicializador.

 W = tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None, regularizer=None, trainable=True, collections=None) 

Es muy útil usar inicializadores como xavier_initializer:

 W = tf.get_variable("W", shape=[784, 256], initializer=tf.contrib.layers.xavier_initializer()) 

Más información en https://www.tensorflow.org/versions/r0.8/api_docs/python/state_ops.html#Variable .

Puedo encontrar dos diferencias principales entre una y la otra:

  1. Primero, tf.Variable siempre creará una nueva variable, ya sea que tf.get_variable obtenga de la gráfica una variable existente con esos parámetros, y si no existe, crea una nueva.

  2. tf.Variable requiere que se especifique un valor inicial.

Es importante aclarar que la función tf.get_variable prefija el nombre con el scope de la variable actual para realizar reutilizaciones. Por ejemplo:

 with tf.variable_scope("one"): a = tf.get_variable("v", [1]) #a.name == "one/v:0" with tf.variable_scope("one"): b = tf.get_variable("v", [1]) #ValueError: Variable one/v already exists with tf.variable_scope("one", reuse = True): c = tf.get_variable("v", [1]) #c.name == "one/v:0" with tf.variable_scope("two"): d = tf.get_variable("v", [1]) #d.name == "two/v:0" e = tf.Variable(1, name = "v", expected_shape = [1]) #e.name == "two/v_1:0" assert(a is c) #Assertion is true, they refer to the same object. assert(a is d) #AssertionError: they are different objects assert(d is e) #AssertionError: they are different objects 

El último error de aserción es interesante: se supone que dos variables con el mismo nombre en el mismo ámbito son la misma variable. Pero si prueba los nombres de las variables d y e se dará cuenta de que Tensorflow cambió el nombre de la variable e :

 d.name #d.name == "two/v:0" e.name #e.name == "two/v_1:0" 

Otra diferencia radica en que uno está en ('variable_store',) colección pero el otro no.

Por favor vea el código fuente :

 def _get_default_variable_store(): store = ops.get_collection(_VARSTORE_KEY) if store: return store[0] store = _VariableStore() ops.add_to_collection(_VARSTORE_KEY, store) return store 

Déjame ilustrar eso:

 import tensorflow as tf from tensorflow.python.framework import ops embedding_1 = tf.Variable(tf.constant(1.0, shape=[30522, 1024]), name="word_embeddings_1", dtype=tf.float32) embedding_2 = tf.get_variable("word_embeddings_2", shape=[30522, 1024]) graph = tf.get_default_graph() collections = graph.collections for c in collections: stores = ops.get_collection(c) print('collection %s: ' % str(c)) for k, store in enumerate(stores): try: print('\t%d: %s' % (k, str(store._vars))) except: print('\t%d: %s' % (k, str(store))) print('') 

La salida:

collection ('__variable_store',): 0: {'word_embeddings_2': }