Matplotlib imshow / matshow muestra los valores en el gráfico

Estoy intentando crear una cuadrícula de 10×10 usando imshow o matshow en Matplotlib. La siguiente función toma una matriz numpy como entrada y traza la cuadrícula. Sin embargo, me gustaría que los valores de la matriz también se muestren dentro de las celdas definidas por la cuadrícula. Hasta ahora no he podido encontrar una manera adecuada de hacerlo. Puedo usar plt.text para colocar cosas sobre la cuadrícula, pero esto requiere coordenadas de cada celda, totalmente inconveniente. ¿Hay una mejor manera de hacer lo que estoy tratando de lograr?

¡Gracias!

NOTA: el código siguiente no toma los valores de la matriz todavía, solo estaba jugando con plt.text .

 import numpy as np import matplotlib.pyplot as plt from matplotlib import colors board = np.zeros((10, 10)) def visBoard(board): cmap = colors.ListedColormap(['white', 'red']) bounds=[0,0.5,1] norm = colors.BoundaryNorm(bounds, cmap.N) plt.figure(figsize=(4,4)) plt.matshow(board, cmap=cmap, norm=norm, interpolation='none', vmin=0, vmax=1) plt.xticks(np.arange(0.5,10.5), []) plt.yticks(np.arange(0.5,10.5), []) plt.text(-0.1, 0.2, 'x') plt.text(0.9, 0.2, 'o') plt.text(1.9, 0.2, 'x') plt.grid() visBoard(board) 

Salida:

introduzca la descripción de la imagen aquí

¿Puedes hacer algo como:

 import numpy as np import matplotlib.pyplot as plt fig, ax = plt.subplots() min_val, max_val = 0, 10 ind_array = np.arange(min_val + 0.5, max_val + 0.5, 1.0) x, y = np.meshgrid(ind_array, ind_array) for i, (x_val, y_val) in enumerate(zip(x.flatten(), y.flatten())): c = 'x' if i%2 else 'o' ax.text(x_val, y_val, c, va='center', ha='center') #alternatively, you could do something like #for x_val, y_val in zip(x.flatten(), y.flatten()): # c = 'x' if (x_val + y_val)%2 else 'o' ax.set_xlim(min_val, max_val) ax.set_ylim(min_val, max_val) ax.set_xticks(np.arange(max_val)) ax.set_yticks(np.arange(max_val)) ax.grid() 

introduzca la descripción de la imagen aquí


Editar:

Aquí hay un ejemplo actualizado con un fondo de imshow .

 import numpy as np import matplotlib.pyplot as plt fig, ax = plt.subplots() min_val, max_val, diff = 0., 10., 1. #imshow portion N_points = (max_val - min_val) / diff imshow_data = np.random.rand(N_points, N_points) ax.imshow(imshow_data, interpolation='nearest') #text portion ind_array = np.arange(min_val, max_val, diff) x, y = np.meshgrid(ind_array, ind_array) for x_val, y_val in zip(x.flatten(), y.flatten()): c = 'x' if (x_val + y_val)%2 else 'o' ax.text(x_val, y_val, c, va='center', ha='center') #set tick marks for grid ax.set_xticks(np.arange(min_val-diff/2, max_val-diff/2)) ax.set_yticks(np.arange(min_val-diff/2, max_val-diff/2)) ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_xlim(min_val-diff/2, max_val-diff/2) ax.set_ylim(min_val-diff/2, max_val-diff/2) ax.grid() plt.show() 

introduzca la descripción de la imagen aquí

Para tu gráfica deberías probar con pyplot.table :

 import matplotlib.pyplot as plt import numpy as np board = np.zeros((10, 10)) board[0,0] = 1 board[0,1] = -1 board[0,2] = 1 def visBoard(board): data = np.empty(board.shape,dtype=np.str) data[:,:] = ' ' data[board==1.0] = 'X' data[board==-1.0] = 'O' plt.axis('off') size = np.ones(board.shape[0])/board.shape[0] plt.table(cellText=data,loc='center',colWidths=size,cellLoc='center',bbox=[0,0,1,1]) plt.show() visBoard(board) 

Algunos detalles sobre el código de @wflynny convirtiéndolo en una función que toma cualquier matriz sin importar el tamaño y traza sus valores.

 import numpy as np import matplotlib.pyplot as plt cols = np.random.randint(low=1,high=30) rows = np.random.randint(low=1,high=30) X = np.random.rand(rows,cols) def plotMat(X): fig, ax = plt.subplots() #imshow portion ax.imshow(X, interpolation='nearest') #text portion diff = 1. min_val = 0. rows = X.shape[0] cols = X.shape[1] col_array = np.arange(min_val, cols, diff) row_array = np.arange(min_val, rows, diff) x, y = np.meshgrid(col_array, row_array) for col_val, row_val in zip(x.flatten(), y.flatten()): c = '+' if X[row_val.astype(int),col_val.astype(int)] < 0.5 else '-' ax.text(col_val, row_val, c, va='center', ha='center') #set tick marks for grid ax.set_xticks(np.arange(min_val-diff/2, cols-diff/2)) ax.set_yticks(np.arange(min_val-diff/2, rows-diff/2)) ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_xlim(min_val-diff/2, cols-diff/2) ax.set_ylim(min_val-diff/2, rows-diff/2) ax.grid() plt.show() plotMat(X)