Tensorflow: ¿Cómo guardar / restaurar un modelo?

Después de entrenar un modelo en Tensorflow:

  1. ¿Cómo se guarda el modelo entrenado?
  2. ¿Cómo restaurar más tarde este modelo guardado?

Docs

Construyeron un tutorial exhaustivo y útil -> https://www.tensorflow.org/guide/saved_model

De los documentos:

Salvar

# Create some variables. v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer) v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer) inc_v1 = v1.assign(v1+1) dec_v2 = v2.assign(v2-1) # Add an op to initialize the variables. init_op = tf.global_variables_initializer() # Add ops to save and restre all the variables. saver = tf.train.Saver() # Later, launch the model, initialize the variables, do some work, and save the # variables to disk. with tf.Session() as sess: sess.run(init_op) # Do some work with the model. inc_v1.op.run() dec_v2.op.run() # Save the variables to disk. save_path = saver.save(sess, "/tmp/model.ckpt") print("Model saved in path: %s" % save_path) 

Restaurar

 tf.reset_default_graph() # Create some variables. v1 = tf.get_variable("v1", shape=[3]) v2 = tf.get_variable("v2", shape=[5]) # Add ops to save and restre all the variables. saver = tf.train.Saver() # Later, launch the model, use the saver to restre variables from disk, and # do some work with the model. with tf.Session() as sess: # Restore variables from disk. saver.restre(sess, "/tmp/model.ckpt") print("Model restred.") # Check the values of the variables print("v1 : %s" % v1.eval()) print("v2 : %s" % v2.eval()) 

Tensor de flujo <2

simple_save

Muchas buenas respuestas, para completar, agregaré mis 2 centavos: simple_save . También un ejemplo de código independiente que utiliza la API tf.data.Dataset .

Python 3; Tensorflow 1.7

 import tensorflow as tf from tensorflow.python.saved_model import tag_constants with tf.Graph().as_default(): with tf.Session as sess: ... # Saving inputs = { "batch_size_placeholder": batch_size_placeholder, "features_placeholder": features_placeholder, "labels_placeholder": labels_placeholder, } outputs = {"prediction": model_output} tf.saved_model.simple_save( sess, 'path/to/your/location/', inputs, outputs ) 

Restaurando

 graph = tf.Graph() with restred_graph.as_default(): with tf.Session as sess: tf.saved_model.loader.load( sess, [tag_constants.SERVING], 'path/to/your/location/', ) batch_size_placeholder = graph.get_tensor_by_name('batch_size_placeholder:0') features_placeholder = graph.get_tensor_by_name('features_placeholder:0') labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0') prediction = restred_graph.get_tensor_by_name('dense/BiasAdd:0') sess.run(prediction, feed_dict={ batch_size_placeholder: some_value, features_placeholder: some_other_value, labels_placeholder: another_value }) 

Ejemplo independiente

Entrada de blog original

El siguiente código genera datos aleatorios para la demostración.

  1. Comenzamos creando los marcadores de posición. Ellos mantendrán los datos en tiempo de ejecución. A partir de ellos, creamos el Dataset y luego su Iterator . Obtenemos el tensor generado del iterador, llamado input_tensor que servirá como entrada para nuestro modelo.
  2. El modelo en sí está construido a partir de input_tensor : un RNN bidireccional basado en GRU seguido de un clasificador denso. Porque, porque no.
  3. La pérdida es un softmax_cross_entropy_with_logits , optimizado con Adam . Después de 2 épocas (de 2 lotes cada una), guardamos el modelo “entrenado” con tf.saved_model.simple_save . Si ejecuta el código como está, entonces el modelo se guardará en una carpeta llamada simple/ en su directorio de trabajo actual.
  4. En un nuevo gráfico, luego restauramos el modelo guardado con tf.saved_model.loader.load . Tomamos los marcadores de posición y logits con graph.get_tensor_by_name y la operación de inicialización del Iterator con graph.get_operation_by_name .
  5. Por último, ejecutamos una inferencia para ambos lotes en el conjunto de datos y verificamos que el modelo guardado y restaurado produzca los mismos valores. ¡Ellas hacen!

