Restaurar subconjunto de variables en Tensorflow

Estoy entrenando una Red de Publicidad Generativa (GAN) en tensorflow, donde básicamente tenemos dos redes diferentes, cada una con su propio optimizador.

self.G, self.layer = self.generator(self.inputCT,batch_size_tf) self.D, self.D_logits = self.discriminator(self.GT_1hot) ... self.g_optim = tf.train.MomentumOptimizer(self.learning_rate_tensor, 0.9).minimize(self.g_loss, global_step=self.global_step) self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5) \ .minimize(self.d_loss, var_list=self.d_vars) 

El problema es que primero entreno una de las redes (g) y luego quiero entrenar g y d juntos. Sin embargo, cuando llamo a la función de carga:

 self.sess.run(tf.initialize_all_variables()) self.sess.graph.finalize() self.load(self.checkpoint_dir) def load(self, checkpoint_dir): print(" [*] Reading checkpoints...") ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) self.saver.restre(self.sess, ckpt.model_checkpoint_path) return True else: return False 

Tengo un error como este (con mucho más rastreo):

 Tensor name "beta2_power" not found in checkpoint files checkpoint/MR2CT.model-96000 

Puedo restaurar la red g y seguir entrenando con esa función, pero cuando quiero comenzar d desde cero y g desde el modelo almacenado tengo ese error.

Para restaurar un subconjunto de variables, debe crear un nuevo tf.train.Saver y pasarle una lista específica de variables para restaurar en el argumento var_list opcional.

De forma predeterminada, un tf.train.Saver creará operaciones que (i) guardarán cada variable en su gráfico cuando llame a saver.save() y (ii) busquen (por nombre) cada variable en el punto de control dado cuando llame a saver.restre() . Si bien esto funciona para los escenarios más comunes, debe proporcionar más información para trabajar con subconjuntos específicos de las variables:

  1. Si solo desea restaurar un subconjunto de las variables, puede obtener una lista de estas variables llamando a tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=G_NETWORK_PREFIX) , asumiendo que pone la red “g” en común with tf.name_scope(G_NETWORK_PREFIX): o tf.variable_scope(G_NETWORK_PREFIX): bloque. A continuación, puede pasar esta lista al constructor tf.train.Saver .

  2. Si desea restaurar un subconjunto de la variable y / o las variables en el punto de control tienen nombres diferentes , puede pasar un diccionario como el argumento var_list . De forma predeterminada, cada variable en un punto de control está asociada con una clave , que es el valor de su propiedad tf.Variable.name . Si el nombre es diferente en el gráfico de destino (p. Ej., Porque agregó un prefijo de scope), puede especificar un diccionario que tf.Variable claves de cadena (en el archivo de punto de control) a los objetos de tf.Variable (en el gráfico de destino).

Inspirado por @mrry, propongo una solución para este problema. Para dejarlo claro, formulo el problema como restaurar un subconjunto de la variable desde el punto de control, cuando el modelo se construye sobre un modelo pre-entrenado. Primero, debemos usar la función print_tensors_in_checkpoint_file de la biblioteca inspect_checkpoint o simplemente extraer esta función mediante:

 from tensorflow.python import pywrap_tensorflow def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors): varlist=[] reader = pywrap_tensorflow.NewCheckpointReader(file_name) if all_tensors: var_to_shape_map = reader.get_variable_to_shape_map() for key in sorted(var_to_shape_map): varlist.append(key) return varlist varlist=print_tensors_in_checkpoint_file(file_name=the path of the ckpt file,all_tensors=True,tensor_name=None) 

Luego usamos tf.get_collection () como @mrry saied:

 variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 

Finalmente, podemos inicializar el ahorrador mediante:

 saver = tf.train.Saver(variable[:len(varlist)]) 

La versión completa se puede encontrar en mi github: https://github.com/pobingwanghai/tensorflow_trick/blob/master/restre_from_checkpoint.py

En mi situación, las nuevas variables se agregan al final del modelo, así que simplemente puedo usar [: length ()] para identificar las variables necesarias, para una situación más compleja, es posible que tenga que hacer un trabajo de alineación manual o escriba una función de coincidencia de cadena simple para determinar las variables requeridas.

Tuve un problema similar al restaurar solo una parte de mis variables desde un punto de control y algunas de las variables guardadas no existían en el nuevo modelo. Inspirado por la respuesta de @Lidong, modifiqué un poco la función de lectura:

 def get_tensors_in_checkpoint_file(file_name,all_tensors=True,tensor_name=None): varlist=[] var_value =[] reader = pywrap_tensorflow.NewCheckpointReader(file_name) if all_tensors: var_to_shape_map = reader.get_variable_to_shape_map() for key in sorted(var_to_shape_map): varlist.append(key) var_value.append(reader.get_tensor(key)) else: varlist.append(tensor_name) var_value.append(reader.get_tensor(tensor_name)) return (varlist, var_value) 

y agregó una función de carga:

 def build_tensors_in_checkpoint_file(loaded_tensors): full_var_list = list() # Loop all loaded tensors for i, tensor_name in enumerate(loaded_tensors[0]): # Extract tensor try: tensor_aux = tf.get_default_graph().get_tensor_by_name(tensor_name+":0") except: print('Not found: '+tensor_name) full_var_list.append(tensor_aux) return full_var_list 

Entonces simplemente puede cargar todas las variables comunes usando:

 CHECKPOINT_NAME = path to save file restred_vars = get_tensors_in_checkpoint_file(file_name=CHECKPOINT_NAME) tensors_to_load = build_tensors_in_checkpoint_file(restred_vars) loader = tf.train.Saver(tensors_to_load) loader.restre(sess, CHECKPOINT_NAME) 

Edición: estoy usando tensorflow 1.2

Puede crear una instancia separada de tf.train.Saver() con el argumento var_list establecido en las variables que desea restaurar. Y crea una instancia separada para guardar las variables.