[pytorch] Custom dataset, dataloader 만들기

2022. 1. 2. 16:28·💻 Programming/AI & ML
반응형

* dataset 폴더 구조 

minc2500
├─images
│        ├─brick
│        │     ├─brick_000000.jpg
│        │     ├─brick_000001.jpg
│        │     ├─...
│        ├─carpet
│        │     ├─carpet_000000.jpg
│        │     ├─...
│        ├─...
│        │     ├─...
│        │     ├─...
...         ...            ...
├─labels
│        ├─train1.txt
│        ├─train2.txt
│        ├─...
│        ├─test1.txt
│        ├─test2.txt
│        └─...

 

 

import os
import os.path
import torch
import torch.utils.data as data
from PIL import Image
from torchvision import transforms
import numpy as np

# dataset의 class 이름, class index 저장
def find_classes(dir):
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx

# txt 파일로 저장된 train, test data의 경로를 읽어와서 image와 label을 페어로 저장
def make_dataset(txtnames, datadir, class_to_idx):
    images = []
    labels = []
    for txtname in txtnames:
        with open(txtname, 'r') as lines:
            for line in lines:
                classname = line.split('/')[1]
                _img = os.path.join(datadir, line.strip())
                assert os.path.isfile(_img)
                images.append(_img)
                labels.append(class_to_idx[classname])

    return images, labels


class MINCDataset(data.Dataset):
    def __init__(self, args, transform=None, train=True):
        classes, class_to_idx = find_classes(os.path.join(config.dataset_path, 'images'))
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.train = train
        self.transform = transform
		
        # train, test data의 경로가 저장된 txt 파일을 불러옴
        if train:
            filename = [os.path.join(config.dataset_path, 'labels/train' + args.split + '.txt'),
                        os.path.join(config.dataset_path, 'labels/validate' + args.split + '.txt')]
        else:
            filename = [os.path.join(config.dataset_path, 'labels/test' + args.split + '.txt')]

        self.images, self.labels = make_dataset(filename, args.dataset_path, class_to_idx)
        assert (len(self.images) == len(self.labels))

    def __getitem__(self, index): # 경로로 저장된 image를 PIL image로 불러옴
        _img = Image.open(self.images[index]).convert('RGB')
        _label = self.labels[index]
        if self.transform is not None:	# image에 지정된 transform 수행(tensor 변환 포함)
            _img = self.transform(_img)

        return _img, _label

    def __len__(self):
        return len(self.images)


class Dataloder():
    def __init__(self, args):
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        transform_train = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])

        trainset = MINCDataset(args, transform_train, train=True)
        testset = MINCDataset(args, transform_test, train=False)

        kwargs = {'num_workers': 8, 'pin_memory': True}
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, **kwargs)
        testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, **kwargs)
        self.classes = trainset.classes
        self.trainloader = trainloader
        self.testloader = testloader

    def getloader(self):
        return self.classes, self.trainloader, self.testloader

 

기본적인 dataset, dataloaer class 구조

- dataset

  • __init__ : 필요한 변수들을 선언, 전체 image, label들의 path를 저장
  • __getitem__ : index를 받아 image, label을 return

- dataloader

  • batch 기반의 학습을 위해서 mini batch를 만들어주는 역할. dataset 을 input으로 넣어주면 여러 옵션을 통해 batch를 생성. 

 

1. image 경로가 저장된 train1.txt, test1.txt 파일을 불러들임

(* txt파일 : data 경로를 모두 불러들이고 suffle 해서 일정 비율로 잘라서 train, test 지정)

*예시 

images/brick/brick_002089.jpg
images/brick/brick_000519.jpg
images/brick/brick_000216.jpg

 

2. dataset의 모든 class type을 저장하고 index 부여 (find_classes)

 

3. train.txt, test.txt 파일에서 image 경로를 한줄 씩 읽어들이고 image, label을 페어로 저장 (make_dataset)

txt 파일에서 for문 돌면서 line 한줄씩 
image 1개 경로 저장
image들의 경로, label index들을 페어로 저장

 

4. dataset __getitem__

dataloader의 __getitem__함수가 실행되면 dataset의 index가 넘어오는데, 해당 index의 image path를 PIL image로 불러들이고, transform으로 torch tensor로 변환

__getitem__에서 load한 하나의 data shape

 

5. dataloader

앞서 만든 dataset을 torch.utils.data.DataLoader 의 input으로 넣어주면 됨.

반응형

'💻 Programming > AI & ML' 카테고리의 다른 글

[pytorch] pytorch 모델 로드 중 Missing key(s) in state_dict 에러  (0) 2022.12.15
[pytorch] COCO Data Format 전용 Custom Dataset 생성  (1) 2022.06.04
[pytorch] model 에 접근하기, 특정 layer 변경하기  (0) 2022.01.05
[pytorch] DataParallel 로 학습한 모델 load  (0) 2021.02.17
[pytorch] 모델의 일부 레이어 웨이트 업데이트 막기 | model freezing (모델 프리징)  (0) 2021.02.17
'💻 Programming/AI & ML' 카테고리의 다른 글
  • [pytorch] COCO Data Format 전용 Custom Dataset 생성
  • [pytorch] model 에 접근하기, 특정 layer 변경하기
  • [pytorch] DataParallel 로 학습한 모델 load
  • [pytorch] 모델의 일부 레이어 웨이트 업데이트 막기 | model freezing (모델 프리징)
뭅즤
뭅즤
AI 기술 블로그
    반응형
  • 뭅즤
    CV DOODLE
    뭅즤
  • 전체
    오늘
    어제
  • 공지사항

    • ✨ About Me
    • 분류 전체보기 (202)
      • 📖 Fundamentals (33)
        • Computer Vision (9)
        • 3D vision & Graphics (6)
        • AI & ML (15)
        • NLP (2)
        • etc. (1)
      • 🏛 Research (67)
        • Deep Learning (7)
        • Image Classification (2)
        • Detection & Segmentation (17)
        • OCR (7)
        • Multi-modal (4)
        • Generative AI (8)
        • 3D Vision (3)
        • Material & Texture Recognit.. (8)
        • NLP & LLM (11)
        • etc. (0)
      • 🛠️ Engineering (7)
        • Distributed Training (4)
        • AI & ML 인사이트 (3)
      • 💻 Programming (86)
        • Python (18)
        • Computer Vision (12)
        • LLM (4)
        • AI & ML (18)
        • Database (3)
        • Apache Airflow (6)
        • Docker & Kubernetes (14)
        • 코딩 테스트 (4)
        • C++ (1)
        • etc. (6)
      • 💬 ETC (3)
        • 책 리뷰 (3)
  • 링크

  • 인기 글

  • 태그

    프롬프트엔지니어링
    객체 검출
    multi-modal
    ChatGPT
    VLP
    generative ai
    OpenAI
    ml
    Text recognition
    OpenCV
    OCR
    nlp
    deep learning
    Computer Vision
    CNN
    material recognition
    AI
    object detection
    pandas
    도커
    airflow
    LLM
    객체검출
    pytorch
    segmentation
    파이썬
    딥러닝
    컴퓨터비전
    3D Vision
    Python
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
뭅즤
[pytorch] Custom dataset, dataloader 만들기
상단으로

티스토리툴바