¿Es posible reemplazar el marcador de posición con una constante en un gráfico existente?

Tengo un gráfico congelado de un modelo entrenado, tiene un tf.placeholder que siempre le doy el mismo valor.

Me preguntaba si es posible reemplazarlo con tf.constant en tf.constant lugar. Si es de alguna manera, cualquier ejemplo sería apreciado!

EDITAR: Aquí está cómo se ve con el código, para ayudar a visualizar la pregunta

Estoy usando un modelo pre-entrenado (por otras personas) para ejecutar la inferencia. El modelo se almacena localmente como un archivo de gráfico congelado con extensión .pb .

El código se ve así:

 # load graph graph = load_graph('frozen.pb') session = tf.Session(graph=graph) # Get input and output tensors images_placeholder = graph.get_tensor_by_name("input:0") output = graph.get_tensor_by_name("output:0") phase_train_placeholder = graph.get_tensor_by_name("phase_train:0") feed_dict = {images_placeholder: images, phase_train_placeholder: False} result = session.run(output, feed_dict=feed_dict) 

El problema es que siempre alimento phase_train_placeholder: False para mis propósitos, por lo que me preguntaba si es posible eliminar ese marcador de posición y reemplazarlo con algo como tf.constant(False, dtype=bool, shape=[])

Así que no conseguí encontrar una forma adecuada, pero logré hacerlo de una manera pirateada, al reconstruir la definición del gráfico y al sustituir el nodo que necesitaba sustituir. Inspirado en este código.

Aquí está el código (super hacky, uso bajo su propio riesgo):

 INPUT_GRAPH_DEF_FILE = 'path/to/file' OUTPUT_GRAPH_DEF_FILE = 'another/one' # Get NodeDef of a constant tensor we want to put in place of # the placeholder. # (There is probably a better way to do this) example_graph = tf.Graph() with tf.Session(graph=example_graph): c = tf.constant(False, dtype=bool, shape=[], name='phase_train') for node in example_graph.as_graph_def().node: if node.name == 'phase_train': c_def = node # load our graph graph = load_graph(INPUT_GRAPH_DEF_FILE) graph_def = graph.as_graph_def() # Create new graph, and rebuild it from original one # replacing phase train node def with constant new_graph_def = graph_pb2.GraphDef() for node in graph_def.node: if node.name == 'phase_train': new_graph_def.node.extend([c_def]) else: new_graph_def.node.extend([copy.deepcopy(node)]) # save new graph with tf.gfile.GFile(OUTPUT_GRAPH_DEF_FILE, "wb") as f: f.write(new_graph_def.SerializeToString()) 

Recientemente he tenido que volver a escribir la respuesta anterior.

 import tensorflow as tf import sys from tensorflow.core.framework import graph_pb2 import copy INPUT_GRAPH_DEF_FILE = sys.argv[1] OUTPUT_GRAPH_DEF_FILE = sys.argv[2] # load our graph def load_graph(filename): graph_def = tf.GraphDef() with tf.gfile.FastGFile(filename, 'rb') as f: graph_def.ParseFromString(f.read()) return graph_def graph_def = load_graph(INPUT_GRAPH_DEF_FILE) target_node_name = sys.argv[3] c = tf.constant(False, dtype=bool, shape=[], name=target_node_name) # Create new graph, and rebuild it from original one # replacing phase train node def with constant new_graph_def = graph_pb2.GraphDef() for node in graph_def.node: if node.name == target_node_name: new_graph_def.node.extend([c.op.node_def]) else: new_graph_def.node.extend([copy.deepcopy(node)]) # save new graph with tf.gfile.GFile(OUTPUT_GRAPH_DEF_FILE, "wb") as f: f.write(new_graph_def.SerializeToString())