¿Cómo alternar las operaciones de tren en tensorflow?

Estoy implementando un esquema de entrenamiento alternativo. La gráfica contiene dos operaciones de entrenamiento. El entrenamiento debe alternar entre estos.

Esto es relevante para investigaciones como esta o esta.

A continuación se muestra un pequeño ejemplo. Pero parece actualizar ambas operaciones a cada paso. ¿Cómo puedo alternar explícitamente entre estos?

from tensorflow.examples.tutorials.mnist import input_data import tensorflow as tf # Import data mnist = input_data.read_data_sets('/tmp/tensorflow/mnist/input_data', one_hot=True) # Create the model x = tf.placeholder(tf.float32, [None, 784]) W = tf.Variable(tf.zeros([784, 10]), name='weights') b = tf.Variable(tf.zeros([10]), name='biases') y = tf.matmul(x, W) + b # Define loss and optimizer y_ = tf.placeholder(tf.float32, [None, 10]) cross_entropy = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) global_step = tf.Variable(0, trainable=False) tvars1 = [b] train_step1 = tf.train.GradientDescentOptimizer(0.5).apply_gradients(zip(tf.gradients(cross_entropy, tvars1), tvars1), global_step) tvars2 = [W] train_step2 = tf.train.GradientDescentOptimizer(0.5).apply_gradients(zip(tf.gradients(cross_entropy, tvars2), tvars2), global_step) train_step = tf.cond(tf.equal(tf.mod(global_step,2), 0), true_fn= lambda:train_step1, false_fn=lambda : train_step2) sess = tf.InteractiveSession() tf.global_variables_initializer().run() # Train for i in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) if i % 100 == 0: print(sess.run([cross_entropy, global_step], feed_dict={x: mnist.test.images, y_: mnist.test.labels})) 

Esto resulta en

 [2.0890141, 2] [0.38277805, 202] [0.33943111, 402] [0.32314575, 602] [0.3113254, 802] [0.3006627, 1002] [0.2965056, 1202] [0.29858461, 1402] [0.29135355, 1602] [0.29006076, 1802] 

El paso global se repite en 1802, por lo que ambas operaciones de tren se ejecutan cada vez que se llama a train_step . (Esto también sucede cuando la condición de siempre falsa es tf.equal(global_step,-1) por ejemplo).

Mi pregunta es cómo alternar entre la ejecución de train_step1 y train_step2 ?

Creo que la forma más sencilla es simplemente

 for i in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) if i % 2 == 0: sess.run(train_step1, feed_dict={x: batch_xs, y_: batch_ys}) else: sess.run(train_step2, feed_dict={x: batch_xs, y_: batch_ys}) 

Pero si es necesario realizar un cambio a través del flujo condicional de tensorflow, hágalo de esta manera:

 optimizer = tf.train.GradientDescentOptimizer(0.5) train_step = tf.cond(tf.equal(tf.mod(global_step, 2), 0), true_fn=lambda: optimizer.apply_gradients(zip(tf.gradients(cross_entropy, tvars1), tvars1), global_step), false_fn=lambda: optimizer.apply_gradients(zip(tf.gradients(cross_entropy, tvars2), tvars2), global_step))