¿Por qué las imágenes de CIFAR-10 no se muestran correctamente con matplotlib?

Del conjunto de entrenamiento tomé una imagen (‘img’) de tamaño (3,32,32). He utilizado plt.imshow (img.T). La imagen no es clara. Ahora los cambios que tengo que hacer en la imagen (‘img’) para hacerla más claramente visible. Gracias.

Esta es la imagen que tengo

A continuación imprime cuadrícula 5X5 de imágenes aleatorias de Cifar10. No es borroso, aunque tampoco es perfecto. Cualquier sugerencia de bienvenida.

%matplotlib inline import numpy as np import matplotlib.pyplot as plt from six.moves import cPickle f = open('data/cifar10/cifar-10-batches-py/data_batch_1', 'rb') datadict = cPickle.load(f,encoding='latin1') f.close() X = datadict["data"] Y = datadict['labels'] X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("uint8") Y = np.array(Y) #Visualizing CIFAR 10 fig, axes1 = plt.subplots(5,5,figsize=(3,3)) for j in range(5): for k in range(5): i = np.random.choice(range(len(X))) axes1[j][k].set_axis_off() axes1[j][k].imshow(X[i:i+1][0]) 

La imagen está borrosa debido a la interpolación. Para evitar el desenfoque en matplotlib, llame a imshow con palabra clave interpolation='nearest' :

 plt.imshow(img.T, interpolation='nearest') 

Además, parece que tus ejes x e y se están intercambiando cuando usas la transposición, por lo que es posible que desees mostrar esto en su lugar:

 plt.imshow(np.transpose(img, (1, 2, 0)), interpolation='nearest') 

He usado el siguiente código para mostrar todos los datos de CIFAR como una imagen grande. El código muestra la imagen, pero si desea guardarla y no ser borrosa, sugiero utilizar plt.savefig(fname, format='png', dpi=1000)

 import numpy as np import matplotlib.pyplot as plt def reshape_and_print(self, cifar_data): # number of images in rows and columns rows = cols = np.sqrt(cifar_data.shape[0]).astype(np.int32) # Image hight and width. Divide by 3 because of 3 color channels imh = imw = np.sqrt(cifar_data.shape[1] // 3).astype(np.int32) # reshape to number of images X color channels X image size # transpose to color channels X number of images X image size timg = cifar_data.reshape(rows * cols, 3, imh * imh).transpose(1, 0, 2) # reshape to color channels X rows X cols X image hight X image with # swap axis to color channels X rows X image hight X cols X image with timg = timg.reshape(3, rows, cols, imh, imw).swapaxes(2, 3) # reshape to color channels X combined image hight X combined image with # transpose to combined image hight X combined image with X color channels timg = timg.reshape(3, rows * imh, cols * imw).transpose(1, 2, 0) plt.imshow(timg) plt.show() 

Hice una clase de ayuda rápida de datos que usé para un proyecto de prueba pequeño, espero que pueda ser útil:

 import gzip import pickle import numpy as np import matplotlib.pyplot as plt class DataSet(object): def __init__(self, seed=42, setsize=10000): self.seed = seed # set the seed for reproducability np.random.seed(seed) # load the data train_set, test_set = self.load_data() # self.split_data(train_set, valid_set, test_set) self.split_data(train_set, test_set, setsize) def split_data(self, data_set, test_set, split_size): permutation = np.random.permutation(data_set.shape[0]) self.train = data_set[permutation[:split_size]] self.valid = data_set[permutation[split_size:split_size * 2]] self.test = test_set[:split_size] def reshape_for_print(self, data): raise NotImplemented def load_data(self): raise NotImplemented def show_all_imgs(self, data): raise NotImplemented class CIFAR(DataSet): def load_data(self): # try to load data with open('./data/cifar-100-python/train', 'rb') as f: data = pickle.load(f, encoding='latin1') train_set = data['data'].astype(np.float32) / 255.0 with open('./data/cifar-100-python/test', 'rb') as f: data = pickle.load(f, encoding='latin1') test_set = data['data'].astype(np.float32) / 255.0 return train_set, test_set def reshape_for_print(self, data): gh = gw = np.sqrt(data.shape[0]).astype(np.int32) imh = imw = np.sqrt(data.shape[1] // 3).astype(np.int32) timg = data.reshape(gh * gw, 3, imh * imh).transpose(1, 0, 2) timg = timg.reshape(3, gh, gw, imh, imw).swapaxes(2, 3) timg = timg.reshape(3, gh * imh, gw * imw).transpose(1, 2, 0) return timg def show_all_imgs(self, data): timg = self.reshape_for_print(data) plt.imshow(timg) plt.show() class MNIST(DataSet): def load_data(self): # try to load data with gzip.open('./data/mnist.pkl.gz', 'rb') as f: train_set, valid_set, test_set = pickle.load(f, encoding='latin1') return train_set[0], test_set[0] def reshape_for_print(self, data): gh = gw = np.sqrt(data.shape[0]).astype(np.int32) imh = imw = np.sqrt(data.shape[1]).astype(np.int32) timg = data.reshape(gh, gw, imh, imw).swapaxes(1, 2) timg = timg.reshape(gh * imh, gw * imw) return timg def show_all_imgs(self, data): timg = self.reshape_for_print(data) plt.imshow(timg, cmap=plt.cm.gray) plt.show() 

