¿Cómo guardar el estimador en Tensorflow para su uso posterior?

Seguí el tutorial “Una guía para las capas de TF: Construyendo una neural network convolucional” (aquí está el código: https://github.com/tensorflow/tensorflow/blob/r1.1/tensorflow/examples/tutorials/layers/cnn_mnist .py ).

Adapte el tutorial a mis necesidades, que es la detección de manos.

Por lo que entiendo, este tutorial crea el estimador (que es una CNN), luego realiza el ajuste y, finalmente, evalúa el rendimiento del estimador. Ahora, mi problema es que quiero usar el estimador en otro archivo, que será mi progtwig principal. ¿Cómo accedo al estimador desde otro archivo? ¿Tengo que ajustar el estimador cada vez que quiero usarlo? (Espero que no)

Me preguntaba si alguien podría ayudarme a entender cómo guardar el estimador para usarlo más adelante. (Hasta donde entiendo, no puedo crear un protector con tf.train.Saver , porque no tengo una sesión en ejecución).

Aquí está el código de mi archivo train.py :

 def main(unused_argv): #Load training and eval data (part missing) # Create the estimator hand_detector = learn.Estimator(model_fn=cnn_model_fn, model_dir="\cnn_model_fn") # Set up logging for predictions # Log the values in the "Softmax" tensor with label "probabilities" tensors_to_log = {"probabilities": "softmax_tensor"} logging_hook = tf.train.LoggingTensorHook( tensors=tensors_to_log, every_n_iter=50) # Train the model hand_detector.fit( x=train_data, y=train_labels, batch_size=100, steps=20000, monitors=[logging_hook]) # Configure the accuracy metric for evaluation metrics = { "accuracy": learn.MetricSpec( metric_fn=tf.metrics.accuracy, prediction_key="classes"), } # Evaluate the model and print results eval_results = hand_detector.evaluate( x=eval_data, y=eval_labels, metrics=metrics) print(eval_results) # Save the model for later use (part missing!) 

Casi todas las aplicaciones reales de aprendizaje automático buscan entrenar un modelo una vez y luego guardarlo para usos futuros con nuevos datos. La mayoría de los clasificadores pasan horas en la etapa de entrenamiento y solo unos segundos en la etapa de prueba, por lo que es fundamental aprender cómo guardar con éxito un modelo entrenado.

Voy a explicar cómo exportar modelos de Tensorflow de “alto nivel” (usando export_savedmodel ). La función export_savedmodel requiere el argumento serve_input_receiver_fn, que es una función sin argumentos, que define la entrada del modelo y el predictor. Por lo tanto, debe crear su propio serve_input_receiver_fn , donde el tipo de entrada del modelo coincide con la entrada del modelo en el script de entrenamiento, y el tipo de entrada del predictor coincide con la entrada del predictor en el script de prueba. Por otro lado, si crea un modelo personalizado, debe definir los export_outputs, definidos por la función tf.estimator.export.PredictOutput , cuya entrada es un diccionario que define el nombre que tiene que coincidir con el nombre de la salida del predictor en el guión de prueba.

Por ejemplo:

ESCRITURA DE ENTRENAMIENTO

 def serving_input_receiver_fn(): serialized_tf_example = tf.placeholder(dtype=tf.string, shape=[None], name='input_tensors') receiver_tensors = {"predictor_inputs": serialized_tf_example} feature_spec = {"words": tf.FixedLenFeature([25],tf.int64)} features = tf.parse_example(serialized_tf_example, feature_spec) return tf.estimator.export.ServingInputReceiver(features, receiver_tensors) def estimator_spec_for_softmax_classification(logits, labels, mode): predicted_classes = tf.argmax(logits, 1) if (mode == tf.estimator.ModeKeys.PREDICT): export_outputs = {'predict_output': tf.estimator.export.PredictOutput({"pred_output_classes": predicted_classes, 'probabilities': tf.nn.softmax(logits)})} return tf.estimator.EstimatorSpec(mode=mode, predictions={'class': predicted_classes, 'prob': tf.nn.softmax(logits)}, export_outputs=export_outputs) # IMPORTANT!!! onehot_labels = tf.one_hot(labels, 31, 1, 0) loss = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels, logits=logits) if (mode == tf.estimator.ModeKeys.TRAIN): optimizer = tf.train.AdamOptimizer(learning_rate=0.01) train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) eval_metric_ops = {'accuracy': tf.metrics.accuracy(labels=labels, predictions=predicted_classes)} return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) def model_custom(features, labels, mode): bow_column = tf.feature_column.categorical_column_with_identity("words", num_buckets=1000) bow_embedding_column = tf.feature_column.embedding_column(bow_column, dimension=50) bow = tf.feature_column.input_layer(features, feature_columns=[bow_embedding_column]) logits = tf.layers.dense(bow, 31, activation=None) return estimator_spec_for_softmax_classification(logits=logits, labels=labels, mode=mode) def main(): # ... # preprocess-> features_train_set and labels_train_set # ... classifier = tf.estimator.Estimator(model_fn = model_custom) train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"words": features_train_set}, y=labels_train_set, batch_size=batch_size_param, num_epochs=None, shuffle=True) classifier.train(input_fn=train_input_fn, steps=100) full_model_dir = classifier.export_savedmodel(export_dir_base="C:/models/directory_base", serving_input_receiver_fn=serving_input_receiver_fn) 

ESCRITURA DE PRUEBA

 def main(): # ... # preprocess-> features_test_set # ... with tf.Session() as sess: tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], full_model_dir) predictor = tf.contrib.predictor.from_saved_model(full_model_dir) model_input = tf.train.Example(features=tf.train.Features( feature={"words": tf.train.Feature(int64_list=tf.train.Int64List(value=features_test_set)) })) model_input = model_input.SerializeToString() output_dict = predictor({"predictor_inputs":[model_input]}) y_predicted = output_dict["pred_output_classes"][0] 

(Código probado en Python 3.6.3, Tensorflow 1.4.0)

Estimator tiene una función miembro export_savedmodel para ese propósito. Encontrarás los documentos aquí .