En Tensorflow, ¿cómo usar tf.gather () para la última dimensión?

Estoy tratando de reunir partes de un tensor en términos de la última dimensión para la conexión parcial entre capas. Debido a que la forma del tensor de salida es [batch_size, h, w, depth] , quiero seleccionar cortes basados ​​en la última dimensión, como

 # L is intermediate tensor partL = L[:, :, :, [0,2,3,8]] 

Sin embargo, tf.gather(L, [0, 2,3,8]) parece funcionar solo para la primera dimensión (¿verdad?) ¿Alguien puede decirme cómo hacerlo?

Hay un error de seguimiento para admitir este caso de uso aquí: https://github.com/tensorflow/tensorflow/issues/206

Por ahora puedes:

  1. Transponga su matriz para que la dimensión a reunir sea la primera (la transposición es costosa)

  2. remodela tu tensor en 1d (remodelar es barato) y convierte tus índices de columna en una lista de índices de elementos individuales en la indexación lineal, luego remodela de nuevo

  3. utilizar gather_nd . Todavía tendrá que convertir sus índices de columna en una lista de índices de elementos individuales.

A partir de TensorFlow 1.3 tf.gather tiene un parámetro de axis , por lo que las diversas soluciones aquí ya no son necesarias.

https://www.tensorflow.org/versions/r1.3/api_docs/python/tf/gather https://github.com/tensorflow/tensorflow/issues/11223

Con gather_nd ahora puedes hacer esto de la siguiente manera:

 cat_idx = tf.concat([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=0) result = tf.gather_nd(matrix, cat_idx) 

Además, según lo informado por el usuario Nova en un hilo referenciado por @Yaroslav Bulatov:

 x = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) idx = tf.constant([1, 0, 2]) idx_flattened = tf.range(0, x.shape[0]) * x.shape[1] + idx y = tf.gather(tf.reshape(x, [-1]), # flatten input idx_flattened) # use flattened indices with tf.Session(''): print y.eval() # [2 4 9] 

La esencia es aplanar el tensor y usar direccionamiento 1D con tf.gather (…).

Otra solución más que usa tf.unstack (…), tf.gather (…) y tf.stack (..)

Código:

 import tensorflow as tf import numpy as np shape = [2, 2, 2, 10] L = np.arange(np.prod(shape)) L = np.reshape(L, shape) indices = [0, 2, 3, 8] axis = -1 # last dimension def gather_axis(params, indices, axis=0): return tf.stack(tf.unstack(tf.gather(tf.unstack(params, axis=axis), indices)), axis=axis) print(L) with tf.Session() as sess: partL = sess.run(gather_axis(L, indices, axis)) print(partL) 

Resultado:

 L = [[[[ 0 1 2 3 4 5 6 7 8 9] [10 11 12 13 14 15 16 17 18 19]] [[20 21 22 23 24 25 26 27 28 29] [30 31 32 33 34 35 36 37 38 39]]] [[[40 41 42 43 44 45 46 47 48 49] [50 51 52 53 54 55 56 57 58 59]] [[60 61 62 63 64 65 66 67 68 69] [70 71 72 73 74 75 76 77 78 79]]]] partL = [[[[ 0 2 3 8] [10 12 13 18]] [[20 22 23 28] [30 32 33 38]]] [[[40 42 43 48] [50 52 53 58]] [[60 62 63 68] [70 72 73 78]]]] 

Una versión correcta de la respuesta de @ Andrei sería:

 cat_idx = tf.stack([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=1) result = tf.gather_nd(matrix, cat_idx) 

Puede intentar de esta manera, por ejemplo (en la mayoría de los casos, en NLP al menos),

El parámetro tiene forma [batch_size, depth] y los índices son [i, j, k, n, m] cuya longitud es batch_size. Entonces gather_nd puede ser útil.

 parameters = tf.constant([ [11, 12, 13], [21, 22, 23], [31, 32, 33], [41, 42, 43]]) targets = tf.constant([2, 1, 0, 1]) batch_nums = tf.range(0, limit=parameters.get_shape().as_list()[0]) indices = tf.stack((batch_nums, targets), axis=1) # the axis is the dimension number items = tf.gather_nd(parameters, indices) # which is what we want: [13, 22, 31, 42] 

Este fragmento primero encuentra la primera dimensión a través del batch_num y luego busca el elemento a lo largo de esa dimensión por el número objective.

Implementando 2. de @Yaroslav Bulatov’s:

 #Your indices indices = [0, 2, 3, 8] #Remember for final reshaping n_indices = tf.shape(indices)[0] flattened_L = tf.reshape(L, [-1]) #Walk strided over the flattened array offset = tf.expand_dims(tf.range(0, tf.reduce_prod(tf.shape(L)), tf.shape(L)[-1]), 1) flattened_indices = tf.reshape(tf.reshape(indices, [-1])+offset, [-1]) selected_rows = tf.gather(flattened_L, flattened_indices) #Final reshape partL = tf.reshape(selected_rows, tf.concat(0, [tf.shape(L)[:-1], [n_indices]])) 

Crédito a ¿Cómo seleccionar filas de un tensor 3-D en TensorFlow?

Tensor no tiene forma de atributo, pero el método get_shape (). A continuación es ejecutable por Python 2.7

 import tensorflow as tf import numpy as np x = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) idx = tf.constant([1, 0, 2]) idx_flattened = tf.range(0, x.get_shape()[0]) * x.get_shape()[1] + idx y = tf.gather(tf.reshape(x, [-1]), # flatten input idx_flattened) # use flattened indices with tf.Session(''): print y.eval() # [2 4 9]