TensorFlow obtiene elementos de cada fila para columnas específicas

Si A es una variable TensorFlow como tal

 A = tf.Variable([[1, 2], [3, 4]]) 

e index es otra variable

 index = tf.Variable([0, 1]) 

Quiero usar este índice para seleccionar columnas en cada fila. En este caso, el artículo 0 de la primera fila y el artículo 1 de la segunda fila.

Si A era una matriz de Numpy, para obtener las columnas de las filas correspondientes mencionadas en el índice podemos hacer

 x = A[np.arange(A.shape[0]), index] 

y el resultado seria

 [1, 4] 

¿Cuál es la operación / operaciones equivalentes de TensorFlow para esto? Sé que TensorFlow no admite muchas operaciones de indexación. ¿Cuál sería el trabajo alrededor si no se puede hacer directamente?

Puede ampliar sus índices de columna con índices de fila y luego usar gather_nd:

 import tensorflow as tf A = tf.constant([[1, 2], [3, 4]]) indices = tf.constant([1, 0]) # prepare row indices row_indices = tf.range(tf.shape(indices)[0]) # zip row indices with column indices full_indices = tf.stack([row_indices, indices], axis=1) # retrieve values by indices S = tf.gather_nd(A, full_indices) session = tf.InteractiveSession() session.run(S) 

Después de chapotear durante bastante tiempo. Encontré dos funciones que podrían ser útiles.

Uno es tf.gather_nd() que puede ser útil si puede producir un tensor de la forma [[0, 0], [1, 1]] y, por lo tanto, podría hacer

index = tf.constant([[0, 0], [1, 1]])

tf.gather_nd(A, index)

Si no puede producir un vector de la forma [[0, 0], [1, 1]] (no pude producir esto porque el número de filas en mi caso dependía de un marcador de posición) por alguna razón, entonces el La tf.py_func() que encontré es usar tf.py_func() . Aquí hay un código de ejemplo sobre cómo se puede hacer esto.

 import tensorflow as tf import numpy as np def index_along_every_row(array, index): N, _ = array.shape return array[np.arange(N), index] a = tf.Variable([[1, 2], [3, 4]], dtype=tf.int32) index = tf.Variable([0, 1], dtype=tf.int32) a_slice_op = tf.py_func(index_along_every_row, [a, index], [tf.int32])[0] session = tf.InteractiveSession() a.initializer.run() index.initializer.run() a_slice = a_slice_op.eval() 

a_slice será una matriz numpy [1, 4]

Puede usar un método para crear una matriz one_hot y usarla como una máscara booleana para seleccionar los índices que desee.

 A = tf.Variable([[1, 2], [3, 4]]) index = tf.Variable([0, 1]) one_hot_mask = tf.one_hot(index, A.shape[1], on_value = True, off_value = False, dtype = tf.bool) output = tf.boolean_mask(A, one_hot_mask) 

Podemos hacer lo mismo usando esta combinación de map_fn y gather_nd .

 def get_element(a, indices): """ Outputs (ith element of indices) from (ith row of a) """ return tf.map_fn(lambda x: tf.gather_nd(x[0], x[1]), (a, indices), dtype = tf.float32) 

Aquí hay un ejemplo de uso.

 A = tf.constant(np.array([[1,2,3], [4,5,6], [7,8,9]], dtype = np.float32)) idx = tf.constant(np.array([[2],[1],[0]])) elems = get_element(A, idx) with tf.Session() as sess: e = sess.run(elems) print(e) 

No sé si esto será mucho más lento que otras respuestas.

Tiene la ventaja de que no es necesario especificar el número de filas de A por adelantado, siempre y cuando los indices tengan el mismo número de filas en el tiempo de ejecución.

Tenga en cuenta que la salida de lo anterior será la clasificación 1. Si prefiere que tenga la clasificación 2, reemplace gather_nd por gather