Tensorflow actualiza el primer elemento coincidente en cada fila

Sobre la base de esta pregunta , busco actualizar los valores de un tensor 2-D la primera vez en una fila donde se cumple la condición tf.where. Aquí hay un código de muestra que estoy usando para simular:

tf.reset_default_graph() graph = tf.Graph() with graph.as_default(): val = "hello" new_val = "goodbye" matrix = tf.constant([["word","hello","hello"], ["word", "other", "hello"], ["hello", "hello","hello"], ["word", "word", "word"] ]) matching_indices = tf.where(tf.equal(matrix, val)) first_matching_idx = tf.segment_min(data = matching_indices[:, 1], segment_ids = matching_indices[:, 0]) sess = tf.InteractiveSession(graph=graph) print(sess.run(first_matching_idx)) 

Esto dará como resultado [1, 2, 0] donde 1 es la ubicación del primer saludo en la fila 1, el 2 es la colocación del primer saludo en la fila 2 y el 0 es la colocación del primer saludo en la fila 3 .

Sin embargo, no puedo encontrar una manera de actualizar el primer índice coincidente con el nuevo valor; básicamente, quiero que el primer “hola” se convierta en “adiós”. He intentado usar tf.scatter_update () pero no parece funcionar en tensores 2D. ¿Hay alguna manera de modificar el tensor 2-D como se describe?

Una solución sencilla es usar tf.py_func con una matriz numpy

 def ch_val(array, val, new_val): idx = np.array([[s, list(row).index(val)] for s, row in enumerate(array) if val in row]) idx = tuple((idx[:, 0], idx[:, 1])) array[idx] = new_val return array ... matrix = tf.Variable([["word","hello","hello"], ["word", "other", "hello"], ["hello", "hello","hello"], ["word", "word", "word"] ]) matrix = tf.py_func(ch_val, [matrix, 'hello', 'goodbye'], tf.string) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(matrix)) # results: [['word' 'goodbye' 'hello'] ['word' 'other' 'goodbye'] ['goodbye' 'hello' 'hello'] ['word' 'word' 'word']]