¿Cómo agregar un mecanismo de atención en keras?

Actualmente estoy usando este código que obtengo de una discusión en github Aquí está el código del mecanismo de atención:

_input = Input(shape=[max_length], dtype='int32') # get the embedding layer embedded = Embedding( input_dim=vocab_size, output_dim=embedding_size, input_length=max_length, trainable=False, mask_zero=False )(_input) activations = LSTM(units, return_sequences=True)(embedded) # compute importance for each step attention = Dense(1, activation='tanh')(activations) attention = Flatten()(attention) attention = Activation('softmax')(attention) attention = RepeatVector(units)(attention) attention = Permute([2, 1])(attention) sent_representation = merge([activations, attention], mode='mul') sent_representation = Lambda(lambda xin: K.sum(xin, axis=-2), output_shape=(units,))(sent_representation) probabilities = Dense(3, activation='softmax')(sent_representation) 

¿Es esta la forma correcta de hacerlo? Esperaba la existencia de una capa distribuida en el tiempo ya que el mecanismo de atención se distribuye en cada paso de tiempo de la RNN. Necesito que alguien confirme que esta implementación (el código) es una implementación correcta del mecanismo de atención. Gracias.

Si desea tener una atención a lo largo de la dimensión de tiempo, entonces esta parte de su código me parece correcta:

 activations = LSTM(units, return_sequences=True)(embedded) # compute importance for each step attention = Dense(1, activation='tanh')(activations) attention = Flatten()(attention) attention = Activation('softmax')(attention) attention = RepeatVector(units)(attention) attention = Permute([2, 1])(attention) sent_representation = merge([activations, attention], mode='mul') 

Has desarrollado el vector de atención de la forma (batch_size, max_length) :

 attention = Activation('softmax')(attention) 

Nunca he visto este código antes, así que no puedo decir si este es realmente correcto o no:

 K.sum(xin, axis=-2) 

Lecturas adicionales (puede que tengas un vistazo):

El mecanismo de atención presta atención a diferentes partes de la oración:

activations = LSTM(units, return_sequences=True)(embedded)

Y determina la contribución de cada estado oculto de esa oración por

  1. Cálculo de la agregación de cada attention = Dense(1, activation='tanh')(activations) estado oculto attention = Dense(1, activation='tanh')(activations)
  2. Asignación de pesos a diferentes estados de attention = Activation('softmax')(attention)

Y, finalmente, prestar atención a los diferentes estados:

sent_representation = merge([activations, attention], mode='mul')

No entiendo muy bien esta parte: sent_representation = Lambda(lambda xin: K.sum(xin, axis=-2), output_shape=(units,))(sent_representation)

Para comprender más, puede referirse a esto y esto , y también a esta le da una buena implementación, ver si puede entender más por su cuenta.