¿Cómo extraer el estado celular y el estado oculto de un modelo RNN en tensorflow?

Soy nuevo en TensorFlow y tengo dificultades para entender el módulo RNN. Estoy intentando extraer estados ocultos / celulares de un LSTM. Para mi código, estoy usando la implementación de https://github.com/aymericdamien/TensorFlow-Examples .

# tf Graph input x = tf.placeholder("float", [None, n_steps, n_input]) y = tf.placeholder("float", [None, n_classes]) # Define weights weights = {'out': tf.Variable(tf.random_normal([n_hidden, n_classes]))} biases = {'out': tf.Variable(tf.random_normal([n_classes]))} def RNN(x, weights, biases): # Prepare data shape to match `rnn` function requirements # Current data input shape: (batch_size, n_steps, n_input) # Required shape: 'n_steps' tensors list of shape (batch_size, n_input) # Permuting batch_size and n_steps x = tf.transpose(x, [1, 0, 2]) # Reshaping to (n_steps*batch_size, n_input) x = tf.reshape(x, [-1, n_input]) # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input) x = tf.split(0, n_steps, x) # Define a lstm cell with tensorflow #with tf.variable_scope('RNN'): lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0, state_is_tuple=True) # Get lstm cell output outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32) # Linear activation, using rnn inner loop last output return tf.matmul(outputs[-1], weights['out']) + biases['out'], states pred, states = RNN(x, weights, biases) # Define loss and optimizer cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y)) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) # Evaluate model correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1)) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) # Initializing the variables init = tf.initialize_all_variables() 

Ahora quiero extraer el estado de celda / oculto para cada paso de tiempo en una predicción. El estado se almacena en un LSTMStateTuple de la forma (c, h), que puedo averiguar evaluando los print states . Sin embargo, al intentar llamar a print states.c.eval() (que de acuerdo con la documentación debería proporcionarme valores en el tensor states.c ), se produce un error que indica que mis variables no están inicializadas aunque lo esté llamando justo después de que estoy prediciendo algo El código para esto está aquí:

 # Launch the graph with tf.Session() as sess: sess.run(init) step = 1 # Keep training until reach max iterations for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='RNN'): print v.name while step * batch_size < training_iters: batch_x, batch_y = mnist.train.next_batch(batch_size) # Reshape data to get 28 seq of 28 elements batch_x = batch_x.reshape((batch_size, n_steps, n_input)) # Run optimization op (backprop) sess.run(optimizer, feed_dict={x: batch_x, y: batch_y}) print states.c.eval() # Calculate batch accuracy acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y}) step += 1 print "Optimization Finished!" 

y el mensaje de error es

 InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder' with dtype float [[Node: Placeholder = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]] 

Los estados tampoco son visibles en tf.all_variables() , solo los tensores de matriz / polarización entrenados (como se describe aquí: Tensorflow: muestra o guarda los valores de puerta olvidada en LSTM ). No quiero construir todo el LSTM desde cero, ya que tengo los estados en la variable states , solo necesito llamarlo.

Simplemente puede recostackr los valores de los states de la misma manera que se recostack la precisión.

Supongo que pred, states, acc = sess.run(pred, states, accuracy, feed_dict={x: batch_x, y: batch_y}) debería funcionar perfectamente bien.

Un comentario sobre su suposición: los “estados” solo tienen los valores de “estado oculto” y “celda de memoria” del último paso del tiempo.

Las “salidas” contienen el “estado oculto” de cada paso de tiempo que desee (el tamaño de las salidas es [batch_size, seq_len, hidden_size]. Por lo tanto, supongo que desea “salidas” variables, no “estados”. Consulte la documentación .

Tengo que estar en desacuerdo con la respuesta del usuario 3480922. Para el código:

 outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32) 

para poder extraer el estado oculto de cada time_step en una predicción, tienes que usar las salidas. Porque las salidas tienen el valor de estado oculto para cada time_step. Sin embargo, no estoy seguro de que exista alguna manera de que podamos almacenar los valores del estado de la celda también para cada paso de tiempo. Debido a que la tupla de estados proporciona los valores de estado de celda, pero solo para el último paso de tiempo.

Por ejemplo, en la siguiente muestra con 5 time_steps, las salidas [4,:,:], time_step = 0, …, 4 tiene los valores de estado ocultos para time_step = 4, mientras que la tupla de estados h solo tiene el estado oculto valores para time_step = 4. Sin embargo, la tupla de estado c tiene el valor de celda en el paso de tiempo = 4.

  outputs = [[[ 0.0589103 -0.06925126 -0.01531546 0.06108122] [ 0.00861215 0.06067181 0.03790079 -0.04296958] [ 0.00597713 0.03916606 0.02355802 -0.0277683 ]] [[ 0.06252582 -0.07336216 -0.01607122 0.05024602] [ 0.05464711 0.03219429 0.06635305 0.00753127] [ 0.05385715 0.01259535 0.0524035 0.01696803]] [[ 0.0853352 -0.06414541 0.02524283 0.05798233] [ 0.10790729 -0.05008117 0.03003334 0.07391824] [ 0.10205664 -0.04479517 0.03844892 0.0693808 ]] [[ 0.10556188 0.0516542 0.09162509 -0.02726674] [ 0.11425048 -0.00211394 0.06025286 0.03575509] [ 0.11338984 0.02839304 0.08105748 0.01564003]] **[[ 0.10072514 0.14767936 0.12387902 -0.07391471] [ 0.10510238 0.06321315 0.08100517 -0.00940042] [ 0.10553667 0.0984127 0.10094948 -0.02546882]]**] states = LSTMStateTuple(c=array([[ 0.23870754, 0.24315512, 0.20842518, -0.12798975], [ 0.23749796, 0.10797793, 0.14181322, -0.01695861], [ 0.2413336 , 0.16692916, 0.17559692, -0.0453596 ]], dtype=float32), h=array(**[[ 0.10072514, 0.14767936, 0.12387902, -0.07391471], [ 0.10510238, 0.06321315, 0.08100517, -0.00940042], [ 0.10553667, 0.0984127 , 0.10094948, -0.02546882]]**, dtype=float32))