¿Cómo cargar imágenes en Pytorch DataLoader?

El tutorial de pytorch para la carga y el procesamiento de datos es bastante específico para un ejemplo. ¿Podría alguien ayudarme con el aspecto de la función para una carga de imágenes más simple y genérica?

Tutorial: http://pytorch.org/tutorials/beginner/data_loading_tutorial.html

Mis datos:

Tengo el conjunto de datos MINST como jpg en la siguiente estructura de carpetas. (Sé que solo puedo usar la clase de conjunto de datos, pero esto es simplemente para ver cómo cargar imágenes simples en pytorch sin funciones csv o complejas).

El nombre de la carpeta es la etiqueta y las imágenes son 28×28 png en escala de grises, no se requieren transformaciones.

data train 0 3.png 5.png 13.png 23.png ... 1 3.png 10.png 11.png ... 2 4.png 13.png ... 3 8.png ... 4 ... 5 ... 6 ... 7 ... 8 ... 9 ... 

Si está usando mnist, ya hay un preset en pytorch a través de torchvision.
Podrías hacerlo

 import torch import torchvision import torchvision.transforms as transforms import pandas as pd transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) mnistTrainSet = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) mnistTrainLoader = torch.utils.data.DataLoader(mnistTrainSet, batch_size=16, shuffle=True, num_workers=2) 

Si desea generalizar a un directorio de imágenes (las mismas importaciones que arriba), puede hacer

 class mnistmTrainingDataset(torch.utils.data.Dataset): def __init__(self,text_file,root_dir,transform=transformMnistm): """ Args: text_file(string): path to text file root_dir(string): directory with all train images """ self.name_frame = pd.read_csv(text_file,sep=" ",usecols=range(1)) self.label_frame = pd.read_csv(text_file,sep=" ",usecols=range(1,2)) self.root_dir = root_dir self.transform = transform def __len__(self): return len(self.name_frame) def __getitem__(self, idx): img_name = os.path.join(self.root_dir, self.name_frame.iloc[idx, 0]) image = Image.open(img_name) image = self.transform(image) labels = self.label_frame.iloc[idx, 0] #labels = labels.reshape(-1, 2) sample = {'image': image, 'labels': labels} return sample mnistmTrainSet = mnistmTrainingDataset(text_file ='Downloads/mnist_m/mnist_m_train_labels.txt', root_dir = 'Downloads/mnist_m/mnist_m_train') mnistmTrainLoader = torch.utils.data.DataLoader(mnistmTrainSet,batch_size=16,shuffle=True, num_workers=2) 

A continuación, puede iterar sobre él como:

 for i_batch,sample_batched in enumerate(mnistmTrainLoader,0): print("training sample for mnist-m") print(i_batch,sample_batched['image'],sample_batched['labels']) 

Hay un montón de formas de generalizar pytorch para la carga de dataset de imágenes, el método que conozco es subclasificar torch.utils.data.dataset

Esto es lo que hice para pytorch 4.1

 def load_dataset(): data_path = 'data/train/' train_dataset = torchvision.datasets.ImageFolder( root=data_path, transform=torchvision.transforms.ToTensor() ) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=64, num_workers=0, shuffle=True ) return train_loader for batch_idx, (data, target) in enumerate(load_dataset()): #train network