Código:

 import os import shutil import numpy as np import tensorflow as tf from tensorflow.python.saved_model import tag_constants def model(graph, input_tensor): """Create the model which consists of a bidirectional rnn (GRU(10)) followed by a dense classifier Args: graph (tf.Graph): Tensors' graph input_tensor (tf.Tensor): Tensor fed as input to the model Returns: tf.Tensor: the model's output layer Tensor """ cell = tf.nn.rnn_cell.GRUCell(10) with graph.as_default(): ((fw_outputs, bw_outputs), (fw_state, bw_state)) = tf.nn.bidirectional_dynamic_rnn( cell_fw=cell, cell_bw=cell, inputs=input_tensor, sequence_length=[10] * 32, dtype=tf.float32, swap_memory=True, scope=None) outputs = tf.concat((fw_outputs, bw_outputs), 2) mean = tf.reduce_mean(outputs, axis=1) dense = tf.layers.dense(mean, 5, activation=None) return dense def get_opt_op(graph, logits, labels_tensor): """Create optimization operation from model's logits and labels Args: graph (tf.Graph): Tensors' graph logits (tf.Tensor): The model's output without activation labels_tensor (tf.Tensor): Target labels Returns: tf.Operation: the operation performing a stem of Adam optimizer """ with graph.as_default(): with tf.variable_scope('loss'): loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( logits=logits, labels=labels_tensor, name='xent'), name="mean-xent" ) with tf.variable_scope('optimizer'): opt_op = tf.train.AdamOptimizer(1e-2).minimize(loss) return opt_op if __name__ == '__main__': # Set random seed for reproducibility # and create synthetic data np.random.seed(0) features = np.random.randn(64, 10, 30) labels = np.eye(5)[np.random.randint(0, 5, (64,))] graph1 = tf.Graph() with graph1.as_default(): # Random seed for reproducibility tf.set_random_seed(0) # Placeholders batch_size_ph = tf.placeholder(tf.int64, name='batch_size_ph') features_data_ph = tf.placeholder(tf.float32, [None, None, 30], 'features_data_ph') labels_data_ph = tf.placeholder(tf.int32, [None, 5], 'labels_data_ph') # Dataset dataset = tf.data.Dataset.from_tensor_slices((features_data_ph, labels_data_ph)) dataset = dataset.batch(batch_size_ph) iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes) dataset_init_op = iterator.make_initializer(dataset, name='dataset_init') input_tensor, labels_tensor = iterator.get_next() # Model logits = model(graph1, input_tensor) # Optimization opt_op = get_opt_op(graph1, logits, labels_tensor) with tf.Session(graph=graph1) as sess: # Initialize variables tf.global_variables_initializer().run(session=sess) for epoch in range(3): batch = 0 # Initialize dataset (could feed epochs in Dataset.repeat(epochs)) sess.run( dataset_init_op, feed_dict={ features_data_ph: features, labels_data_ph: labels, batch_size_ph: 32 }) values = [] while True: try: if epoch < 2: # Training _, value = sess.run([opt_op, logits]) print('Epoch {}, batch {} | Sample value: {}'.format(epoch, batch, value[0])) batch += 1 else: # Final inference values.append(sess.run(logits)) print('Epoch {}, batch {} | Final inference | Sample value: {}'.format(epoch, batch, values[-1][0])) batch += 1 except tf.errors.OutOfRangeError: break # Save model state print('\nSaving...') cwd = os.getcwd() path = os.path.join(cwd, 'simple') shutil.rmtree(path, ignore_errors=True) inputs_dict = { "batch_size_ph": batch_size_ph, "features_data_ph": features_data_ph, "labels_data_ph": labels_data_ph } outputs_dict = { "logits": logits } tf.saved_model.simple_save( sess, path, inputs_dict, outputs_dict ) print('Ok') # Restoring graph2 = tf.Graph() with graph2.as_default(): with tf.Session(graph=graph2) as sess: # Restore saved values print('\nRestoring...') tf.saved_model.loader.load( sess, [tag_constants.SERVING], path ) print('Ok') # Get restored placeholders labels_data_ph = graph2.get_tensor_by_name('labels_data_ph:0') features_data_ph = graph2.get_tensor_by_name('features_data_ph:0') batch_size_ph = graph2.get_tensor_by_name('batch_size_ph:0') # Get restored model output restored_logits = graph2.get_tensor_by_name('dense/BiasAdd:0') # Get dataset initializing operation dataset_init_op = graph2.get_operation_by_name('dataset_init') # Initialize restored dataset sess.run( dataset_init_op, feed_dict={ features_data_ph: features, labels_data_ph: labels, batch_size_ph: 32 } ) # Compute inference for both batches in dataset restored_values = [] for i in range(2): restored_values.append(sess.run(restored_logits)) print('Restored values: ', restored_values[i][0]) # Check if original inference and restored inference are equal valid = all((v == rv).all() for v, rv in zip(values, restored_values)) print('\nInferences match: ', valid) 

