반응형
* dataset 폴더 구조
minc2500
├─images
│ ├─brick
│ │ ├─brick_000000.jpg
│ │ ├─brick_000001.jpg
│ │ ├─...
│ ├─carpet
│ │ ├─carpet_000000.jpg
│ │ ├─...
│ ├─...
│ │ ├─...
│ │ ├─...
... ... ...
├─labels
│ ├─train1.txt
│ ├─train2.txt
│ ├─...
│ ├─test1.txt
│ ├─test2.txt
│ └─...
import os
import os.path
import torch
import torch.utils.data as data
from PIL import Image
from torchvision import transforms
import numpy as np
# dataset의 class 이름, class index 저장
def find_classes(dir):
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
# txt 파일로 저장된 train, test data의 경로를 읽어와서 image와 label을 페어로 저장
def make_dataset(txtnames, datadir, class_to_idx):
images = []
labels = []
for txtname in txtnames:
with open(txtname, 'r') as lines:
for line in lines:
classname = line.split('/')[1]
_img = os.path.join(datadir, line.strip())
assert os.path.isfile(_img)
images.append(_img)
labels.append(class_to_idx[classname])
return images, labels
class MINCDataset(data.Dataset):
def __init__(self, args, transform=None, train=True):
classes, class_to_idx = find_classes(os.path.join(config.dataset_path, 'images'))
self.classes = classes
self.class_to_idx = class_to_idx
self.train = train
self.transform = transform
# train, test data의 경로가 저장된 txt 파일을 불러옴
if train:
filename = [os.path.join(config.dataset_path, 'labels/train' + args.split + '.txt'),
os.path.join(config.dataset_path, 'labels/validate' + args.split + '.txt')]
else:
filename = [os.path.join(config.dataset_path, 'labels/test' + args.split + '.txt')]
self.images, self.labels = make_dataset(filename, args.dataset_path, class_to_idx)
assert (len(self.images) == len(self.labels))
def __getitem__(self, index): # 경로로 저장된 image를 PIL image로 불러옴
_img = Image.open(self.images[index]).convert('RGB')
_label = self.labels[index]
if self.transform is not None: # image에 지정된 transform 수행(tensor 변환 포함)
_img = self.transform(_img)
return _img, _label
def __len__(self):
return len(self.images)
class Dataloder():
def __init__(self, args):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
transform_train = transforms.Compose([
transforms.Resize(256),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
normalize,
])
transform_test = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
trainset = MINCDataset(args, transform_train, train=True)
testset = MINCDataset(args, transform_test, train=False)
kwargs = {'num_workers': 8, 'pin_memory': True}
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, **kwargs)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, **kwargs)
self.classes = trainset.classes
self.trainloader = trainloader
self.testloader = testloader
def getloader(self):
return self.classes, self.trainloader, self.testloader
기본적인 dataset, dataloaer class 구조
- dataset
- __init__ : 필요한 변수들을 선언, 전체 image, label들의 path를 저장
- __getitem__ : index를 받아 image, label을 return
- dataloader
- batch 기반의 학습을 위해서 mini batch를 만들어주는 역할. dataset 을 input으로 넣어주면 여러 옵션을 통해 batch를 생성.
1. image 경로가 저장된 train1.txt, test1.txt 파일을 불러들임
(* txt파일 : data 경로를 모두 불러들이고 suffle 해서 일정 비율로 잘라서 train, test 지정)
*예시
images/brick/brick_002089.jpg
images/brick/brick_000519.jpg
images/brick/brick_000216.jpg
2. dataset의 모든 class type을 저장하고 index 부여 (find_classes)
3. train.txt, test.txt 파일에서 image 경로를 한줄 씩 읽어들이고 image, label을 페어로 저장 (make_dataset)
4. dataset __getitem__
dataloader의 __getitem__함수가 실행되면 dataset의 index가 넘어오는데, 해당 index의 image path를 PIL image로 불러들이고, transform으로 torch tensor로 변환
5. dataloader
앞서 만든 dataset을 torch.utils.data.DataLoader 의 input으로 넣어주면 됨.
반응형
'💻 Programming > AI & ML' 카테고리의 다른 글
[pytorch] pytorch 모델 로드 중 Missing key(s) in state_dict 에러 (0) | 2022.12.15 |
---|---|
[pytorch] COCO Data Format 전용 Custom Dataset 생성 (1) | 2022.06.04 |
[pytorch] model 에 접근하기, 특정 layer 변경하기 (0) | 2022.01.05 |
[pytorch] DataParallel 로 학습한 모델 load (0) | 2021.02.17 |
[pytorch] 모델의 일부 레이어 웨이트 업데이트 막기 | model freezing (모델 프리징) (0) | 2021.02.17 |