Animando un gráfico de red para mostrar el progreso de un algoritmo

Me gustaría animar un gráfico de red para mostrar el progreso de un algoritmo. Estoy usando NetworkX para la creación de gráficos.

A partir de esta respuesta SO , se me ocurrió una solución que utiliza clear_ouput de IPython.display y el comando plt.pause() para administrar la velocidad de la animación. Esto funciona bien para gráficos pequeños con algunos nodos, pero cuando lo implemento en una cuadrícula de 10×10, la animación es muy lenta y la reducción del argumento en plt.pause() no parece tener ningún efecto en la velocidad de la animación. Aquí hay una MME con una implementación del algoritmo de Dijkstra donde actualizo los colores de los nodos en cada iteración del algoritmo:

 import math import queue import random import networkx as nx import matplotlib.pyplot as plt from IPython.display import clear_output %matplotlib inline # plotting function def get_fig(G,current,pred): nColorList = [] for i in G.nodes(): if i == current: nColorList.append('red') elif i==pred: nColorList.append('white') elif i==N: nColorList.append('grey') elif node_visited[i]==1:nColorList.append('dodgerblue') else: nColorList.append('powderblue') plt.figure(figsize=(10,10)) nx.draw_networkx(G,pos,node_color=nColorList,width=2,node_size=400,font_size=10) plt.axis('off') plt.show() # graph creation G=nx.DiGraph() pos={} cost={} for i in range(100): x= i % 10 y= math.floor(i/10) pos[i]=(x,y) if i % 10 != 9 and i+1 < 100: cost[(i,i+1)] = random.randint(0,9) cost[(i+1,i)] = random.randint(0,9) if i+10  cost[(i,j)] + lab[j]: lab[i] = cost[(i,j)] + lab[j] path[i] = j SE.put((lab[i],i)) clear_output(wait=True) get_fig(G,j,i) plt.pause(0.0001) print('end') 

Idealmente, me gustaría mostrar la animación completa en no más de 5 segundos, mientras que actualmente se necesitan unos minutos para completar el algoritmo, lo que sugiere que plt.pause(0.0001) no funciona según lo previsto.

Después de leer las publicaciones de SO en el gráfico de animación ( publicación 2 y publicación 3 ), parece que el módulo de animation de matplotlib podría usarse para resolver esto, pero no he podido implementar las respuestas correctamente en mi algoritmo. La respuesta en la publicación 2 sugiere el uso de FuncAnimation de matplotlib, pero estoy luchando para adaptar el método de update a mi problema y la respuesta en la publicación 3 conduce a un buen tutorial con una sugerencia similar.

Mi pregunta es cómo puedo mejorar la velocidad de la animación para mi problema: ¿es posible organizar los clear_output y plt.pause() para una animación más rápida o debo usar FuncAnimation de matplotlib? Si es lo último, ¿cómo debo definir la función de update ?

Gracias por tu ayuda.

EDITAR 1

 import math import queue import random import networkx as nx import matplotlib.pyplot as plt # plotting function def get_fig(G,current,pred): for i in G.nodes(): if i==current: G.node[i]['draw'].set_color('red') elif i==pred: G.node[i]['draw'].set_color('white') elif i==N: G.node[i]['draw'].set_color('grey') elif node_visited[i]==1: G.node[i]['draw'].set_color('dodgerblue') else: G.node[i]['draw'].set_color('powderblue') # graph creation G=nx.DiGraph() pos={} cost={} for i in range(100): x= i % 10 y= math.floor(i/10) pos[i]=(x,y) if i % 10 != 9 and i+1 < 100: cost[(i,i+1)] = random.randint(0,9) cost[(i+1,i)] = random.randint(0,9) if i+10  cost[(i,j)] + lab[j]: lab[i] = cost[(i,j)] + lab[j] path[i] = j SE.put((lab[i],i)) get_fig(G,j,i) plt.draw() plt.pause(0.00001) plt.close() 

Editar 2

 import math import queue import random import networkx as nx import matplotlib.pyplot as plt # graph creation G=nx.DiGraph() pos={} cost={} for i in range(100): x= i % 10 y= math.floor(i/10) pos[i]=(x,y) if i % 10 != 9 and i+1 < 100: cost[(i,i+1)] = random.randint(0,9) cost[(i+1,i)] = random.randint(0,9) if i+10  cost[(i,j)] + lab[j]: lab[i] = cost[(i,j)] + lab[j] path[i] = j SE.put((lab[i],i)) if i!=N: G.node[i]['draw'].set_alpha(0.7) G[i][j]['draw'].set_alpha(1.0) ax.draw_artist(G[i][j]['draw']) ax.draw_artist(G.node[i]['draw']) ax.draw_artist(G.node[j]['draw']) canvas.blit(ax.bbox) plt.pause(0.0001) plt.close() 

