Tensorflow Estimator: cómo ejecutar una operación dentro de model_fn

Tengo una pregunta muy breve, que probablemente tenga una respuesta muy simple, pero simplemente no puedo entenderla, aunque lo intenté durante horas.

Estoy usando Tensorflow Estimator y quiero acceder al paso global dentro de mi model_fn. He intentado tf.train.get_global_step, lo que me devuelve un Tensor. ¡Necesito el paso global como un entero (o como una cadena)!

Así que he intentado evaluar () (= tf.get_default_session (). Run (t)), pero no funciona …

¡Aclamaciones!

Puede usar tf.cast para convertir el Tensor a int o string .

Por ejemplo,

 tf.cast(tf.train.get_global_step(), dtype=tf.int) 

Vea la referencia aquí .

Una forma sería analizarlo desde el último archivo de punto de control en model_dir . Entonces, asumiendo que puede pasar model_dir a model_fn (ya sea a través del argumento params de tf.estimator.Estimator(..., params={'model_dir': 'path/to/model_dir'}) o mediante tf.flags.FLAGS , puedes usar esta función de utilidad:

 import tensorflow as tf def get_global_step_from_model_dir(model_dir): latest_checkpoint_file = tf.train.latest_checkpoint(model_dir) if latest_checkpoint_file is None: return 0 else: return int(os.path.basename(latest_checkpoint_file).split('-')[-1])