Contador de época con TensorFlow Dataset API

Estoy cambiando mi código TensorFlow de la antigua interfaz de la cola a la nueva API de conjunto de datos . En mi código anterior, realicé el seguimiento del recuento de épocas incrementando una tf.Variable cada vez que se accede a un nuevo tensor de entrada y se procesa en la cola. Me gustaría que esta época contara con la nueva API de conjunto de datos, pero tengo algunos problemas para hacer que funcione.

Dado que estoy produciendo una cantidad variable de elementos de datos en la etapa de preprocesamiento, no es una simple cuestión de incrementar un contador (Python) en el bucle de entrenamiento. Necesito calcular el recuento de épocas con respecto a la entrada del colas o conjunto de datos.

Imité lo que tenía antes con el antiguo sistema de colas, y esto es lo que terminé con la API Dataset (ejemplo simplificado):

 with tf.Graph().as_default(): data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data") input_tensors = (data,) epoch_counter = tf.Variable(initial_value=0.0, dtype=tf.float32, trainable=False) def pre_processing_func(data_): data_size = tf.constant(0.1, dtype=tf.float32) epoch_counter_op = tf.assign_add(epoch_counter, data_size) with tf.control_dependencies([epoch_counter_op]): # normally I would do data-augmentation here results = (tf.expand_dims(data_, axis=0),) return tf.data.Dataset.from_tensor_slices(results) dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors) dataset = dataset_source.flat_map(pre_processing_func) dataset = dataset.repeat() # ... do something with 'dataset' and print # the value of 'epoch_counter' every once a while 

Sin embargo, esto no funciona. Se bloquea con un mensaje de error críptico:

  TypeError: In op 'AssignAdd', input types ([tf.float32, tf.float32]) are not compatible with expected types ([tf.float32_ref, tf.float32]) 

Una inspección más cercana muestra que la variable epoch_counter podría no ser accesible en pre_processing_func en pre_processing_func . ¿Vive en una gráfica diferente quizás?

¿Alguna idea de cómo arreglar el ejemplo anterior? ¿O cómo obtener el contador de la época (con puntos decimales, por ejemplo, 0,4 o 2,9) a través de algún otro medio?

TL; DR : Reemplaza la definición de epoch_counter con lo siguiente:

 epoch_counter = tf.get_variable("epoch_counter", initializer=0.0, trainable=False, use_resource=True) 

Existen algunas limitaciones en cuanto al uso de las variables TensorFlow dentro de tf.data.Dataset transformaciones del conjunto de datos tf.data.Dataset . La limitación principal es que todas las variables deben ser “variables de recursos” y no las “variables de referencia” anteriores; desafortunadamente tf.Variable todavía crea “variables de referencia” por razones de compatibilidad hacia atrás.

En general, no recomendaría el uso de variables en una canalización tf.data si es posible evitarlo. Por ejemplo, podría usar Dataset.range() para definir un contador de época y luego hacer algo como:

 epoch_counter = tf.data.Dataset.range(NUM_EPOCHS) dataset = epoch_counter.flat_map(lambda i: tf.data.Dataset.zip( (pre_processing_func(data), tf.data.Dataset.from_tensors(i).repeat())) 

El fragmento de código anterior adjunta un contador de época a cada valor como un segundo componente.

Para agregar a la gran respuesta de @mrry, si desea permanecer dentro del tf.data y también desea realizar un seguimiento de la iteración dentro de cada época, puede probar mi solución a continuación. Si tiene un tamaño de lote no unitario, supongo que tendría que agregar la línea data = data.batch(bs) .

 import tensorflow as tf import itertools def step_counter(): for i in itertools.count(): yield i num_examples = 3 num_epochs = 2 num_iters = num_examples * num_epochs features = tf.data.Dataset.range(num_examples) labels = tf.data.Dataset.range(num_examples) data = tf.data.Dataset.zip((features, labels)) data = data.shuffle(num_examples) step = tf.data.Dataset.from_generator(step_counter, tf.int32) data = tf.data.Dataset.zip((data, step)) epoch = tf.data.Dataset.range(num_epochs) data = epoch.flat_map( lambda i: tf.data.Dataset.zip( (data, tf.data.Dataset.from_tensors(i).repeat()))) data = data.repeat(num_epochs) it = data.make_one_shot_iterator() example = it.get_next() with tf.Session() as sess: for _ in range(num_iters): ((x, y), st), ep = sess.run(example) print(f'step {st} \t epoch {ep} \tx {x} \ty {y}') 

Huellas dactilares:

 step 0 epoch 0 x 2 y 2 step 1 epoch 0 x 0 y 0 step 2 epoch 0 x 1 y 1 step 0 epoch 1 x 2 y 2 step 1 epoch 1 x 0 y 0 step 2 epoch 1 x 1 y 1 

La línea data = data.repeat(num_epochs) hace que se repita el conjunto de datos ya repetidos para num_epochs (también el contador de la época). Se puede obtener fácilmente reemplazando for _ in range(num_iters): con for _ in range(num_iters+1):

itertool el código de ejemplo de numerica a lotes y reemplacé la parte de itertool :

 num_examples = 5 num_epochs = 4 batch_size = 2 num_iters = int(num_examples * num_epochs / batch_size) features = tf.data.Dataset.range(num_examples) labels = tf.data.Dataset.range(num_examples) data = tf.data.Dataset.zip((features, labels)) data = data.shuffle(num_examples) epoch = tf.data.Dataset.range(num_epochs) data = epoch.flat_map( lambda i: tf.data.Dataset.zip(( data, tf.data.Dataset.from_tensors(i).repeat(), tf.data.Dataset.range(num_examples) )) ) # to flatten the nested datasets data = data.map(lambda samples, *cnts: samples+cnts ) data = data.batch(batch_size=batch_size) it = data.make_one_shot_iterator() x, y, ep, st = it.get_next() with tf.Session() as sess: for _ in range(num_iters): x_, y_, ep_, st_ = sess.run([x, y, ep, st]) print(f'step {st_}\t epoch {ep_} \tx {x_} \ty {y_}')