¿Usando entrenamiento hecho con la API de python como entrada para el módulo LabelImage en la API de java?

Tengo un problema con java tensorflow API. He ejecutado la capacitación utilizando la API de tensorflow de python, generando los archivos output_graph.pb y output_labels.txt. Ahora, por alguna razón, quiero usar esos archivos como entrada para el módulo LabelImage en la API de tensorflow de Java. Pensé que todo habría funcionado bien ya que ese módulo quiere exactamente un .pb y un .txt. Sin embargo, cuando ejecuto el módulo, me sale este error:

2017-04-26 10:12:56.711402: W tensorflow/core/framework/op_def_util.cc:332] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization(). Exception in thread "main" java.lang.IllegalArgumentException: No Operation named [input] in the Graph at org.tensorflow.Session$Runner.operationByName(Session.java:343) at org.tensorflow.Session$Runner.feed(Session.java:137) at org.tensorflow.Session$Runner.feed(Session.java:126) at it.zero11.LabelImage.executeInceptionGraph(LabelImage.java:115) at it.zero11.LabelImage.main(LabelImage.java:68) 

Le agradecería mucho que me ayudara a encontrar dónde está el problema. Además, quiero preguntarle si hay una forma de ejecutar la capacitación desde la API de tensorflow de Java, porque eso facilitaría las cosas.

Ser más preciso:

De hecho, no uso código autoescrito, al menos para los pasos relevantes. Todo lo que he hecho es hacer el entrenamiento con este módulo, https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py , alimentándolo con el directorio que contiene las imágenes divididas entre subdirectorios. Según su descripción. En particular, creo que estas son las líneas que generan los resultados:

 output_graph_def = graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [FLAGS.final_tensor_name]) with gfile.FastGFile(FLAGS.output_graph, 'wb') as f: f.write(output_graph_def.SerializeToString()) with gfile.FastGFile(FLAGS.output_labels, 'w') as f: f.write('\n'.join(image_lists.keys()) + '\n') 

Luego, doy las salidas (una some_graph.pb y una some_labels.txt) como entrada para este módulo java: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/ org / tensorflow / examples / LabelImage.java , reemplazando las entradas predeterminadas. El error que recibo es el reportado arriba.

El modelo utilizado por defecto en LabelImage.java es diferente al modelo que se está reentrenando, por lo que los nombres de las entradas y los nodos de salida no se alinean. Tenga en cuenta que los modelos TensorFlow son gráficos y los argumentos para feed() y fetch() son nombres de nodos en el gráfico. Así que necesitas saber los nombres apropiados para tu modelo.

Al retrain.py , parece que tiene un nodo que toma el contenido sin procesar de un archivo JPEG como entrada (el nodo DecodeJpeg/contents ) y produce el conjunto de tags en el nodo final_result .

Si ese es el caso, entonces harías algo como lo siguiente en Java (y no necesitas el bit que construye un gráfico para normalizar la imagen, ya que parece ser parte del modelo reentrenado, así que reemplaza LabelImage.java:64 con algo como:

 try (Tensor image = Tensor.create(imageBytes); Graph g = new Graph()) { g.importGraphDef(graphDef); try (Session s = new Session(g); // Note the change to the name of the node and the fact // that it is being provided the raw imageBytes as input Tensor result = s.runner().feed("DecodeJpeg/contents", image).fetch("final_result").run().get(0)) { final long[] rshape = result.shape(); if (result.numDimensions() != 2 || rshape[0] != 1) { throw new RuntimeException( String.format( "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s", Arrays.toString(rshape))); } int nlabels = (int) rshape[1]; float[] probabilities = result.copyTo(new float[1][nlabels])[0]; // At this point nlabels = number of classes in your retrained model DoSomethingWith(probabilities); } } 

Espero que ayude.

Con respecto al error “Sin operación”, pude resolverlo utilizando los nombres de capa de entrada y salida “Mul” y “final_result”, respectivamente. Ver:

https://github.com/tensorflow/tensorflow/issues/2883