Esto imprimirá:

 $ python3 save_and_restre.py Epoch 0, batch 0 | Sample value: [-0.13851789 -0.3087595 0.12804556 0.20013677 -0.08229901] Epoch 0, batch 1 | Sample value: [-0.00555491 -0.04339041 -0.05111827 -0.2480045 -0.00107776] Epoch 1, batch 0 | Sample value: [-0.19321944 -0.2104792 -0.00602257 0.07465433 0.11674127] Epoch 1, batch 1 | Sample value: [-0.05275984 0.05981954 -0.15913513 -0.3244143 0.10673307] Epoch 2, batch 0 | Final inference | Sample value: [-0.26331693 -0.13013336 -0.12553 -0.04276478 0.2933622 ] Epoch 2, batch 1 | Final inference | Sample value: [-0.07730117 0.11119192 -0.20817074 -0.35660955 0.16990358] Saving... INFO:tensorflow:Assets added to graph. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: b'/some/path/simple/saved_model.pb' Ok Restoring... INFO:tensorflow:Restoring parameters from b'/some/path/simple/variables/variables' Ok Restored values: [-0.26331693 -0.13013336 -0.12553 -0.04276478 0.2933622 ] Restored values: [-0.07730117 0.11119192 -0.20817074 -0.35660955 0.16990358] Inferences match: True 

Estoy mejorando mi respuesta para agregar más detalles para guardar y restaurar modelos.

En (y después) Tensorflow versión 0.11 :

Guarde el modelo:

 import tensorflow as tf #Prepare to feed input, ie feed_dict and placeholders w1 = tf.placeholder("float", name="w1") w2 = tf.placeholder("float", name="w2") b1= tf.Variable(2.0,name="bias") feed_dict ={w1:4,w2:8} #Define a test operation that we will restre w3 = tf.add(w1,w2) w4 = tf.multiply(w3,b1,name="op_to_restre") sess = tf.Session() sess.run(tf.global_variables_initializer()) #Create a saver object which will save all the variables saver = tf.train.Saver() #Run the operation by feeding input print sess.run(w4,feed_dict) #Prints 24 which is sum of (w1+w2)*b1 #Now, save the graph saver.save(sess, 'my_test_model',global_step=1000) 

Restaura el modelo:

 import tensorflow as tf sess=tf.Session() #First let's load meta graph and restre weights saver = tf.train.import_meta_graph('my_test_model-1000.meta') saver.restre(sess,tf.train.latest_checkpoint('./')) # Access saved Variables directly print(sess.run('bias:0')) # This will print 2, which is the value of bias that we saved # Now, let's access and create placeholders variables and # create feed-dict to feed new data graph = tf.get_default_graph() w1 = graph.get_tensor_by_name("w1:0") w2 = graph.get_tensor_by_name("w2:0") feed_dict ={w1:13.0,w2:17.0} #Now, access the op that you want to run. op_to_restre = graph.get_tensor_by_name("op_to_restre:0") print sess.run(op_to_restre,feed_dict) #This will print 60 which is calculated 

Este y algunos casos de uso más avanzados se han explicado muy bien aquí.

