Producto cartesiano en tensorflow

¿Hay alguna manera fácil de hacer un producto cartesiano en Tensorflow como itertools.product? Quiero obtener una combinación de elementos de dos tensores ( b ), en Python es posible a través de itertools como list(product(a, b)) . Estoy buscando una alternativa en Tensorflow.

Voy a suponer aquí que a y b son tensores 1-D.

Para obtener el producto cartesiano de los dos, usaría una combinación de tf.expand_dims y tf.tile :

 a = tf.constant([1,2,3]) b = tf.constant([4,5,6,7]) tile_a = tf.tile(tf.expand_dims(a, 1), [1, tf.shape(b)[0]]) tile_a = tf.expand_dims(tile_a, 2) tile_b = tf.tile(tf.expand_dims(b, 0), [tf.shape(a)[0], 1]) tile_b = tf.expand_dims(tile_b, 2) cartesian_product = tf.concat([tile_a, tile_b], axis=2) cart = tf.Session().run(cartesian_product) print(cart.shape) print(cart) 

Usted termina con un tensor de len (a) * len (b) * 2 donde cada combinación de los elementos de a y b se representa en la última dimensión.

Una solución más corta a la misma, utilizando tf.add() para la difusión (probado):

 import tensorflow as tf a = tf.constant([1,2,3]) b = tf.constant([4,5,6,7]) a, b = a[ None, :, None ], b[ :, None, None ] cartesian_product = tf.concat( [ a + tf.zeros_like( b ), tf.zeros_like( a ) + b ], axis = 2 ) with tf.Session() as sess: print( sess.run( cartesian_product ) ) 

saldrá:

[[[1 4]
[2 4]
[3 4]]

[[15]
[2 5]
[3 5]]

[[dieciséis]
[2 6]
[3 6]]

[[1 7]
[2 7]
[3 7]]]

Me inspira la respuesta de Jaba. Si desea obtener el producto cartesiano de dos tensores 2-D, puede hacerlo de la siguiente manera:

ingrese a: [N, L] yb: [M, L], obtenga un tensor concat [N * M, L]

 tile_a = tf.tile(tf.expand_dims(a, 1), [1, M, 1]) tile_b = tf.tile(tf.expand_dims(b, 0), [N, 1, 1]) cartesian_product = tf.concat([tile_a, tile_b], axis=2) cartesian = tf.reshape(cartesian_product, [N*M, -1]) cart = tf.Session().run(cartesian) print(cart.shape) print(cart) 
 import tensorflow as tf a = tf.constant([0, 1, 2]) b = tf.constant([2, 3]) c = tf.stack(tf.meshgrid(a, b, indexing='ij'), axis=-1) c = tf.reshape(c, (-1, 2)) with tf.Session() as sess: print(sess.run(c)) 

Salida:

 [[0 2] [0 3] [1 2] [1 3] [2 2] [2 3]] 

crédito a jdehesa: enlace