Pytorch์ Dataloader๋ ์ธ๋ฑ์ค์ ๋ฐ๋ฅธ ๋ฐ์ดํฐ๋ฅผ ๋ฐํํด์ฃผ๋ dataset, ๊ฐ์ ธ์ฌ ๋ฐ์ดํฐ์ ์ธ๋ฑ์ค๋ฅผ ์ปจํธ๋กคํ๋ sampler์ batch๋ก ๋ฌถ์ธ ๋ฐ์ดํฐ๋ฅผ batch๋ก ๋ฌถ์ ๋ ํ์ํ ํจ์๋ฅผ ์ ์ํ๋ collate_fn ๋ฑ์ ํ๋ผ๋ฏธํฐ๋ฅผ ๊ฐ์ง๋ค.
๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ํ์ต ๋๋ ์ธํผ๋ฐ์ค ํ๋ค๋ณด๋ฉด ๊ฐ๋ณ ์ฌ์ด์ฆ์ ๋ฐ์ดํฐ๋ฅผ ๋ชจ๋ธ์ ์ฃผ์ ํด์ผ ํ ๊ฒฝ์ฐ๊ฐ ์๊ธฐ๋๋ฐ, ์ด๋ฏธ์ง ๋ฐ์ดํฐ์ ๊ฒฝ์ฐ ์ผ๋ฐ์ ์ผ๋ก ํน์ ์ฌ์ด์ฆ(e.g. 224x224)๋ก ์ด๋ฏธ์ง๋ฅผ ๋ฆฌ์ฌ์ด์ฆํด์ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ๊ฐ ๋ง๋ค. ๊ทธ๋์ ์ผ๋ฐ์ ์ผ๋ก ํผ๋ธ๋ฆญ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ ๋ณ ์๊ฐ์์ด transforms.Resize() ํจ์๋ฅผ ์ฌ์ฉํด์ ๋ชจ๋ ๋ฐ์ดํฐ๋ฅผ ์ผ๊ด๋ ์ฌ์ด์ฆ๋ก ๋ณ๊ฒฝํด์ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ๊ฐ ๋๋ถ๋ถ์ด๋ค.
ํ์ง๋ง, ์ค์ ํ๊ฒฝ์์ ์ผ๊ด๋ ์ด๋ฏธ์ง ๋ฆฌ์ฌ์ด์ง์ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ ์ด๋ฏธ์ง์ height, width ๋น์จ์ด ํฌ๊ฒ ๋ณ๊ฒฝ๋ ์ ์๋ค. ์ด๋ ๋ชจ๋ธ์ ์ฑ๋ฅ๊ณผ๋ ์ง๊ฒฐ๋๊ธฐ ๋๋ฌธ์ ์ค์ํ ๋ฌธ์ ์ด๋ค. Batch๋ก ๋ฐ์ดํฐ๋ฅผ ๋ฌถ์ด์ ๋ชจ๋ธ์ ํต๊ณผ์ํค๊ธฐ ์ํด์๋ ๋ฐ์ดํฐ์ ์ฌ์ด์ฆ๋ฅผ ๋์ผํ๊ฒ ๋ง๋ค์ด์ผ ํ๊ธฐ ๋๋ฌธ์ ๊ฐ ์ด๋ฏธ์ง๋ฅผ ์๋ณธ ๋น์จ๋ก ๋ชจ๋ธ์ ํต๊ณผ์ํค๊ธฐ ์ํด์๋ batch๋ฅผ 1๋ก ๋ง๋ค์ด์ผ ํ๋ค. ํ์ง๋ง ์ด ๊ฒฝ์ฐ ๋ชจ๋ธ์ ์๋๊ฐ ๊ต์ฅํ ๋จ์ด์ง๊ธฐ ๋๋ฌธ์ ๋นํจ์จ์ ์ด๋ค.
๊ทธ๋ ๋ค๋ฉด ๊ฐ๊ธฐ ๋ค๋ฅธ ๋น์จ์ ์ด๋ฏธ์ง๋ค์ ๊ณผํ ๋ฆฌ์ฌ์ด์ง ์์ด batch๋ก ๋ฌถ์ด ๋ชจ๋ธ์ ๋ฃ์ด์ฃผ๋ ๋ฐฉ๋ฒ์ ์์๊น?
์ด๋ฏธ์ง ๋ฆฌ์ฌ์ด์ง๊ณผ ํจ๊ป ๋ถ์กฑํ ํฝ์ ์ ํจ๋ฉํด์ฃผ๋ฉด ๋๋๋ฐ, map-style ๋ฐ์ดํฐ์ ์ ๋ฐ์ดํฐ๋ฅผ batch๋ก ๋ฌถ์ ๋ ํ์ํ ์ ์ฒ๋ฆฌ๋ฅผ ์ํํ๊ฒ ํด์ฃผ๋ collate_fn๋ฅผ ์ด์ฉํ๋ฉด ์ฝ๊ฒ ๋ฆฌ์ฌ์ด์ง + ํจ๋ฉ์ ์ ์ฉํด์ ๋ฐ์ดํฐ ์ฌ์ด์ฆ๋ฅผ ๋์ผํ๊ฒ ๋ง๋ batch๋ฅผ ๊ตฌ์ฑํ ์ ์๋ค.
์ด์ธ์๋ batch๋ฅผ ๊ตฌ์ฑํ๊ธฐ ์ ์ ์ํํ๊ณ ์ถ์ ์ ์ฒ๋ฆฌ๊ฐ ์๋ค๋ฉด ์ ์ฉํ ์ ์๋ค.
collate_fn ์ฌ์ฉ ์์
6 ๊ฐ์ ํ ์คํธ ์ด๋ฏธ์ง๋ก ๋ฐ์ดํฐ์ ์ ๊ตฌ์ฑํ ํ Dataset์ transform์ ์ด์ฉํ์ฌ ๋ฐ์ดํฐ์ ํฌ๊ธฐ๋ฅผ ์ผ์ ํ๊ฒ ๋ฆฌ์ฌ์ด์งํด์ฃผ๋ ๋ฐฉ๋ฒ๊ณผ ์ปค์คํ collate_fn ํจ์๋ฅผ ์ ์ํ์ฌ ๋ฆฌ์ฌ์ด์ง๊ณผ ํจ๋ฉ์ ์ ์ฉํ ์์๋ฅผ ์๊ฐํ.
- 1ํ : ์๋ก ๋ค๋ฅธ ์ฌ์ด์ฆ์ ์ด๋ฏธ์ง๋ฅผ ์ผ์ ํ 64x256 ์ ํฌ๊ธฐ๋ก ๋ฆฌ์ฌ์ด์งํ๋ transform์ ์ ์ฉ
ใด ex) 32x100 -(๋ฆฌ์ฌ์ด์ง)-> 64x256
- 2ํ : ์๋ก ๋ค๋ฅธ ์ฌ์ด์ฆ์ ์ด๋ฏธ์ง๋ฅผ height, width ๋น์จ์ ๊ณ ์ ํ ์ฑ๋ก ๋ฆฌ์ฌ์ด์ง ํ ๋ถ์กฑํ ํฝ์ ๋ถ๋ถ์ padding
ใด ex) 32x100 -(๋ฆฌ์ฌ์ด์ง)-> 64x200 -(ํจ๋ฉ)-> 64x256
*๋ฆฌ์ฌ์ด์ง vs ๋ฆฌ์ฌ์ด์ง + ํจ๋ฉ
์ ์ ํ ๋น์จ์ ๊ฐ์ง ํ ์คํธ ์ด๋ฏธ์ง์ ๊ฒฝ์ฐ ๋ฆฌ์ฌ์ด์งํด๋ ํ ์คํธ๊ฐ ์ ๋ณด์ด์ง๋ง, ์ฒซ ๋ฒ์งธ ์ด๋ฏธ์ง์ธ '๋ง' ํ ์คํธ์ ๊ฒฝ์ฐ ๊ฐ๋ก๋ก ๋๋ฌด ๊ธธ๊ฒ ๋ฆฌ์ฌ์ด์งํ๋ค ๋ณด๋ ์ด๋ฏธ์ง๊ฐ ๊ณผํ๊ฒ ๋ณํ๋ ๊ฒ์ ๋ณผ ์ ์๋ค. ์ด๋ ๋ชจ๋ธ ์ธ์ ์ฑ๋ฅ์ ์ ํ์ํฌ ์ ์๊ธฐ ๋๋ฌธ์ ์ ๋นํ ๋ฆฌ์ฌ์ด์ฆ ํ ์์ด๋ ๋ท๋ถ๋ถ์ ํจ๋ฉ์ผ๋ก ์ฒ๋ฆฌํ๋ ๊ฒ์ด ๋ ์ข์ ๊ฒฐ๊ณผ๋ฅผ ์ป์ ์ ์๋ค.
ํ ์คํธ ์ด๋ฏธ์ง๋ฟ๋ง ์๋๋ผ, ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ์ค๋ ์ ์ ๊ฐ ์ง์ ์ ๋ก๋ํ ๋ฐ์ดํฐ๋ฅผ ์ฃผ์ ํ๋ ๊ฒฝ์ฐ์ ๋ฐฐ์น ํํ๋ก ๋ฐ์ดํฐ๋ฅผ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ์ธํผ๋ฐ์คํ๊ธฐ ์ํด ๊ณ ๋ คํด๋ณด๋ฉด ์ข์ ํธ๋ฆญ์ด๋ค. ์ค๋ ์ ์ ๊ฐ ์ ๋ก๋ํ๋ ๋ฐ์ดํฐ๋ ๊ต์ฅํ ๋ค์ํ ํํ๋ฅผ ๊ฐ์ง๋๊น.
์ฝ๋ (๊นํ๋ธ ๋งํฌ)
- ๋จ์ํ ์ด๋ฏธ์ง๋ฅผ resize ํ๋ ๊ฒฝ์ฐ์ collate_fn์ ์ด์ฉํ์ฌ resize + padding ํ๋ ๊ฒฝ์ฐ๋ฅผ ์๊ฐํํด์ ๋น๊ตํ๊ธฐ ์ํ ์ฝ๋
- data_loader_with_transform : ์ด๋ฏธ์ง resize transform์ด ํฌํจ๋ dataloader
- data_loader_with_collate_fn : ์ด๋ฏธ์ง resize + padding ๊ณผ์ ์ ํฌํจํ collate_fn์ ํฌํจํ dataloader
import os
import math
import torch
import glob
from PIL import Image
import cv2
import numpy as np
import torchvision.transforms as transforms
import torchvision
import matplotlib.pyplot as plt
curr_dir = os.path.dirname(__file__)
class ListDataset(torch.utils.data.Dataset):
def __init__(self, image_list, transforms = None):
self.image_list = image_list
self.nSamples = len(image_list)
self.transforms = transforms
def __len__(self):
return self.nSamples
def __getitem__(self, index):
img = self.image_list[index]
img = Image.open(img).convert('RGB')
if self.transforms is not None :
img = self.transforms(img)
label = np.random.randint(0,5, (1)) # ์์์ ๋๋ค ๋ ์ด๋ธ
return img, label
class Image_Pad(object):
def __init__(self, max_size, PAD_type='right'):
self.toTensor = transforms.ToTensor()
self.max_size = max_size
self.max_width_half = math.floor(max_size[2] / 2)
self.PAD_type = PAD_type
def __call__(self, img):
img = self.toTensor(img)
c, h, w = img.size()
Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
Pad_img[:, :, :w] = img # right pad
if self.max_size[2] != w: # add border Pad
Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w)
return Pad_img
class My_Collate(object):
def __init__(self, imgH=32, imgW=100):
self.imgH = imgH
self.imgW = imgW
def __call__(self, batch):
batch = filter(lambda x: x is not None, batch)
images, labels = zip(*batch)
resized_max_w = self.imgW
transform = Image_Pad((3, self.imgH, resized_max_w))
resized_images = []
for idx, image in enumerate(images):
print(f'{idx} ๋ฒ์งธ ๋ฐ์ดํฐ shape :', np.array(image).shape)
w, h = image.size
ratio = w / float(h)
if math.ceil(self.imgH * ratio) > self.imgW:
resized_w = self.imgW
else:
resized_w = math.ceil(self.imgH * ratio)
resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
transformed_image = transform(resized_image)
resized_images.append(transformed_image)
image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)
return image_tensors, labels
if __name__ == '__main__':
imgH = 64
imgW = 256
image_list = glob.glob(os.path.join(os.path.dirname(__file__), '*.jpg'))
transform = transforms.Compose([
transforms.Resize((64,256)),
transforms.ToTensor(),
])
My_collate = My_Collate(imgH=imgH, imgW=imgW)
dataset_with_transform = ListDataset(image_list, transforms = transform)
dataset = ListDataset(image_list)
data_loader_with_transform = torch.utils.data.DataLoader(dataset_with_transform, batch_size=len(image_list), shuffle=False)
data_loader_with_collate_fn = torch.utils.data.DataLoader(dataset, batch_size=len(image_list), shuffle=False, collate_fn=My_collate)
data_with_transform = next(iter(data_loader_with_transform))
data_with_collate_fn = next(iter(data_loader_with_collate_fn))
data = torch.vstack([data_with_transform[0], data_with_collate_fn[0]])
grid = torchvision.utils.make_grid(data, nrow = len(image_list))
plt.imshow(grid.permute(1,2,0))
plt.show()