Un rápido tutorial completo para guardar y restaurar modelos Tensorflow

En (y después) TensorFlow versión 0.11.0RC1, puede guardar y restaurar su modelo directamente llamando a tf.train.export_meta_graph y tf.train.import_meta_graph acuerdo con https://www.tensorflow.org/programmers_guide/meta_graph .

Guardar el modelo

 w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1') w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2') tf.add_to_collection('vars', w1) tf.add_to_collection('vars', w2) saver = tf.train.Saver() sess = tf.Session() sess.run(tf.global_variables_initializer()) saver.save(sess, 'my-model') # `save` method will call `export_meta_graph` implicitly. # you will get saved graph files:my-model.meta 

Restaurar el modelo

 sess = tf.Session() new_saver = tf.train.import_meta_graph('my-model.meta') new_saver.restre(sess, tf.train.latest_checkpoint('./')) all_vars = tf.get_collection('vars') for v in all_vars: v_ = sess.run(v) print(v_) 

Para la versión TensorFlow <0.11.0RC1:

Los puntos de control que se guardan contienen valores para la Variable s en su modelo, no el modelo / gráfico en sí, lo que significa que el gráfico debe ser el mismo cuando restaura el punto de control.

Aquí hay un ejemplo de una regresión lineal donde hay un ciclo de entrenamiento que guarda puntos de control variables y una sección de evaluación que restaurará las variables guardadas en una ejecución anterior y calculará predicciones. Por supuesto, también puede restaurar variables y continuar el entrenamiento si lo desea.

 x = tf.placeholder(tf.float32) y = tf.placeholder(tf.float32) w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32)) b = tf.Variable(tf.ones([1, 1], dtype=tf.float32)) y_hat = tf.add(b, tf.matmul(x, w)) ...more setup for optimization and what not... saver = tf.train.Saver() # defaults to saving all variables - in this case w and b with tf.Session() as sess: sess.run(tf.initialize_all_variables()) if FLAGS.train: for i in xrange(FLAGS.training_steps): ...training loop... if (i + 1) % FLAGS.checkpoint_steps == 0: saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt', global_step=i+1) else: # Here's where you're restring the variables w and b. # Note that the graph is exactly as it was when the variables were # saved in a prior training run. ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restre(sess, ckpt.model_checkpoint_path) else: ...no checkpoint found... # Now you can run the model to get predictions batch_x = ...load some data... predictions = sess.run(y_hat, feed_dict={x: batch_x}) 

Aquí están los documentos para Variable s, que cubren guardar y restaurar. Y aquí están los documentos para el Saver .

Mi entorno: Python 3.6, Tensorflow 1.3.0

Aunque ha habido muchas soluciones, la mayoría de ellas se basan en tf.train.Saver . Cuando .ckpt un .ckpt guardado por Saver , tenemos que redefinir la red de tensorflow o usar algún nombre extraño y recordado, por ejemplo, 'placehold_0:0' , 'dense/Adam/Weight:0' . Aquí recomiendo usar tf.saved_model , uno de los ejemplos más simples que se dan a continuación, puede obtener más información sobre tf.saved_model Servir un Modelo TensorFlow :

Guarde el modelo:

 import tensorflow as tf # define the tensorflow network and do some trains x = tf.placeholder("float", name="x") w = tf.Variable(2.0, name="w") b = tf.Variable(0.0, name="bias") h = tf.multiply(x, w) y = tf.add(h, b, name="y") sess = tf.Session() sess.run(tf.global_variables_initializer()) # save the model export_path = './savedmodel' builder = tf.saved_model.builder.SavedModelBuilder(export_path) tensor_info_x = tf.saved_model.utils.build_tensor_info(x) tensor_info_y = tf.saved_model.utils.build_tensor_info(y) prediction_signature = ( tf.saved_model.signature_def_utils.build_signature_def( inputs={'x_input': tensor_info_x}, outputs={'y_output': tensor_info_y}, method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)) builder.add_meta_graph_and_variables( sess, [tf.saved_model.tag_constants.SERVING], signature_def_map={ tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature }, ) builder.save() 

