Estado de salida del codificador multicapa a decodificador multicapa en el modelo Seq2Seq TF 1.0

Tensorflow versión 1.0

Mi pregunta es, qué dimensión del argumento tf.contrib.seq2seq attention_decoder_fn_train espera.

¿Puede tomar el estado del codificador multicapa?

Contexto :

Quiero crear una atención bidireccional multicapa basada en seq2seq en tensorflow 1.0 .

Mi codificador

 cell = LSTM(n) cell = MultiRnnCell([cell]*4) ((encoder_fw_outputs,encoder_bw_outputs), (encoder_fw_state,encoder_bw_state)) = (tf.nn.bidirectional_dynamic_rnn(cell_fw=cell, cell_bw = cell.... ) 

Ahora, el codificador bidireccional de varias capas devuelve el encoder cell_states[c] y hidden_states[h] para cada capa y también para el paso hacia atrás y hacia adelante. Concatenaré los estados de paso hacia adelante y paso hacia atrás para pasarlo a encoder_state:

self.encoder_state = tf.concat((encoder_fw_state, encoder_bw_state), -1)

Y se lo paso a mi decodificador:

 decoder_fn_train = seq2seq.simple_decoder_fn_train(encoder_state=self.encoder_state) (self.decoder_outputs_train, self.decoder_state_train, self.decoder_context_state_train) = seq2seq.dynamic_rnn_decoder(cell=decoder_cell,... ) 

Pero da el siguiente error:

ValueError: The two structures don't have the same number of elements. First structure: Tensor("BidirectionalEncoder/transpose:0", shape=(?, 2, 2, 20), dtype=float32), second structure: (LSTMStateTuple(c=20, h=20), LSTMStateTuple(c=20, h=20)).

Mi decoder_cell también es multicapa.

Enlace a mi código

1 :

Encontré un problema con mi implementación. Así que publicarlo aquí. El problema fue la concatenación de encoder_fw_state y encoder_bw_state . La forma correcta de hacerlo es la siguiente:

  self.encoder_state = [] for i in range(self.num_layers): if isinstance(encoder_fw_state[i], LSTMStateTuple): encoder_state_c = tf.concat((encoder_fw_state[i].c, encoder_bw_state[i].c), 1, name='bidirectional_concat_c') encoder_state_h = tf.concat((encoder_fw_state[i].h, encoder_bw_state[i].h), 1, name='bidirectional_concat_h') encoder_state = LSTMStateTuple(c=encoder_state_c, h=encoder_state_h) elif isinstance(encoder_fw_state[i], tf.Tensor): encoder_state = tf.concat((encoder_fw_state[i], encoder_bw_state[i]), 1, name='bidirectional_concat') self.encoder_state.append(encoder_state) self.encoder_state = tuple(self.encoder_state)