¿Cómo TensorArray y while_loop trabajan juntos en tensorflow?

Estoy tratando de producir un ejemplo muy fácil para la combinación de TensorArray y while_loop:

# 1000 sequence in the length of 100 matrix = tf.placeholder(tf.int32, shape=(100, 1000), name="input_matrix") matrix_rows = tf.shape(matrix)[0] ta = tf.TensorArray(tf.float32, size=matrix_rows) ta = ta.unstack(matrix) init_state = (0, ta) condition = lambda i, _: i < n body = lambda i, ta: (i + 1, ta.write(i,ta.read(i)*2)) # run the graph with tf.Session() as sess: (n, ta_final) = sess.run(tf.while_loop(condition, body, init_state),feed_dict={matrix: tf.ones(tf.float32, shape=(100,1000))}) print (ta_final.stack()) 

Pero estoy recibiendo el siguiente error:

 ValueError: Tensor("while/LoopCond:0", shape=(), dtype=bool) must be from the same graph as Tensor("Merge:0", shape=(), dtype=float32). 

¿Alguien tiene una idea de cuál es el problema?

Hay varias cosas en su código para señalar. Primero, no necesita desastackr la matriz en TensorArray para usarla dentro del bucle, puede hacer referencia al Tensor matriz dentro del cuerpo e indexarlo con la notación de matrix[i] . Otro problema es el tipo de datos diferente entre su matriz ( tf.int32 ) y TensorArray ( tf.float32 ), según su código, está multiplicando las entradas de matriz por 2 y escribiendo el resultado en la matriz, por lo que debería ser int32 como bien. Finalmente, cuando desee leer el resultado final del bucle, la operación correcta es TensorArray.stack() que es lo que necesita para ejecutar en su llamada session.run .

Aquí hay un ejemplo de trabajo:

 import numpy as np import tensorflow as tf # 1000 sequence in the length of 100 matrix = tf.placeholder(tf.int32, shape=(100, 1000), name="input_matrix") matrix_rows = tf.shape(matrix)[0] ta = tf.TensorArray(dtype=tf.int32, size=matrix_rows) init_state = (0, ta) condition = lambda i, _: i < matrix_rows body = lambda i, ta: (i + 1, ta.write(i, matrix[i] * 2)) n, ta_final = tf.while_loop(condition, body, init_state) # get the final result ta_final_result = ta_final.stack() # run the graph with tf.Session() as sess: # print the output of ta_final_result print sess.run(ta_final_result, feed_dict={matrix: np.ones(shape=(100,1000), dtype=np.int32)})