La mayoría de los tutoriales se centran en el caso en el que todo el conjunto de datos de entrenamiento encaja en la memoria. Sin embargo, tengo un iterador que actúa como un flujo infinito de (características, tags) -tuples (que los crean de forma barata sobre la marcha).
Al implementar el input_fn
para el estimador de input_fn
de tensión, ¿puedo devolver una instancia del iterador como
def input_fn(): (feature_batch, label_batch) = next(it) return tf.constant(feature_batch), tf.constant(label_batch)
o ¿ input_fn
tiene que devolver las mismas (características, tags) -tuples en cada llamada?
Además, esta función se llama varias veces durante el entrenamiento, ya que espero que sea como en el siguiente pseudocódigo:
for i in range(max_iter): learn_op(input_fn())
El argumento de input_fn
se usa a lo largo del entrenamiento, pero la función en sí se llama una vez. Por lo tanto, crear un input_fn
sofisticado que vaya más allá de devolver una matriz constante como se explica en el tutorial no es tan sencillo.
Tensorflow propone dos ejemplos de este tipo de entrada no trivial input_fn
para matrices numpy y panda , pero comienzan desde una matriz en la memoria, por lo que esto no le ayuda con su problema.
También puede echar un vistazo a su código siguiendo los enlaces anteriores, para ver cómo implementan un input_fn
no trivial input_fn
, pero puede encontrar que requiere más código que le gustaría.
Si está dispuesto a utilizar la interfaz de menor nivel de Tensorflow, las cosas son más sencillas y más flexibles en mi humilde opinión. Existe un tutorial que cubre la mayoría de las necesidades y las soluciones propuestas son fáciles de implementar (-er).
En particular, si ya tiene un iterador que devuelve datos como describió en su pregunta, el uso de marcadores de posición (sección “Alimentación” en el enlace anterior) debe ser sencillo.
Encontré una solicitud de extracción que convierte un generator
en un input_fn
: https://github.com/tensorflow/tensorflow/pull/7045/files
La parte relevante es
def _generator_input_fn(): """generator input function.""" queue = feeding_functions.enqueue_data( x, queue_capacity, shuffle=shuffle, num_threads=num_threads, enqueue_size=batch_size, num_epochs=num_epochs) features = (queue.dequeue_many(batch_size) if num_epochs is None else queue.dequeue_up_to(batch_size)) if not isinstance(features, list): features = [features] features = dict(zip(input_keys, features)) if target_key is not None: if len(target_key) > 1: target = {key: features.pop(key) for key in target_key} else: target = features.pop(target_key[0]) return features, target return features return _generator_input_fn
from tensorflow.contrib.learn.python.learn.learn_io import generator_io import numpy as np # define generator def generator(): for index in range(2): yield {'a': np.ones(1) * index,'b': np.ones(1) * index + 32,'label': np.ones(1) * index - 32} input_fn = generator_io.generator_input_fn(generator, target_key='label', batch_size=2, shuffle=False, num_epochs=1) features, target = input_fn()
Consulte el caso de prueba https://github.com/tensorflow/tensorflow/pull/7045/files