¿Cómo entrenar la red TensorFlow usando un generador para producir entradas?

Los documentos de TensorFlow describen un montón de formas de leer datos usando TFRecordReader, TextLineReader, QueueRunner, etc. y colas.

Lo que me gustaría hacer es mucho, mucho más simple: tengo una función de generador de python que produce una secuencia infinita de datos de entrenamiento como tuplas (X, y) (ambas son matrices numpy, y la primera dimensión es el tamaño del lote). Solo quiero entrenar una red usando esos datos como entradas.

¿Existe un ejemplo simple y autónomo de entrenamiento de una red TensorFlow utilizando un generador que produce los datos? (a lo largo de las líneas de los ejemplos MNIST o CIFAR)

Supongamos que tiene una función que genera datos:

  def generator(data): ... yield (X, y) 

Ahora necesita otra función que describa la architecture de su modelo. Podría ser cualquier función que procese X y tenga que predecir y como salida (por ejemplo, neural network).

Supongamos que su función acepta X e y como entradas, calcula una predicción para y desde X de alguna manera y devuelve la función de pérdida (por ejemplo, entropía cruzada o MSE en el caso de regresión) entre y y predice y:

  def neural_network(X, y): # computation of prediction for y using X ... return loss(y, y_pred) 

Para hacer que su modelo funcione, necesita definir marcadores de posición para X e y, a continuación, ejecutar una sesión:

  X = tf.placeholder(tf.float32, shape=(batch_size, x_dim)) y = tf.placeholder(tf.float32, shape=(batch_size, y_dim)) 

Los marcadores de posición son algo así como “variables libres” que debe especificar al ejecutar la sesión mediante feed_dict :

  with tf.Session() as sess: # variables need to be initialized before any sess.run() calls tf.global_variables_initializer().run() for X_batch, y_batch in generator(data): feed_dict = {X: X_batch, y: y_batch} _, loss_value, ... = sess.run([train_op, loss, ...], feed_dict) # train_op here stands for optimization operation you have defined # and loss for loss function (return value of neural_network function) 

Espero que te sea de utilidad. Sin embargo, tenga en cuenta que esto no es una implementación completamente funcional, sino más bien un pseudocódigo, ya que casi no especificó detalles.