Si su gráfico no es demasiado grande, puede probar el siguiente enfoque que establece las propiedades para nodos y bordes individuales. El truco consiste en guardar la salida de las funciones de dibujo, lo que le da un control de las propiedades del objeto, como el color, la transparencia y la visibilidad.

 import networkx as nx import matplotlib.pyplot as plt G = nx.cycle_graph(12) pos = nx.spring_layout(G) cf = plt.figure(1, figsize=(8,8)) ax = cf.add_axes((0,0,1,1)) for n in G: G.node[n]['draw'] = nx.draw_networkx_nodes(G,pos,nodelist=[n], with_labels=False,node_size=200,alpha=0.5,node_color='r') for u,v in G.edges(): G[u][v]['draw']=nx.draw_networkx_edges(G,pos,edgelist=[(u,v)],alpha=0.5,arrows=False,width=5) plt.ion() plt.draw() sp = nx.shortest_path(G,0,6) edges = zip(sp[:-1],sp[1:]) for u,v in edges: plt.pause(1) G.node[u]['draw'].set_color('r') G.node[v]['draw'].set_color('r') G[u][v]['draw'].set_alpha(1.0) G[u][v]['draw'].set_color('r') plt.draw() 

EDITAR

Aquí hay un ejemplo en una cuadrícula de 10×10 usando graphviz para hacer el diseño. Todo funciona en aproximadamente 1 segundo en mi máquina.

 import networkx as nx import matplotlib.pyplot as plt G = nx.grid_2d_graph(10,10) pos = nx.graphviz_layout(G) cf = plt.figure(1, figsize=(8,8)) ax = cf.add_axes((0,0,1,1)) for n in G: G.node[n]['draw'] = nx.draw_networkx_nodes(G,pos,nodelist=[n], with_labels=False,node_size=200,alpha=0.5,node_color='k') for u,v in G.edges(): G[u][v]['draw']=nx.draw_networkx_edges(G,pos,edgelist=[(u,v)],alpha=0.5,arrows=False,width=5) plt.ion() plt.draw() plt.show() sp = nx.shortest_path(G,(0,0),(9,9)) edges = zip(sp[:-1],sp[1:]) for u,v in edges: G.node[u]['draw'].set_color('r') G.node[v]['draw'].set_color('r') G[u][v]['draw'].set_alpha(1.0) G[u][v]['draw'].set_color('r') plt.draw() 

Editar 2

Aquí hay otro enfoque que es más rápido (no redibuja el eje o todos los nodos) y utiliza un primer algoritmo de búsqueda amplio. Este se ejecuta en unos 2 segundos en mi máquina. Noté que algunos backends son más rápidos, estoy usando GTKAgg.

 import networkx as nx import matplotlib.pyplot as plt def single_source_shortest_path(G,source): ax = plt.gca() canvas = ax.figure.canvas background = canvas.copy_from_bbox(ax.bbox) level=0 # the current level nextlevel={source:1} # list of nodes to check at next level paths={source:[source]} # paths dictionary (paths to key from source) G.node[source]['draw'].set_color('r') G.node[source]['draw'].set_alpha('1.0') while nextlevel: thislevel=nextlevel nextlevel={} for v in thislevel: # canvas.restre_region(background) s = G.node[v]['draw'] s.set_color('r') s.set_alpha('1.0') for w in G[v]: if w not in paths: n = G.node[w]['draw'] n.set_color('r') n.set_alpha('1.0') e = G[v][w]['draw'] e.set_alpha(1.0) e.set_color('k') ax.draw_artist(e) ax.draw_artist(n) ax.draw_artist(s) paths[w]=paths[v]+[w] nextlevel[w]=1 canvas.blit(ax.bbox) level=level+1 return paths if __name__=='__main__': G = nx.grid_2d_graph(10,10) pos = nx.graphviz_layout(G) cf = plt.figure(1, figsize=(8,8)) ax = cf.add_axes((0,0,1,1)) for n in G: G.node[n]['draw'] = nx.draw_networkx_nodes(G,pos,nodelist=[n], with_labels=False,node_size=200,alpha=0.2,node_color='k') for u,v in G.edges(): G[u][v]['draw']=nx.draw_networkx_edges(G,pos,edgelist=[(u,v)],alpha=0.5,arrows=False,width=5) plt.ion() plt.show() path = single_source_shortest_path(G,source=(0,0))