¿Cuál es la estructura “DERECHA” para guardar / restaurar un modelo en Tensorflow durante el entrenamiento / val / test?

Quiero escribir algunos códigos en Tensoflow que puedan capacitar a un modelo, ejecutar la validación durante el entrenamiento y, finalmente, informar los resultados en los datos de prueba para el mejor modelo seleccionado a través de los datos de validación. Me preguntaba si la siguiente estructura es la forma correcta de hacerlo. [considerando ámbitos variables, compartir parámetros, guardar / restaurar, …]

MyModel.py

class MyModel(object): def build_model(self, reuse): with tf.variable_scope("Model", reuse = reuse) as scope: self.v1 = tf.get_variable("v1", [1, 2]) // rest of the codes def train(self, sess): self.build_model(False) s1 = tf.train.Saver() init_opt =tf.global_variables_initializer() sess.run(init_opt) // model training // ... s1.save(sess, "/tmp/model.ckpt") def val(self, sess): self.build_model(True) s2 = tf.train.Saver() // do the validation s2.save(sess, "/tmp/best_model.ckpt") def test(self, sess): self.build_model(False) s3 = tf.train.Saver() s3.restre(sess, "/tmp/model_best.ckpt") //rest of the codes ... 

Y escribí las siguientes funciones en los dos archivos diferentes:

train.py:

  with tf.Session() as sess: mtrain = MyModel() mval = MyModel() for iter_i in range(num_training_iters): mtrain.train(sess) mval.val(sess) 

test.py

 with tf.Session() as sess: mtest = MyModel() mtest.test(sess) 

Miré los tutoriales de Tensorflow, pero ninguno de ellos tiene esta estructura. Cualquier ayuda sería muy apreciada.

Gracias

Usted puede encontrar esto útil:

https://gist.github.com/earonesty/ac0617a5672ae1a41be1eaf316dd63e4

Le permite guardar / restaurar variables globales y de ámbito:

desde varlib import vartemp

 x = 3 y = 4 with vartemp({'x':93,'y':94}): print(x) print(y) print(x) print(y)