Transferencia de aprendizaje con tf.estimator.Estimator framework

Estoy tratando de transferir el aprendizaje de un modelo Inception-resnet v2 pre-entrenado en imagenet, usando mi propio conjunto de datos y clases. Mi base de código original era una modificación de una muestra tf.slim que ya no puedo encontrar y ahora estoy tratando de volver a escribir el mismo código utilizando el marco tf.estimator.* .

Sin embargo, me estoy topando con el problema de cargar solo algunos de los pesos desde el punto de control pre-entrenado, inicializando las capas restantes con sus inicializadores predeterminados.

Al investigar el problema, encontré este problema de GitHub y esta pregunta , ambos mencionando la necesidad de usar tf.train.init_from_checkpoint en mi model_fn . Lo intenté, pero dada la falta de ejemplos en ambos, supongo que me equivoqué.

Este es mi ejemplo mínimo:

 import sys import os os.environ['CUDA_VISIBLE_DEVICES'] = '-1' import tensorflow as tf import numpy as np import inception_resnet_v2 NUM_CLASSES = 900 IMAGE_SIZE = 299 def input_fn(mode, num_classes, batch_size=1): # some code that loads images, reshapes them to 299x299x3 and batches them return tf.constant(np.zeros([batch_size, 299, 299, 3], np.float32)), tf.one_hot(tf.constant(np.zeros([batch_size], np.int32)), NUM_CLASSES) def model_fn(images, labels, num_classes, mode): with tf.contrib.slim.arg_scope(inception_resnet_v2.inception_resnet_v2_arg_scope()): logits, end_points = inception_resnet_v2.inception_resnet_v2(images, num_classes, is_training=(mode==tf.estimator.ModeKeys.TRAIN)) predictions = { 'classes': tf.argmax(input=logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) exclude = ['InceptionResnetV2/Logits', 'InceptionResnetV2/AuxLogits'] variables_to_restre = tf.contrib.slim.get_variables_to_restre(exclude=exclude) scopes = { os.path.dirname(v.name) for v in variables_to_restre } tf.train.init_from_checkpoint('inception_resnet_v2_2016_08_30.ckpt', {s+'/':s+'/' for s in scopes}) tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits) total_loss = tf.losses.get_total_loss() #obtain the regularization losses as well # Configure the training op if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_or_create_global_step() optimizer = tf.train.AdamOptimizer(learning_rate=0.00002) train_op = optimizer.minimize(total_loss, global_step) else: train_op = None return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=total_loss, train_op=train_op) def main(unused_argv): # Create the Estimator classifier = tf.estimator.Estimator( model_fn=lambda features, labels, mode: model_fn(features, labels, NUM_CLASSES, mode), model_dir='model/MCVE') # Train the model classifier.train( input_fn=lambda: input_fn(tf.estimator.ModeKeys.TRAIN, NUM_CLASSES, batch_size=1), steps=1000) # Evaluate the model and print results eval_results = classifier.evaluate( input_fn=lambda: input_fn(tf.estimator.ModeKeys.EVAL, NUM_CLASSES, batch_size=1)) print() print('Evaluation results:\n %s' % eval_results) if __name__ == '__main__': tf.app.run(main=main, argv=[sys.argv[0]]) 

donde inception_resnet_v2 es la implementación del modelo en el repository de modelos de Tensorflow .

