¿Cómo puedo devolver el mismo lote dos veces desde un iterador de conjunto de datos de tensorflow?

Estoy convirtiendo algo de código heredado para usar la API de conjunto de datos: este código utiliza feed_dict para alimentar un lote a la operación del tren (en realidad tres veces) y luego vuelve a calcular las pérdidas para mostrar usando el mismo lote . Así que necesito tener un iterador que devuelva exactamente el mismo lote dos (o varias) veces. Desafortunadamente, parece que no puedo encontrar una manera de hacerlo con conjuntos de datos de tensorflow, ¿es posible?

Puede repetir elementos individuales de un Dataset de Dataset utilizando Dataset.flat_map() , Dataset.from_tensors() y Dataset.repeat() juntos. Por ejemplo, para repetir elementos dos veces:

 NUM_REPEATS = 2 dataset = tf.data.Dataset.range(10) # ...or the output of `.batch()`, etc. # Repeat each element of `dataset` NUM_REPEATS times. dataset = dataset.flat_map( lambda x: tf.data.Dataset.from_tensors(x).repeat(NUM_REPEATS))