Cargue el modelo:

 import tensorflow as tf sess=tf.Session() signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY input_key = 'x_input' output_key = 'y_output' export_path = './savedmodel' meta_graph_def = tf.saved_model.loader.load( sess, [tf.saved_model.tag_constants.SERVING], export_path) signature = meta_graph_def.signature_def x_tensor_name = signature[signature_key].inputs[input_key].name y_tensor_name = signature[signature_key].outputs[output_key].name x = sess.graph.get_tensor_by_name(x_tensor_name) y = sess.graph.get_tensor_by_name(y_tensor_name) y_out = sess.run(y, {x: 3.0}) 

El modelo tiene dos partes, la definición del modelo, guardada por Supervisor como graph.pbtxt en el directorio del modelo y los valores numéricos de los tensores, guardados en archivos de punto de control como model.ckpt-1003418 .

La definición del modelo se puede restaurar usando tf.import_graph_def , y los pesos se restauran usando Saver .

Sin embargo, Saver utiliza una lista especial de colecciones de variables que se adjunta al modelo Graph, y esta colección no se inicializa con import_graph_def, por lo que no puede usar las dos juntas en este momento (está en nuestra hoja de ruta para corregir). Por ahora, tiene que usar el enfoque de Ryan Sepassi: construya manualmente un gráfico con nombres de nodos idénticos y use Saver para cargar los pesos en él.

(Alternativamente, puede piratearlo utilizando import_graph_def , creando variables manualmente y usando tf.add_to_collection(tf.GraphKeys.VARIABLES, variable) para cada variable, luego usando Saver )

También puedes tomar este camino más fácil.

Paso 1: inicializa todas tus variables

 W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1") B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1") Similarly, W2, B2, W3, ..... 

Paso 2: guarda la sesión dentro del modelo Saver y guárdala

 model_saver = tf.train.Saver() # Train the model and save it in the end model_saver.save(session, "saved_models/CNN_New.ckpt") 

Paso 3: restaurar el modelo

 with tf.Session(graph=graph_cnn) as session: model_saver.restre(session, "saved_models/CNN_New.ckpt") print("Model restred.") print('Initialized') 

Paso 4: revisa tu variable

 W1 = session.run(W1) print(W1) 

Mientras se ejecuta en diferentes instancias de python, use

 with tf.Session() as sess: # Restore latest checkpoint saver.restre(sess, tf.train.latest_checkpoint('saved_model/.')) # Initalize the variables sess.run(tf.global_variables_initializer()) # Get default graph (supply your custom graph if you have one) graph = tf.get_default_graph() # It will give tensor object W1 = graph.get_tensor_by_name('W1:0') # To get the value (numpy array) W1_value = session.run(W1) 

En la mayoría de los casos, guardar y restaurar desde el disco usando un tf.train.Saver es su mejor opción:

 ... # build your model saver = tf.train.Saver() with tf.Session() as sess: ... # train the model saver.save(sess, "/tmp/my_great_model") with tf.Session() as sess: saver.restre(sess, "/tmp/my_great_model") ... # use the model 

También puede guardar / restaurar la estructura del gráfico (consulte la documentación de MetaGraph para obtener más información). Por defecto, el Saver guarda la estructura del gráfico en un archivo .meta . Puede llamar a import_meta_graph() para restaurarlo. Restaura la estructura del gráfico y devuelve un Saver que puede usar para restaurar el estado del modelo:

 saver = tf.train.import_meta_graph("/tmp/my_great_model.meta") with tf.Session() as sess: saver.restre(sess, "/tmp/my_great_model") ... # use the model 

Sin embargo, hay casos en los que necesitas algo mucho más rápido. Por ejemplo, si implementa la detención temprana, desea guardar los puntos de control cada vez que el modelo mejore durante el entrenamiento (según lo medido en el conjunto de validación), luego, si no hay progreso durante algún tiempo, desea volver al mejor modelo. Si guarda el modelo en el disco cada vez que mejora, ralentizará enormemente el entrenamiento. El truco es guardar los estados variables en la memoria y luego restaurarlos más tarde:

 ... # build your model # get a handle on the graph nodes we need to save/restre the model graph = tf.get_default_graph() gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign") for v in gvars] init_values = [assign_op.inputs[1] for assign_op in assign_ops] with tf.Session() as sess: ... # train the model # when needed, save the model state to memory gvars_state = sess.run(gvars) # when needed, restre the model state feed_dict = {init_value: val for init_value, val in zip(init_values, gvars_state)} sess.run(assign_ops, feed_dict=feed_dict) 

