¿Cómo se devuelven las predicciones Y las tags con tf.estimator (ya sea con el método de predicción o con el de evaluación)?

Estoy trabajando con Tensorflow 1.4.

Creé un tf.estimator personalizado para hacer una clasificación, como esto:

def model_fn(): # Some operations here [...] return tf.estimator.EstimatorSpec(mode=mode, predictions={"Preds": predictions}, loss=cost, train_op=loss, eval_metric_ops=eval_metric_ops, training_hooks=[summary_hook]) my_estimator = tf.estimator.Estimator(model_fn=model_fn, params=model_params, model_dir='/my/directory') 

Puedo entrenarlo fácilmente:

 input_fn = create_train_input_fn(path=train_files) my_estimator.train(input_fn=input_fn) 

donde input_fn es una función que lee datos de archivos tfrecords , con la API tf.data.Dataset .

Cuando estoy leyendo archivos de tfrecords, no tengo tags en la memoria cuando hago predicciones.

Mi pregunta es, ¿cómo puedo obtener las predicciones Y las tags devueltas, ya sea por el método predict () o por el método evaluar () ?

Parece que no hay manera de tener ambas cosas. predict () no tiene acceso (?) a las tags, y no es posible acceder al diccionario de predicciones con el método de evaluación () .

Después de que hayas terminado tu entrenamiento, en '/my/directory' tienes un montón de archivos de puntos de control.

Debe configurar su canal de entrada nuevamente, cargar manualmente uno de esos archivos, luego comenzar a recorrer sus lotes almacenando las predicciones y las tags:

 # Rebuild the input pipeline input_fn = create_eval_input_fn(path=eval_files) features, labels = input_fn() # Rebuild the model predictions = model_fn(features, labels, tf.estimator.ModeKeys.EVAL).predictions # Manually load the latest checkpoint saver = tf.train.Saver() with tf.Session() as sess: ckpt = tf.train.get_checkpoint_state('/my/directory') saver.restre(sess, ckpt.model_checkpoint_path) # Loop through the batches and store predictions and labels prediction_values = [] label_values = [] while True: try: preds, lbls = sess.run([predictions, labels]) prediction_values += preds label_values += lbls except tf.errors.OutOfRangeError: break # store prediction_values and label_values somewhere 

Actualización: se cambió para usar directamente la función model_fn que ya tiene.