Subprocesos paralelos con TensorFlow Dataset API y flat_map

Estoy cambiando mi código TensorFlow de la antigua interfaz de la cola a la nueva API de conjunto de datos . Con la antigua interfaz podría especificar el argumento num_threads para la cola tf.train.shuffle_batch . Sin embargo, la única forma de controlar la cantidad de subprocesos en la API del conjunto de datos parece estar en la función de map usando el argumento num_parallel_calls . Sin embargo, estoy usando la función flat_map lugar, que no tiene tal argumento.

Pregunta : ¿Hay alguna forma de controlar el número de subprocesos / procesos para la función flat_map ? ¿O es posible utilizar el map en combinación con flat_map y aún especificar el número de llamadas paralelas?

Tenga en cuenta que es de vital importancia ejecutar varios subprocesos en paralelo, ya que tengo la intención de ejecutar un preprocesamiento pesado en la CPU antes de que los datos entren en la cola.

Hay dos publicaciones relacionadas ( aquí y aquí ) en GitHub, pero no creo que respondan a esta pregunta.

Aquí hay un ejemplo de código mínimo de mi caso de uso para ilustración:

 with tf.Graph().as_default(): data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data") input_tensors = (data,) def pre_processing_func(data_): # normally I would do data-augmentation here results = (tf.expand_dims(data_, axis=0),) return tf.data.Dataset.from_tensor_slices(results) dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors) dataset = dataset_source.flat_map(pre_processing_func) # do something with 'dataset' 

Related of "Subprocesos paralelos con TensorFlow Dataset API y flat_map"

Según mi conocimiento, en este momento, flat_map no ofrece opciones de paralelismo. Dado que la mayor parte del cálculo se realiza en pre_processing_func , lo que podría usar como solución alternativa es una llamada de map paralelo seguida de un búfer, y luego usar una llamada flat_map con una función lambda de identidad que se encarga de aplanar la salida.

En codigo:

 NUM_THREADS = 5 BUFFER_SIZE = 1000 def pre_processing_func(data_): # data-augmentation here # generate new samples starting from the sample `data_` artificial_samples = generate_from_sample(data_) return atificial_samples dataset_source = (tf.data.Dataset.from_tensor_slices(input_tensors). map(pre_processing_func, num_parallel_calls=NUM_THREADS). prefetch(BUFFER_SIZE). flat_map(lambda *x : tf.data.Dataset.from_tensor_slices(x)). shuffle(BUFFER_SIZE)) # my addition, probably necessary though 

Nota (para mí y para quien quiera que intente entender la tubería):

Dado que pre_processing_func genera un número arbitrario de nuevas muestras a partir de la muestra inicial (organizadas en matrices de forma (?, 512) ), la llamada flat_map es necesaria para convertir todas las matrices generadas en Dataset contengan muestras individuales (de ahí el tf.data.Dataset.from_tensor_slices(x) en la lambda) y luego aplanar todos estos conjuntos de datos en un gran Dataset contenga muestras individuales.

Probablemente sea una buena idea .shuffle() el conjunto de datos o las muestras generadas se empaquetarán juntas.