Una explicación rápida: cuando crea una variable X , TensorFlow crea automáticamente una operación de asignación X/Assign para establecer el valor inicial de la variable. En lugar de crear marcadores de posición y operaciones de asignación adicionales (que solo harían desordenado el gráfico), solo usamos estas operaciones de asignación existentes. La primera entrada de cada operación de asignación es una referencia a la variable que se supone que debe inicializar, y la segunda entrada ( assign_op.inputs[1] ) es el valor inicial. Entonces, para establecer cualquier valor que queramos (en lugar del valor inicial), necesitamos usar un feed_dict y reemplazar el valor inicial. Sí, TensorFlow le permite alimentar un valor para cualquier operación, no solo para marcadores de posición, así que esto funciona bien.

Como dijo Yaroslav, puede hackear la restauración desde un graph_def y un punto de control importando el gráfico, creando variables manualmente y luego usando un Saver.

Implementé esto para mi uso personal, así que pensé en compartir el código aquí.

Enlace: https://gist.github.com/nikitakit/6ef3b72be67b86cb7868

(Esto es, por supuesto, un pirateo, y no hay garantía de que los modelos guardados de esta manera sigan siendo legibles en futuras versiones de TensorFlow).

Si es un modelo guardado internamente, simplemente especifique un restaurador para todas las variables como

 restrer = tf.train.Saver(tf.all_variables()) 

y úselo para restaurar variables en una sesión actual:

 restrer.restre(self._sess, model_file) 

Para el modelo externo, debe especificar la asignación de los nombres de las variables a los nombres de las variables. Puedes ver los nombres de las variables del modelo usando el comando

 python /path/to/tensorflow/tensorflow/python/tools/inspect_checkpoint.py --file_name=/path/to/pretrained_model/model.ckpt 

El script inspect_checkpoint.py se puede encontrar en la carpeta ‘./tensorflow/python/tools’ de la fuente Tensorflow.

Para especificar la asignación, puede usar mi Tensorflow-Worklab , que contiene un conjunto de clases y scripts para entrenar y volver a entrenar diferentes modelos. Incluye un ejemplo de reentrenamiento de modelos de ResNet, que se encuentra aquí.

Aquí está mi solución simple para los dos casos básicos que difieren en si desea cargar el gráfico desde un archivo o comstackrlo durante el tiempo de ejecución.

Esta respuesta es válida para Tensorflow 0.12+ (incluyendo 1.0).

Reconstruyendo la gráfica en código

Ahorro

 graph = ... # build the graph saver = tf.train.Saver() # create the saver after the graph with ... as sess: # your session object saver.save(sess, 'my-model') 

Cargando

 graph = ... # build the graph saver = tf.train.Saver() # create the saver after the graph with ... as sess: # your session object saver.restre(sess, tf.train.latest_checkpoint('./')) # now you can use the graph, continue training or whatever 

Cargando también la gráfica de un archivo.

Cuando utilice esta técnica, asegúrese de que todas sus capas / variables hayan establecido explícitamente nombres únicos. De lo contrario, Tensorflow hará que los nombres sean únicos y, por lo tanto, serán diferentes de los nombres almacenados en el archivo. No es un problema en la técnica anterior, porque los nombres están “mutilados” de la misma manera, tanto en la carga como en el ahorro.

Ahorro

 graph = ... # build the graph for op in [ ... ]: # operators you want to use after restring the model tf.add_to_collection('ops_to_restre', op) saver = tf.train.Saver() # create the saver after the graph with ... as sess: # your session object saver.save(sess, 'my-model') 

