[pytorch] Custom dataset, dataloader 만들기

2022. 1. 2. 16:28·💻 Programming/AI & ML
반응형

* dataset 폴더 구조 

minc2500
├─images
│        ├─brick
│        │     ├─brick_000000.jpg
│        │     ├─brick_000001.jpg
│        │     ├─...
│        ├─carpet
│        │     ├─carpet_000000.jpg
│        │     ├─...
│        ├─...
│        │     ├─...
│        │     ├─...
...         ...            ...
├─labels
│        ├─train1.txt
│        ├─train2.txt
│        ├─...
│        ├─test1.txt
│        ├─test2.txt
│        └─...

 

 

import os
import os.path
import torch
import torch.utils.data as data
from PIL import Image
from torchvision import transforms
import numpy as np

# dataset의 class 이름, class index 저장
def find_classes(dir):
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx

# txt 파일로 저장된 train, test data의 경로를 읽어와서 image와 label을 페어로 저장
def make_dataset(txtnames, datadir, class_to_idx):
    images = []
    labels = []
    for txtname in txtnames:
        with open(txtname, 'r') as lines:
            for line in lines:
                classname = line.split('/')[1]
                _img = os.path.join(datadir, line.strip())
                assert os.path.isfile(_img)
                images.append(_img)
                labels.append(class_to_idx[classname])

    return images, labels


class MINCDataset(data.Dataset):
    def __init__(self, args, transform=None, train=True):
        classes, class_to_idx = find_classes(os.path.join(config.dataset_path, 'images'))
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.train = train
        self.transform = transform
		
        # train, test data의 경로가 저장된 txt 파일을 불러옴
        if train:
            filename = [os.path.join(config.dataset_path, 'labels/train' + args.split + '.txt'),
                        os.path.join(config.dataset_path, 'labels/validate' + args.split + '.txt')]
        else:
            filename = [os.path.join(config.dataset_path, 'labels/test' + args.split + '.txt')]

        self.images, self.labels = make_dataset(filename, args.dataset_path, class_to_idx)
        assert (len(self.images) == len(self.labels))

    def __getitem__(self, index): # 경로로 저장된 image를 PIL image로 불러옴
        _img = Image.open(self.images[index]).convert('RGB')
        _label = self.labels[index]
        if self.transform is not None:	# image에 지정된 transform 수행(tensor 변환 포함)
            _img = self.transform(_img)

        return _img, _label

    def __len__(self):
        return len(self.images)


class Dataloder():
    def __init__(self, args):
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        transform_train = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])

        trainset = MINCDataset(args, transform_train, train=True)
        testset = MINCDataset(args, transform_test, train=False)

        kwargs = {'num_workers': 8, 'pin_memory': True}
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, **kwargs)
        testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, **kwargs)
        self.classes = trainset.classes
        self.trainloader = trainloader
        self.testloader = testloader

    def getloader(self):
        return self.classes, self.trainloader, self.testloader

 

기본적인 dataset, dataloaer class 구조

- dataset

  • __init__ : 필요한 변수들을 선언, 전체 image, label들의 path를 저장
  • __getitem__ : index를 받아 image, label을 return

- dataloader

  • batch 기반의 학습을 위해서 mini batch를 만들어주는 역할. dataset 을 input으로 넣어주면 여러 옵션을 통해 batch를 생성. 

 

1. image 경로가 저장된 train1.txt, test1.txt 파일을 불러들임

(* txt파일 : data 경로를 모두 불러들이고 suffle 해서 일정 비율로 잘라서 train, test 지정)

*예시 

images/brick/brick_002089.jpg
images/brick/brick_000519.jpg
images/brick/brick_000216.jpg

 

2. dataset의 모든 class type을 저장하고 index 부여 (find_classes)

 

3. train.txt, test.txt 파일에서 image 경로를 한줄 씩 읽어들이고 image, label을 페어로 저장 (make_dataset)

txt 파일에서 for문 돌면서 line 한줄씩 
image 1개 경로 저장
image들의 경로, label index들을 페어로 저장

 

4. dataset __getitem__

dataloader의 __getitem__함수가 실행되면 dataset의 index가 넘어오는데, 해당 index의 image path를 PIL image로 불러들이고, transform으로 torch tensor로 변환

__getitem__에서 load한 하나의 data shape

 

5. dataloader

앞서 만든 dataset을 torch.utils.data.DataLoader 의 input으로 넣어주면 됨.

반응형

'💻 Programming > AI & ML' 카테고리의 다른 글

[pytorch] pytorch 모델 로드 중 Missing key(s) in state_dict 에러  (0) 2022.12.15
[pytorch] COCO Data Format 전용 Custom Dataset 생성  (1) 2022.06.04
[pytorch] model 에 접근하기, 특정 layer 변경하기  (0) 2022.01.05
[pytorch] DataParallel 로 학습한 모델 load  (0) 2021.02.17
[pytorch] 모델의 일부 레이어 웨이트 업데이트 막기 | model freezing (모델 프리징)  (0) 2021.02.17
'💻 Programming/AI & ML' 카테고리의 다른 글
  • [pytorch] COCO Data Format 전용 Custom Dataset 생성
  • [pytorch] model 에 접근하기, 특정 layer 변경하기
  • [pytorch] DataParallel 로 학습한 모델 load
  • [pytorch] 모델의 일부 레이어 웨이트 업데이트 막기 | model freezing (모델 프리징)
뭅즤
뭅즤
AI 기술 블로그
    반응형
  • 뭅즤
    moovzi’s Doodle
    뭅즤
  • 전체
    오늘
    어제
  • 공지사항

    • ✨ About Me
    • 분류 전체보기 (213)
      • 📖 Fundamentals (34)
        • Computer Vision (9)
        • 3D vision & Graphics (6)
        • AI & ML (16)
        • NLP (2)
        • etc. (1)
      • 🏛 Research (75)
        • Deep Learning (7)
        • Perception (19)
        • OCR (7)
        • Multi-modal (5)
        • Image•Video Generation (18)
        • 3D Vision (4)
        • Material • Texture Recognit.. (8)
        • Large-scale Model (7)
        • etc. (0)
      • 🛠️ Engineering (8)
        • Distributed Training & Infe.. (5)
        • AI & ML 인사이트 (3)
      • 💻 Programming (92)
        • Python (18)
        • Computer Vision (12)
        • LLM (4)
        • AI & ML (18)
        • Database (3)
        • Distributed Computing (6)
        • Apache Airflow (6)
        • Docker & Kubernetes (14)
        • 코딩 테스트 (4)
        • C++ (1)
        • etc. (6)
      • 💬 ETC (4)
        • 책 리뷰 (4)
  • 링크

    • 리틀리 프로필 (멘토링, 면접책,...)
    • 『나는 AI 엔지니어입니다』
    • Instagram
    • Brunch
    • Github
  • 인기 글

  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
뭅즤
[pytorch] Custom dataset, dataloader 만들기
상단으로

티스토리툴바