Tensorflow: ¿Por qué tf.case me está dando el resultado equivocado?

Estoy tratando de usar tf.case ( https://www.tensorflow.org/api_docs/python/tf/case ) para actualizar condicionalmente un Tensor. Como se muestra, estoy tratando de actualizar learning_rate a 0.01 cuando global_step == 2 , y a 0.001 cuando global_step == 4 .

Sin embargo, cuando global_step == 2 , ya obtengo learning_rate = 0.001 . Tras una inspección adicional, parece que tf.case me está dando el resultado incorrecto cuando global_step == 2 (obtengo 0.001 lugar de 0.01 ). Esto sucede aunque el predicado para 0.01 se evalúa como Verdadero y el predicado para 0.001 se evalúa como Falso.

¿Estoy haciendo algo mal o es un error?

Versión TF: 1.0.0

Código:

 import tensorflow as tf global_step = tf.Variable(0, dtype=tf.int64) train_op = tf.assign(global_step, global_step + 1) learning_rate = tf.Variable(0.1, dtype=tf.float32, name='learning_rate') # Update the learning_rate tensor conditionally # When global_step == 2, update to 0.01 # When global_step == 4, update to 0.001 cases = [] case_tensors = [] for step, new_rate in [(2, 0.01), (4, 0.001)]: pred = tf.equal(global_step, step) fn_tensor = tf.constant(new_rate, dtype=tf.float32) cases.append((pred, lambda: fn_tensor)) case_tensors.append((pred, fn_tensor)) update = tf.case(cases, default=lambda: learning_rate) updated_learning_rate = tf.assign(learning_rate, update) print tf.__version__ with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for _ in xrange(6): print sess.run([global_step, case_tensors, update, updated_learning_rate]) sess.run(train_op) 

Resultados:

 1.0.0 [0, [(False, 0.0099999998), (False, 0.001)], 0.1, 0.1] [1, [(False, 0.0099999998), (False, 0.001)], 0.1, 0.1] [2, [(True, 0.0099999998), (False, 0.001)], 0.001, 0.001] [3, [(False, 0.0099999998), (False, 0.001)], 0.001, 0.001] [4, [(False, 0.0099999998), (True, 0.001)], 0.001, 0.001] [5, [(False, 0.0099999998), (False, 0.001)], 0.001, 0.001] 

Esto se respondió en https://github.com/tensorflow/tensorflow/issues/8776

Resulta que el comportamiento de tf.case no está definido si, en fn_tensors , las lambdas devuelven un tensor que se creó fuera de la lambda. El uso correcto es definir las lambdas de modo que devuelvan un tensor recién creado.

De acuerdo con el tema de Github vinculado, este uso es necesario porque tf.case debe crear el tensor para poder conectar las entradas del tensor a la twig correcta del predicado.