Cargando

 with ... as sess: # your session object saver = tf.train.import_meta_graph('my-model.meta') saver.restre(sess, tf.train.latest_checkpoint('./')) ops = tf.get_collection('ops_to_restre') # here are your operators in the same order in which you saved them to the collection 

También puede consultar ejemplos en TensorFlow / skflow , que ofrece métodos de save y restre que pueden ayudarlo a administrar fácilmente sus modelos. Tiene parámetros que también puede controlar con qué frecuencia desea realizar una copia de seguridad de su modelo.

Si usa tf.train.MonitoredTrainingSession como la sesión predeterminada, no necesita agregar código adicional para guardar / restaurar cosas. Just pass a checkpoint dir name to MonitoredTrainingSession’s constructor, it will use session hooks to handle these.

All the answers here are great, but I want to add two things.

First, to elaborate on @user7505159’s answer, the “./” can be important to add to the beginning of the file name that you are restring.

For example, you can save a graph with no “./” in the file name like so:

 # Some graph defined up here with specific names saver = tf.train.Saver() save_file = 'model.ckpt' with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver.save(sess, save_file) 

But in order to restre the graph, you may need to prepend a “./” to the file_name:

 # Same graph defined up here saver = tf.train.Saver() save_file = './' + 'model.ckpt' # String addition used for emphasis with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver.restre(sess, save_file) 

You will not always need the “./”, but it can cause problems depending on your environment and version of TensorFlow.

It also want to mention that the sess.run(tf.global_variables_initializer()) can be important before restring the session.

If you are receiving an error regarding uninitialized variables when trying to restre a saved session, make sure you include sess.run(tf.global_variables_initializer()) before the saver.restre(sess, save_file) line. It can save you a headache.

As described in issue 6255 :

 use '**./**model_name.ckpt' saver.restre(sess,'./my_model_final.ckpt') 

en lugar de

 saver.restre('my_model_final.ckpt') 

According to the new Tensorflow version, tf.train.Checkpoint is the preferable way of saving and restring a model:

Checkpoint.save and Checkpoint.restre write and read object-based checkpoints, in contrast to tf.train.Saver which writes and reads variable.name based checkpoints. Object-based checkpointing saves a graph of dependencies between Python objects (Layers, Optimizers, Variables, etc.) with named edges, and this graph is used to match variables when restring a checkpoint. It can be more robust to changes in the Python program, and helps to support restre-on-create for variables when executing eagerly. Prefer tf.train.Checkpoint over tf.train.Saver for new code .

Aquí hay un ejemplo:

 import tensorflow as tf import os tf.enable_eager_execution() checkpoint_directory = "/tmp/training_checkpoints" checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) status = checkpoint.restre(tf.train.latest_checkpoint(checkpoint_directory)) for _ in range(num_training_steps): optimizer.minimize( ... ) # Variables will be restred on creation. status.assert_consumed() # Optional sanity checks. checkpoint.save(file_prefix=checkpoint_prefix) 

More information and example here.

Use tf.train.Saver to save a model, remerber, you need to specify the var_list, if you want to reduce the model size. The val_list can be tf.trainable_variables or tf.global_variables.

You can save the variables in the network using

 saver = tf.train.Saver() saver.save(sess, 'path of save/fileName.ckpt') 

To restre the network for reuse later or in another script, use:

 saver = tf.train.Saver() saver.restre(sess, tf.train.latest_checkpoint('path of save/') sess.run(....) 

Important points:

  1. sess must be same between first and later runs (coherent structure).
  2. saver.restre needs the path of the folder of the saved files, not an individual file path.

Wherever you want to save the model,

 self.saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) ... self.saver.save(sess, filename) 

Make sure, all your tf.Variable have names, because you may want to restre them later using their names. And where you want to predict,

 saver = tf.train.import_meta_graph(filename) name = 'name given when you saved the file' with tf.Session() as sess: saver.restre(sess, name) print(sess.run('W1:0')) #example to retrieve by variable name 

Make sure that saver runs inside the corresponding session. Remember that, if you use the tf.train.latest_checkpoint('./') , then only the latest check point will be used.