Tensorflow, ¿la mejor manera de salvar el estado en RNNs?

Actualmente tengo el siguiente código para una serie de RNN encadenados en tensorflow. No estoy usando MultiRNN ya que tenía que hacer algo más adelante con la salida de cada capa.

for r in range(RNNS): with tf.variable_scope('recurent_%d' % r) as scope: state = [tf.zeros((BATCH_SIZE, sz)) for sz in rnn_func.state_size] time_outputs = [None] * TIME_STEPS for t in range(TIME_STEPS): rnn_input = getTimeStep(rnn_outputs[r - 1], t) time_outputs[t], state = rnn_func(rnn_input, state) time_outputs[t] = tf.reshape(time_outputs[t], (-1, 1, RNN_SIZE)) scope.reuse_variables() rnn_outputs[r] = tf.concat(1, time_outputs) 

Actualmente tengo un número fijo de pasos de tiempo. Sin embargo, me gustaría cambiarlo para que tenga un solo paso de tiempo pero recuerde el estado entre lotes. Por lo tanto, necesitaría crear una variable de estado para cada capa y asignarle el estado final de cada una de las capas. Algo como esto.

 for r in range(RNNS): with tf.variable_scope('recurent_%d' % r) as scope: saved_state = tf.get_variable('saved_state', ...) rnn_outputs[r], state = rnn_func(rnn_outputs[r - 1], saved_state) saved_state = tf.assign(saved_state, state) 

Luego, para cada una de las capas necesitaría evaluar el estado guardado en mi función sess.run, así como llamar a mi función de entrenamiento. Tendría que hacer esto para cada capa rnn. Esto parece una especie de molestia. Tendría que hacer un seguimiento de cada estado guardado y evaluarlo en ejecución. Además, la ejecución debería copiar el estado de mi GPU a la memoria del host, lo que sería ineficaz e innecesario. ¿Hay una mejor manera de hacer esto?

Aquí está el código para actualizar el estado inicial de la LSTM, cuando state_is_tuple=True definiendo las variables de estado. También soporta múltiples capas.

Definimos dos funciones: una para obtener las variables de estado con un estado cero inicial y otra para devolver una operación, que podemos pasar a session.run para actualizar las variables de estado con el último estado oculto de la LSTM.

 def get_state_variables(batch_size, cell): # For each layer, get the initial state and make a variable out of it # to enable updating its value. state_variables = [] for state_c, state_h in cell.zero_state(batch_size, tf.float32): state_variables.append(tf.contrib.rnn.LSTMStateTuple( tf.Variable(state_c, trainable=False), tf.Variable(state_h, trainable=False))) # Return as a tuple, so that it can be fed to dynamic_rnn as an initial state return tuple(state_variables) def get_state_update_op(state_variables, new_states): # Add an operation to update the train states with the last state tensors update_ops = [] for state_variable, new_state in zip(state_variables, new_states): # Assign the new state to the state variables on this layer update_ops.extend([state_variable[0].assign(new_state[0]), state_variable[1].assign(new_state[1])]) # Return a tuple in order to combine all update_ops into a single operation. # The tuple's actual value should not be used. return tf.tuple(update_ops) 

Podemos usar eso para actualizar el estado del LSTM después de cada lote. Tenga en cuenta que uso tf.nn.dynamic_rnn para desenrollar:

 data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size)) cell_layer = tf.contrib.rnn.GRUCell(256) cell = tf.contrib.rnn.MultiRNNCell([cell] * num_layers) # For each layer, get the initial state. states will be a tuple of LSTMStateTuples. states = get_state_variables(batch_size, cell) # Unroll the LSTM outputs, new_states = tf.nn.dynamic_rnn(cell, data, initial_state=states) # Add an operation to update the train states with the last state tensors. update_op = get_state_update_op(states, new_states) sess = tf.Session() sess.run(tf.global_variables_initializer()) sess.run([outputs, update_op], {data: ...}) 

La principal diferencia con esta respuesta es que state_is_tuple=True hace que el estado de LSTM sea un LSTMStateTuple que contiene dos variables (estado de celda y estado oculto) en lugar de una sola variable. El uso de múltiples capas hace que el estado de LSTM sea una tupla de LSTMStateTuples, una por capa.

Restablecer a cero

Al usar un modelo entrenado para la predicción / deencoding, es posible que desee restablecer el estado a cero. Entonces, puedes hacer uso de esta función:

 def get_state_reset_op(state_variables, cell, batch_size): # Return an operation to set each variable in a list of LSTMStateTuples to zero zero_states = cell.zero_state(batch_size, tf.float32) return get_state_update_op(state_variables, zero_states) 

Por ejemplo, como arriba:

 reset_state_op = get_state_reset_op(state, cell, max_batch_size) # Reset the state to zero before feeding input sess.run([reset_state_op]) sess.run([outputs, update_op], {data: ...}) 

Ahora estoy guardando los estados RNN usando las dependencias tf.control_control. Aquí hay un ejemplo.

  saved_states = [tf.get_variable('saved_state_%d' % i, shape = (BATCH_SIZE, sz), trainable = False, initializer = tf.constant_initializer()) for i, sz in enumerate(rnn.state_size)] W = tf.get_variable('W', shape = (2 * RNN_SIZE, RNN_SIZE), initializer = tf.truncated_normal_initializer(0.0, 1 / np.sqrt(2 * RNN_SIZE))) b = tf.get_variable('b', shape = (RNN_SIZE,), initializer = tf.constant_initializer()) rnn_output, states = rnn(last_output, saved_states) with tf.control_dependencies([tf.assign(a, b) for a, b in zip(saved_states, states)]): dense_input = tf.concat(1, (last_output, rnn_output)) dense_output = tf.tanh(tf.matmul(dense_input, W) + b) last_output = dense_output + last_output 

Solo me aseguro de que parte de mi gráfica dependa de guardar el estado.

Estos dos enlaces también son relacionados y útiles para esta pregunta:

https://github.com/tensorflow/tensorflow/issues/2695 https://github.com/tensorflow/tensorflow/issues/2838