Creando `input_fn` desde el iterador

La mayoría de los tutoriales se centran en el caso en el que todo el conjunto de datos de entrenamiento encaja en la memoria. Sin embargo, tengo un iterador que actúa como un flujo infinito de (características, tags) -tuples (que los crean de forma barata sobre la marcha).

Al implementar el input_fn para el estimador de input_fn de tensión, ¿puedo devolver una instancia del iterador como

 def input_fn(): (feature_batch, label_batch) = next(it) return tf.constant(feature_batch), tf.constant(label_batch) 

o ¿ input_fn tiene que devolver las mismas (características, tags) -tuples en cada llamada?

Además, esta función se llama varias veces durante el entrenamiento, ya que espero que sea como en el siguiente pseudocódigo:

 for i in range(max_iter): learn_op(input_fn()) 

El argumento de input_fn se usa a lo largo del entrenamiento, pero la función en sí se llama una vez. Por lo tanto, crear un input_fn sofisticado que vaya más allá de devolver una matriz constante como se explica en el tutorial no es tan sencillo.

Tensorflow propone dos ejemplos de este tipo de entrada no trivial input_fn para matrices numpy y panda , pero comienzan desde una matriz en la memoria, por lo que esto no le ayuda con su problema.

También puede echar un vistazo a su código siguiendo los enlaces anteriores, para ver cómo implementan un input_fn no trivial input_fn , pero puede encontrar que requiere más código que le gustaría.

Si está dispuesto a utilizar la interfaz de menor nivel de Tensorflow, las cosas son más sencillas y más flexibles en mi humilde opinión. Existe un tutorial que cubre la mayoría de las necesidades y las soluciones propuestas son fáciles de implementar (-er).

En particular, si ya tiene un iterador que devuelve datos como describió en su pregunta, el uso de marcadores de posición (sección “Alimentación” en el enlace anterior) debe ser sencillo.

Encontré una solicitud de extracción que convierte un generator en un input_fn : https://github.com/tensorflow/tensorflow/pull/7045/files

La parte relevante es

  def _generator_input_fn(): """generator input function.""" queue = feeding_functions.enqueue_data( x, queue_capacity, shuffle=shuffle, num_threads=num_threads, enqueue_size=batch_size, num_epochs=num_epochs) features = (queue.dequeue_many(batch_size) if num_epochs is None else queue.dequeue_up_to(batch_size)) if not isinstance(features, list): features = [features] features = dict(zip(input_keys, features)) if target_key is not None: if len(target_key) > 1: target = {key: features.pop(key) for key in target_key} else: target = features.pop(target_key[0]) return features, target return features return _generator_input_fn 
 from tensorflow.contrib.learn.python.learn.learn_io import generator_io import numpy as np # define generator def generator(): for index in range(2): yield {'a': np.ones(1) * index,'b': np.ones(1) * index + 32,'label': np.ones(1) * index - 32} input_fn = generator_io.generator_input_fn(generator, target_key='label', batch_size=2, shuffle=False, num_epochs=1) features, target = input_fn() 

Consulte el caso de prueba https://github.com/tensorflow/tensorflow/pull/7045/files