Estoy tratando de usar tensorflow para el aprendizaje por transferencia. Descargué el modelo pre-entrenado inception3 del tutorial. En el código, para la predicción:
prediction = sess.run(softmax_tensor,{'DecodeJpeg/contents:0'}:image_data})
¿Hay una manera de alimentar la imagen png. Intenté cambiar DecodeJpeg
a DecodePng
pero no funcionó. Además, ¿qué debo cambiar si deseo alimentar un archivo de imagen decodificado como una matriz numpy o un lote de matrices?
¡¡Gracias!!
El gráfico InceptionV3 que se incluye en classify_image.py
solo admite imágenes JPEG listas para usar. Hay dos formas de usar este gráfico con imágenes PNG:
Convierta la imagen PNG a una height
width
x 3 (canales) Numpy array, por ejemplo, usando PIL , luego alimente el 'DecodeJpeg:0'
:
import numpy as np from PIL import Image # ... image = Image.open("example.png") image_array = np.array(image)[:, :, 0:3] # Select RGB channels only. prediction = sess.run(softmax_tensor, {'DecodeJpeg:0': image_array})
Quizás confusamente, 'DecodeJpeg:0'
es la salida de la DecodeJpeg
, por lo que al alimentar este tensor, puede alimentar datos de imagen en bruto.
Agregue un tf.image.decode_png()
op al gráfico importado. Simplemente cambiando el nombre del tensor alimentado de 'DecodeJpeg/contents:0'
a 'DecodePng/contents:0'
no funciona porque no hay 'DecodePng'
en el gráfico enviado. Puede agregar dicho nodo al gráfico utilizando el argumento tf.import_graph_def()
para tf.import_graph_def()
:
png_data = tf.placeholder(tf.string, shape=[]) decoded_png = tf.image.decode_png(png_data, channels=3) # ... graph_def = ... softmax_tensor = tf.import_graph_def( graph_def, input_map={'DecodeJpeg:0': decoded_png}, return_elements=['softmax:0']) sess.run(softmax_tensor, {png_data: ...})
El siguiente código debe manejar de ambos casos.
import numpy as np from PIL import Image image_file = 'test.jpeg' with tf.Session() as sess: # softmax_tensor = sess.graph.get_tensor_by_name('final_result:0') if image_file.lower().endswith('.jpeg'): image_data = tf.gfile.FastGFile(image_file, 'rb').read() prediction = sess.run('final_result:0', {'DecodeJpeg/contents:0': image_data}) elif image_file.lower().endswith('.png'): image = Image.open(image_file) image_array = np.array(image)[:, :, 0:3] prediction = sess.run('final_result:0', {'DecodeJpeg:0': image_array}) prediction = prediction[0] print(prediction)
o versión más corta con cuerdas directas:
image_file = 'test.png' # or 'test.jpeg' image_data = tf.gfile.FastGFile(image_file, 'rb').read() ph = tf.placeholder(tf.string, shape=[]) with tf.Session() as sess: predictions = sess.run(output_layer_name, {ph: image_data} )