Rebanado Tensorflow basado en variable

Descubrí que la indexación aún es un problema abierto en tensorflow (# 206) , por lo que me pregunto qué podría usar como solución en este momento. Quiero indexar / cortar una fila / columna de una matriz en función de una variable que cambia para cada ejemplo de entrenamiento.

Lo que he probado hasta ahora:

  1. Rebanado basado en marcador de posición (no funciona)

Los siguientes segmentos de código (en funcionamiento) se basan en un número fijo.

import tensorflow as tf import numpy as np x = tf.placeholder("float") y = tf.slice(x,[0],[1]) #initialize init = tf.initialize_all_variables() sess = tf.Session() sess.run(init) #run result = sess.run(y, feed_dict={x:[1,2,3,4,5]}) print(result) 

Sin embargo, parece que no puedo simplemente reemplazar uno de estos números fijos con un tf.placeholder. El siguiente código me da el error Error de tipo : Lista de tensores cuando se espera un solo tensor”.

 import tensorflow as tf import numpy as np x = tf.placeholder("float") i = tf.placeholder("int32") y = tf.slice(x,[i],[1]) #initialize init = tf.initialize_all_variables() sess = tf.Session() sess.run(init) #run result = sess.run(y, feed_dict={x:[1,2,3,4,5],i:0}) print(result) 

Esto suena como que los corchetes alrededor de [i] son demasiado, pero quitarlos tampoco ayuda. ¿Cómo utilizar un marcador de posición / variable como índice?

  1. Rebanado basado en la variable python (no funciona correctamente / no se actualiza correctamente)

También he intentado usar una variable de python normal como índice. Esto no conduce a un error, pero la red no aprende nada mientras se entrena. Supongo que debido a que la variable cambiante no está registrada correctamente, el gráfico tiene un formato incorrecto y las actualizaciones no funcionan.

  1. Cortar a través de un vector caliente + multiplicación (funciona, pero es lento)

Una solución que encontré es usar un vector caliente. Hacer un vector de un solo calor en números, pasarlo usando un marcador de posición y luego hacer el corte a través de la multiplicación de matrices. Esto funciona, pero es bastante lento.

¿Alguna idea de cómo dividir / indexar de manera eficiente en función de una variable?

El corte basado en un marcador de posición debería funcionar bien. Parece que se está ejecutando en un error de tipo, debido a algunos problemas sutiles de formas y tipos. Donde tengas lo siguiente:

 x = tf.placeholder("float") i = tf.placeholder("int32") y = tf.slice(x,[i],[1]) 

… deberías tener en su lugar

 x = tf.placeholder("float") i = tf.placeholder("int32") y = tf.slice(x,i,[1]) 

… y luego debes sess.run() i como [0] en la llamada a sess.run() .

Para hacer esto un poco más claro, recomiendo volver a escribir el código de la siguiente manera:

 import tensorflow as tf import numpy as np x = tf.placeholder(tf.float32, shape=[None]) # 1-D tensor i = tf.placeholder(tf.int32, shape=[1]) y = tf.slice(x, i, [1]) #initialize init = tf.initialize_all_variables() sess = tf.Session() sess.run(init) #run result = sess.run(y, feed_dict={x: [1, 2, 3, 4, 5], i: [0]}) print(result) 

Los argumentos de shape adicionales de la tf.placeholder ayudan a garantizar que los valores que alimenta tengan las formas adecuadas, y también que TensorFlow generará un error si las formas no son correctas.