본문 바로가기
💻 Programming/AI & ML

[pytorch] Custom dataset, dataloader 만들기

by Moving J 2022. 1. 2.
반응형

* dataset 폴더 구조 

minc2500
├─images
│        ├─brick
│        │     ├─brick_000000.jpg
│        │     ├─brick_000001.jpg
│        │     ├─...
│        ├─carpet
│        │     ├─carpet_000000.jpg
│        │     ├─...
│        ├─...
│        │     ├─...
│        │     ├─...
...         ...            ...
├─labels
│        ├─train1.txt
│        ├─train2.txt
│        ├─...
│        ├─test1.txt
│        ├─test2.txt
│        └─...

 

 

import os
import os.path
import torch
import torch.utils.data as data
from PIL import Image
from torchvision import transforms
import numpy as np

# dataset의 class 이름, class index 저장
def find_classes(dir):
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx

# txt 파일로 저장된 train, test data의 경로를 읽어와서 image와 label을 페어로 저장
def make_dataset(txtnames, datadir, class_to_idx):
    images = []
    labels = []
    for txtname in txtnames:
        with open(txtname, 'r') as lines:
            for line in lines:
                classname = line.split('/')[1]
                _img = os.path.join(datadir, line.strip())
                assert os.path.isfile(_img)
                images.append(_img)
                labels.append(class_to_idx[classname])

    return images, labels


class MINCDataset(data.Dataset):
    def __init__(self, args, transform=None, train=True):
        classes, class_to_idx = find_classes(os.path.join(config.dataset_path, 'images'))
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.train = train
        self.transform = transform
		
        # train, test data의 경로가 저장된 txt 파일을 불러옴
        if train:
            filename = [os.path.join(config.dataset_path, 'labels/train' + args.split + '.txt'),
                        os.path.join(config.dataset_path, 'labels/validate' + args.split + '.txt')]
        else:
            filename = [os.path.join(config.dataset_path, 'labels/test' + args.split + '.txt')]

        self.images, self.labels = make_dataset(filename, args.dataset_path, class_to_idx)
        assert (len(self.images) == len(self.labels))

    def __getitem__(self, index): # 경로로 저장된 image를 PIL image로 불러옴
        _img = Image.open(self.images[index]).convert('RGB')
        _label = self.labels[index]
        if self.transform is not None:	# image에 지정된 transform 수행(tensor 변환 포함)
            _img = self.transform(_img)

        return _img, _label

    def __len__(self):
        return len(self.images)


class Dataloder():
    def __init__(self, args):
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        transform_train = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])

        trainset = MINCDataset(args, transform_train, train=True)
        testset = MINCDataset(args, transform_test, train=False)

        kwargs = {'num_workers': 8, 'pin_memory': True}
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, **kwargs)
        testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, **kwargs)
        self.classes = trainset.classes
        self.trainloader = trainloader
        self.testloader = testloader

    def getloader(self):
        return self.classes, self.trainloader, self.testloader

 

기본적인 dataset, dataloaer class 구조

- dataset

  • __init__ : 필요한 변수들을 선언, 전체 image, label들의 path를 저장
  • __getitem__ : index를 받아 image, label을 return

- dataloader

  • batch 기반의 학습을 위해서 mini batch를 만들어주는 역할. dataset 을 input으로 넣어주면 여러 옵션을 통해 batch를 생성. 

 

1. image 경로가 저장된 train1.txt, test1.txt 파일을 불러들임

(* txt파일 : data 경로를 모두 불러들이고 suffle 해서 일정 비율로 잘라서 train, test 지정)

*예시 

images/brick/brick_002089.jpg
images/brick/brick_000519.jpg
images/brick/brick_000216.jpg

 

2. dataset의 모든 class type을 저장하고 index 부여 (find_classes)

 

3. train.txt, test.txt 파일에서 image 경로를 한줄 씩 읽어들이고 image, label을 페어로 저장 (make_dataset)

txt 파일에서 for문 돌면서 line 한줄씩 
image 1개 경로 저장
image들의 경로, label index들을 페어로 저장

 

4. dataset __getitem__

dataloader의 __getitem__함수가 실행되면 dataset의 index가 넘어오는데, 해당 index의 image path를 PIL image로 불러들이고, transform으로 torch tensor로 변환

__getitem__에서 load한 하나의 data shape

 

5. dataloader

앞서 만든 dataset을 torch.utils.data.DataLoader 의 input으로 넣어주면 됨.

반응형