¿Cómo puedo enumerar todas las variables de Tensorflow de las que depende un nodo?

¿Cómo puedo enumerar todas las variables / constantes / marcadores de Tensorflow de las que depende un nodo?

Ejemplo 1 (sum de constantes):

import tensorflow as tf a = tf.constant(1, name = 'a') b = tf.constant(3, name = 'b') c = tf.constant(9, name = 'c') d = tf.add(a, b, name='d') e = tf.add(d, c, name='e') sess = tf.Session() print(sess.run([d, e])) 

Me gustaría tener una función list_dependencies() como:

  • list_dependencies(d) devuelve ['a', 'b']
  • list_dependencies(e) devuelve ['a', 'b', 'c']

Ejemplo 2 (multiplicación de matrices entre un marcador de posición y una matriz de ponderación, seguido de la adición de un vector de sesgo):

 tf.set_random_seed(1) input_size = 5 output_size = 3 input = tf.placeholder(tf.float32, shape=[1, input_size], name='input') W = tf.get_variable( "W", shape=[input_size, output_size], initializer=tf.contrib.layers.xavier_initializer()) b = tf.get_variable( "b", shape=[output_size], initializer=tf.constant_initializer(2)) output = tf.matmul(input, W, name="output") output_bias = tf.nn.xw_plus_b(input, W, b, name="output_bias") sess = tf.Session() sess.run(tf.global_variables_initializer()) print(sess.run([output,output_bias], feed_dict={input: [[2]*input_size]})) 

Me gustaría tener una función list_dependencies() como:

  • list_dependencies(output) devuelve ['W', 'input']
  • list_dependencies(output_bias) devuelve ['W', 'b', 'input']

Aquí están las utilidades que uso para esto (desde https://github.com/yaroslavvb/stuff/blob/master/linearize/linearize.py )

 # computation flows from parents to children def parents(op): return set(input.op for input in op.inputs) def children(op): return set(op for out in op.outputs for op in out.consumers()) def get_graph(): """Creates dictionary {node: {child1, child2, ..},..} for current TensorFlow graph. Result is compatible with networkx/toposort""" ops = tf.get_default_graph().get_operations() return {op: children(op) for op in ops} def print_tf_graph(graph): """Prints tensorflow graph in dictionary form.""" for node in graph: for child in graph[node]: print("%s -> %s" % (node.name, child.name)) 

Estas funciones funcionan en operaciones. Para obtener una op que produce tensor t , use t.op Para obtener los tensores producidos por op op , use op.outputs

La respuesta de Yaroslav Bulatov es genial, solo get_graph() una función de trazado que utiliza el get_graph() y children() Yaroslav:

 import matplotlib.pyplot as plt import networkx as nx def plot_graph(G): '''Plot a DAG using NetworkX''' def mapping(node): return node.name G = nx.DiGraph(G) nx.relabel_nodes(G, mapping, copy=False) nx.draw(G, cmap = plt.get_cmap('jet'), with_labels = True) plt.show() plot_graph(get_graph()) 

Trazando el ejemplo 1 de la pregunta:

 import matplotlib.pyplot as plt import networkx as nx import tensorflow as tf def children(op): return set(op for out in op.outputs for op in out.consumers()) def get_graph(): """Creates dictionary {node: {child1, child2, ..},..} for current TensorFlow graph. Result is compatible with networkx/toposort""" print('get_graph') ops = tf.get_default_graph().get_operations() return {op: children(op) for op in ops} def plot_graph(G): '''Plot a DAG using NetworkX''' def mapping(node): return node.name G = nx.DiGraph(G) nx.relabel_nodes(G, mapping, copy=False) nx.draw(G, cmap = plt.get_cmap('jet'), with_labels = True) plt.show() a = tf.constant(1, name = 'a') b = tf.constant(3, name = 'b') c = tf.constant(9, name = 'c') d = tf.add(a, b, name='d') e = tf.add(d, c, name='e') sess = tf.Session() print(sess.run([d, e])) plot_graph(get_graph()) 

salida:

introduzca la descripción de la imagen aquí

Trazando el ejemplo 2 de la pregunta:

introduzca la descripción de la imagen aquí

Si usa Microsoft Windows, puede encontrarse con este problema: Error de Python (ValueError: _getfullpathname: carácter nulo incrustado) , en cuyo caso necesita parchear matplotlib como lo explica el enlace.

En algunos casos, es posible que desee buscar todas las variables de “entrada” que están conectadas a un tensor de “salida”, como la pérdida de un gráfico. Para este objective, el siguiente código cortado puede ser útil (inspirado en el código anterior):

 def findVars(atensor): allinputs=atensor.op.inputs if len(allinputs)==0: if atensor.op.type == 'VariableV2' or atensor.op.type == 'Variable': return set([atensor.op]) a=set() for t in allinputs: a=a | findVars(t) return a 

Esto se puede usar en la depuración para averiguar dónde falta una conexión en el gráfico.