TensorFlow dynamic_rnn state

Mi pregunta es sobre el método TensorFlow tf.nn.dynamic_rnn . Devuelve la salida de cada paso de tiempo y el estado final.

Me gustaría saber si el estado final devuelto es el estado de la celda en la longitud de secuencia máxima o si se determina individualmente por el argumento de sequence_length .

Para comprender mejor un ejemplo: tengo 3 secuencias con longitud [10,20,30] y [10,20,30] el estado final [3,512] (si el estado oculto de la celda tiene la longitud 512).

¿Son los tres estados ocultos devueltos para las tres secuencias el estado de la celda en el paso de tiempo 30 o estoy recuperando los estados en los pasos de tiempo [10,20,30] ?

tf.nn.dynamic_rnn devuelve dos tensores: outputs y states .

Las outputs contienen las salidas de todas las celdas para todas las secuencias en un lote. Entonces, si una secuencia en particular es más corta y está rellenada con ceros, las outputs de las últimas celdas serán cero.

Los states mantienen el último estado de celda o, de manera equivalente, la última salida distinta de cero por secuencia (si está utilizando BasicRNNCell ).

Aquí hay un ejemplo:

 import numpy as np import tensorflow as tf n_steps = 2 n_inputs = 3 n_neurons = 5 X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs]) seq_length = tf.placeholder(tf.int32, [None]) basic_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons) outputs, states = tf.nn.dynamic_rnn(basic_cell, X, sequence_length=seq_length, dtype=tf.float32) X_batch = np.array([ # t = 0 t = 1 [[0, 1, 2], [9, 8, 7]], # instance 0 [[3, 4, 5], [0, 0, 0]], # instance 1 ]) seq_length_batch = np.array([2, 1]) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) outputs_val, states_val = sess.run([outputs, states], feed_dict={X: X_batch, seq_length: seq_length_batch}) print('outputs:') print(outputs_val) print('\nstates:') print(states_val) 

Esto imprime algo como:

 outputs: [[[-0.85381496 -0.19517037 0.36011398 -0.18617202 0.39162001] [-0.99998015 -0.99461144 -0.82241321 0.93778896 0.90737367]] [[-0.99849552 -0.88643843 0.20635395 0.157896 0.76042926] [ 0. 0. 0. 0. 0. ]]] # because len=1 states: [[-0.99998015 -0.99461144 -0.82241321 0.93778896 0.90737367] [-0.99849552 -0.88643843 0.20635395 0.157896 0.76042926]] 

Tenga en cuenta que los states tienen los mismos vectores que en la output , y son las últimas salidas no cero por instancia de lote.