Si ejecuto este script, obtengo un montón de registro de información de init_from_checkpoint , pero luego, en el momento de la creación de la sesión, parece que intenta cargar los pesos de Logits desde el punto de control y falla debido a formas incompatibles. Este es el rastreo completo:

 Traceback (most recent call last): File "", line 1, in  runfile('C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py', wdir='C:/Users/1/Desktop/transfer_learning_tutorial-master') File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\site\sitecustomize.py", line 710, in runfile execfile(filename, namespace) File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\site\sitecustomize.py", line 101, in execfile exec(compile(f.read(), filename, 'exec'), namespace) File "C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py", line 77, in  tf.app.run(main=main, argv=[sys.argv[0]]) File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\platform\app.py", line 48, in run _sys.exit(main(_sys.argv[:1] + flags_passthrough)) File "C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py", line 68, in main steps=1000) File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 302, in train loss = self._train_model(input_fn, hooks, saving_listeners) File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 780, in _train_model log_step_count_steps=self._config.log_step_count_steps) as mon_sess: File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 368, in MonitoredTrainingSession stop_grace_period_secs=stop_grace_period_secs) File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 673, in __init__ stop_grace_period_secs=stop_grace_period_secs) File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 493, in __init__ self._sess = _RecoverableSession(self._coordinated_creator) File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 851, in __init__ _WrappedSession.__init__(self, self._create_session()) File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 856, in _create_session return self._sess_creator.create_session() File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 554, in create_session self.tf_sess = self._session_creator.create_session() File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 428, in create_session init_fn=self._scaffold.init_fn) File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\session_manager.py", line 279, in prepare_session sess.run(init_op, feed_dict=init_feed_dict) File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 889, in run run_metadata_ptr) File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1120, in _run feed_dict_tensor, options, run_metadata) File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1317, in _do_run options, run_metadata) File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1336, in _do_call raise type(e)(node_def, op, message) InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [900] rhs shape= [1001] [[Node: Assign_1145 = Assign[T=DT_FLOAT, _class=["loc:@InceptionResnetV2/Logits/Logits/biases"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](InceptionResnetV2/Logits/Logits/biases, checkpoint_initializer_1145)]] 

¿Qué estoy haciendo mal al usar init_from_checkpoint ? ¿Cómo se supone que debemos “usarlo” en nuestro model_fn ? ¿Y por qué el estimador está tratando de cargar los pesos de Logits desde el punto de control cuando le digo explícitamente que no lo Logits ?

Actualizar:

Después de la sugerencia en los comentarios, probé formas alternativas de llamar a tf.train.init_from_checkpoint .

Usando {v.name: v.name}

Si, como se sugiere en el comentario, sustituyo la llamada con {v.name:v.name for v in variables_to_restre} , obtengo este error:

 ValueError: Assignment map with scope only name InceptionResnetV2/Conv2d_2a_3x3 should map to scope only InceptionResnetV2/Conv2d_2a_3x3/weights:0. Should be 'scope/': 'other_scope/'. 

Usando {v.name: v}

Si, en cambio, trato de usar el name:variable asignación de name:variable , obtengo el siguiente error:

 ValueError: Tensor InceptionResnetV2/Conv2d_2a_3x3/weights:0 is not found in inception_resnet_v2_2016_08_30.ckpt checkpoint {'InceptionResnetV2/Repeat_2/block8_4/Branch_1/Conv2d_0c_3x1/BatchNorm/moving_mean': [256], 'InceptionResnetV2/Repeat/block35_9/Branch_0/Conv2d_1x1/BatchNorm/beta': [32], ... 

El error continúa enumerando lo que creo que son todos los nombres de variables en el punto de control (¿o podrían ser los ámbitos en su lugar?).

Actualizar (2)

Después de inspeccionar el error más reciente aquí arriba, veo que InceptionResnetV2/Conv2d_2a_3x3/weights está en la lista de las variables de punto de control. El problema es que :0 al final! Ahora verificaré si esto realmente resuelve el problema y publicaré una respuesta si ese es el caso.

Gracias al comentario de @KathyWu, fui por el camino correcto y encontré el problema.

De hecho, la forma en que estaba calculando los scopes incluiría el InceptionResnetV2/ scope, que activaría la carga de todas las variables “debajo” del ámbito (es decir, todas las variables en la red). Reemplazar esto con el diccionario correcto, sin embargo, no fue trivial.

De los posibles modos de init_from_checkpoint acepta init_from_checkpoint , el que tuve que usar fue el 'scope_variable_name': variable one, pero sin usar el atributo variable.name real .

La variable.name ve así: 'some_scope/variable_name:0' . Eso :0 no está en el nombre de la variable de punto de control y, por lo tanto, usar scopes = {v.name:v.name for v in variables_to_restre} generará un error de “Variable no encontrado”.

El truco para hacerlo funcionar fue quitar el índice de tensor del nombre :

 tf.train.init_from_checkpoint('inception_resnet_v2_2016_08_30.ckpt', {v.name.split(':')[0]: v for v in variables_to_restre}) 

Descubrí que {s+'/':s+'/' for s in scopes} no funcionó, solo porque las variables_to_restre incluyen algo como "global_step" , así que los ámbitos incluyen los ámbitos globales que podrían incluir todo. Necesitas imprimir variables_to_restre , encontrar algo "global_step" y ponerlo en "exclude" .