intenta usar

 import matplotlib.pyplot as plt from scipy.misc import toimage plt.imshow(toimage(img)) 

No estoy 100% seguro de cómo funciona el código, pero creo que debido a que las imágenes se almacenan en matrices numpy de punto flotante, la función imshow () tiene dificultades para mapearlas con los colores correctos. Al encasillarlos a imagen usando toimage (), los convierte a un formato de imagen adecuado que imshow () espera, es decir, no una matriz, sino una imagen codificada como .png o .jpg.

Este código me funciona cada vez que quiero mostrar imágenes en python.

Hice una función para trazar la imagen RGB de una fila en el conjunto de datos CIFAR10. La imagen se verá borrosa en el mejor de los casos ya que el tamaño original de la imagen es muy pequeño (32 px x 32 px).

Imagen de muestra

 def unpickle(file): with open(file, 'rb') as fo: dict1 = pickle.load(fo, encoding='bytes') return dict1 pd_tr = pd.DataFrame() tr_y = pd.DataFrame() for i in range(1,6): data = unpickle('data/data_batch_' + str(i)) pd_tr = pd_tr.append(pd.DataFrame(data[b'data'])) tr_y = tr_y.append(pd.DataFrame(data[b'labels'])) pd_tr['labels'] = tr_y tr_x = np.asarray(pd_tr.iloc[:, :3072]) tr_y = np.asarray(pd_tr['labels']) ts_x = np.asarray(unpickle('data/test_batch')[b'data']) ts_y = np.asarray(unpickle('data/test_batch')[b'labels']) labels = unpickle('data/batches.meta')[b'label_names'] def plot_CIFAR(ind): arr = tr_x[ind] sc_dpi = 157.35 R = arr[0:1024].reshape(32,32)/255.0 G = arr[1024:2048].reshape(32,32)/255.0 B = arr[2048:].reshape(32,32)/255.0 img = np.dstack((R,G,B)) title = re.sub('[!@#$b]', '', str(labels[tr_y[ind]])) fig = plt.figure(figsize=(3,3)) ax = fig.add_subplot(111) ax.imshow(img,interpolation='bicubic') ax.set_title('Category = '+ title,fontsize =15) plot_CIFAR(4) 

Este archivo lee el conjunto de datos cifar10 y traza imágenes individuales utilizando matplotlib .

 import _pickle as pickle import argparse import numpy as np import os import matplotlib.pyplot as plt cifar10 = "./cifar-10-batches-py/" parser = argparse.ArgumentParser("Plot training images in cifar10 dataset") parser.add_argument("-i", "--image", type=int, default=0, help="Index of the image in cifar10. In range [0, 49999]") args = parser.parse_args() def unpickle(file): with open(file, 'rb') as fo: dict = pickle.load(fo, encoding='bytes') return dict def cifar10_plot(data, meta, im_idx=0): im = data[b'data'][im_idx, :] im_r = im[0:1024].reshape(32, 32) im_g = im[1024:2048].reshape(32, 32) im_b = im[2048:].reshape(32, 32) img = np.dstack((im_r, im_g, im_b)) print("shape: ", img.shape) print("label: ", data[b'labels'][im_idx]) print("category:", meta[b'label_names'][data[b'labels'][im_idx]]) plt.imshow(img) plt.show() def main(): batch = (args.image // 10000) + 1 idx = args.image - (batch-1)*10000 data = unpickle(os.path.join(cifar10, "data_batch_" + str(batch))) meta = unpickle(os.path.join(cifar10, "batches.meta")) cifar10_plot(data, meta, im_idx=idx) if __name__ == "__main__": main() 

Añadir 0.5:

 plt.imshow(np.transpose(img, (1, 2, 0)) + 0.5)