๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
๐Ÿ’ป Programming/AI & ML

[pytorch] Dataloader์˜ 'collate_fn'์„ ์‚ฌ์šฉํ•œ ์ด๋ฏธ์ง€ ํŒจ๋”ฉ. ๊ฐ€๋ณ€ ์‚ฌ์ด์ฆˆ์˜ ์ด๋ฏธ์ง€๋ฅผ batch๋กœ ๋ฌถ์–ด Dataloader์— ์ฃผ์ž…ํ•˜๋Š” ๋ฐฉ๋ฒ•.

by ๋ญ…์ฆค 2023. 3. 3.
๋ฐ˜์‘ํ˜•

Pytorch์˜ Dataloader๋Š” ์ธ๋ฑ์Šค์— ๋”ฐ๋ฅธ ๋ฐ์ดํ„ฐ๋ฅผ ๋ฐ˜ํ™˜ํ•ด์ฃผ๋Š” dataset, ๊ฐ€์ ธ์˜ฌ ๋ฐ์ดํ„ฐ์˜ ์ธ๋ฑ์Šค๋ฅผ ์ปจํŠธ๋กคํ•˜๋Š” sampler์™€ batch๋กœ ๋ฌถ์ธ ๋ฐ์ดํ„ฐ๋ฅผ batch๋กœ ๋ฌถ์„ ๋•Œ ํ•„์š”ํ•œ ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•˜๋Š” collate_fn ๋“ฑ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๊ฐ€์ง„๋‹ค.

 

๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์„ ํ•™์Šต ๋˜๋Š” ์ธํผ๋Ÿฐ์Šค ํ•˜๋‹ค๋ณด๋ฉด ๊ฐ€๋ณ€ ์‚ฌ์ด์ฆˆ์˜ ๋ฐ์ดํ„ฐ๋ฅผ ๋ชจ๋ธ์— ์ฃผ์ž…ํ•ด์•ผ ํ•  ๊ฒฝ์šฐ๊ฐ€ ์ƒ๊ธฐ๋Š”๋ฐ, ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ์˜ ๊ฒฝ์šฐ ์ผ๋ฐ˜์ ์œผ๋กœ ํŠน์ • ์‚ฌ์ด์ฆˆ(e.g. 224x224)๋กœ ์ด๋ฏธ์ง€๋ฅผ ๋ฆฌ์‚ฌ์ด์ฆˆํ•ด์„œ ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ๊ฐ€ ๋งŽ๋‹ค. ๊ทธ๋ž˜์„œ ์ผ๋ฐ˜์ ์œผ๋กœ ํผ๋ธ”๋ฆญ ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ ๋ณ„ ์ƒ๊ฐ์—†์ด transforms.Resize() ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•ด์„œ ๋ชจ๋“  ๋ฐ์ดํ„ฐ๋ฅผ ์ผ๊ด„๋œ ์‚ฌ์ด์ฆˆ๋กœ ๋ณ€๊ฒฝํ•ด์„œ ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ๊ฐ€ ๋Œ€๋ถ€๋ถ„์ด๋‹ค.

 

ํ•˜์ง€๋งŒ, ์‹ค์ œ ํ™˜๊ฒฝ์—์„œ ์ผ๊ด„๋œ ์ด๋ฏธ์ง€ ๋ฆฌ์‚ฌ์ด์ง•์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ ์ด๋ฏธ์ง€์˜ height, width ๋น„์œจ์ด ํฌ๊ฒŒ ๋ณ€๊ฒฝ๋  ์ˆ˜ ์žˆ๋‹ค. ์ด๋Š” ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ๊ณผ๋„ ์ง๊ฒฐ๋˜๊ธฐ ๋•Œ๋ฌธ์— ์ค‘์š”ํ•œ ๋ฌธ์ œ์ด๋‹ค. Batch๋กœ ๋ฐ์ดํ„ฐ๋ฅผ ๋ฌถ์–ด์„œ ๋ชจ๋ธ์— ํ†ต๊ณผ์‹œํ‚ค๊ธฐ ์œ„ํ•ด์„œ๋Š” ๋ฐ์ดํ„ฐ์˜ ์‚ฌ์ด์ฆˆ๋ฅผ ๋™์ผํ•˜๊ฒŒ ๋งŒ๋“ค์–ด์•ผ ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๊ฐ ์ด๋ฏธ์ง€๋ฅผ ์›๋ณธ ๋น„์œจ๋กœ ๋ชจ๋ธ์— ํ†ต๊ณผ์‹œํ‚ค๊ธฐ ์œ„ํ•ด์„œ๋Š” batch๋ฅผ 1๋กœ ๋งŒ๋“ค์–ด์•ผ ํ•œ๋‹ค. ํ•˜์ง€๋งŒ ์ด ๊ฒฝ์šฐ ๋ชจ๋ธ์˜ ์†๋„๊ฐ€ ๊ต‰์žฅํžˆ ๋–จ์–ด์ง€๊ธฐ ๋•Œ๋ฌธ์— ๋น„ํšจ์œจ์ ์ด๋‹ค. 

 

 

๊ทธ๋ ‡๋‹ค๋ฉด ๊ฐ๊ธฐ ๋‹ค๋ฅธ ๋น„์œจ์˜ ์ด๋ฏธ์ง€๋“ค์„ ๊ณผํ•œ ๋ฆฌ์‚ฌ์ด์ง• ์—†์ด batch๋กœ ๋ฌถ์–ด ๋ชจ๋ธ์— ๋„ฃ์–ด์ฃผ๋Š” ๋ฐฉ๋ฒ•์€ ์—†์„๊นŒ?

 

torch.utils.data. DataLoader ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ ์ค‘ ํ•˜๋‚˜์ธ collate_fn

 

์ด๋ฏธ์ง€ ๋ฆฌ์‚ฌ์ด์ง•๊ณผ ํ•จ๊ป˜ ๋ถ€์กฑํ•œ ํ”ฝ์…€์„ ํŒจ๋”ฉํ•ด์ฃผ๋ฉด ๋˜๋Š”๋ฐ, map-style ๋ฐ์ดํ„ฐ์…‹์˜ ๋ฐ์ดํ„ฐ๋ฅผ batch๋กœ ๋ฌถ์„ ๋•Œ ํ•„์š”ํ•œ ์ „์ฒ˜๋ฆฌ๋ฅผ ์ˆ˜ํ–‰ํ•˜๊ฒŒ ํ•ด์ฃผ๋Š” collate_fn๋ฅผ ์ด์šฉํ•˜๋ฉด ์‰ฝ๊ฒŒ ๋ฆฌ์‚ฌ์ด์ง• + ํŒจ๋”ฉ์„ ์ ์šฉํ•ด์„œ ๋ฐ์ดํ„ฐ ์‚ฌ์ด์ฆˆ๋ฅผ ๋™์ผํ•˜๊ฒŒ ๋งŒ๋“  batch๋ฅผ ๊ตฌ์„ฑํ•  ์ˆ˜ ์žˆ๋‹ค.

 

์ด์™ธ์—๋„ batch๋ฅผ ๊ตฌ์„ฑํ•˜๊ธฐ ์ „์— ์ˆ˜ํ–‰ํ•˜๊ณ  ์‹ถ์€ ์ „์ฒ˜๋ฆฌ๊ฐ€ ์žˆ๋‹ค๋ฉด ์ ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค.

 

 

collate_fn ์‚ฌ์šฉ ์˜ˆ์‹œ

6 ๊ฐœ์˜ ํ…์ŠคํŠธ ์ด๋ฏธ์ง€๋กœ ๋ฐ์ดํ„ฐ์…‹์„ ๊ตฌ์„ฑํ•œ ํ›„ Dataset์˜ transform์„ ์ด์šฉํ•˜์—ฌ ๋ฐ์ดํ„ฐ์˜ ํฌ๊ธฐ๋ฅผ ์ผ์ •ํ•˜๊ฒŒ ๋ฆฌ์‚ฌ์ด์ง•ํ•ด์ฃผ๋Š” ๋ฐฉ๋ฒ•๊ณผ ์ปค์Šคํ…€ collate_fn ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•˜์—ฌ ๋ฆฌ์‚ฌ์ด์ง•๊ณผ ํŒจ๋”ฉ์„ ์ ์šฉํ•œ ์˜ˆ์‹œ๋ฅผ ์‹œ๊ฐํ™”.

 

Dataloader์— ๋“ค์–ด๊ฐ€๋Š” ๋ฐ์ดํ„ฐ ์‹œ๊ฐํ™” (1ํ–‰ : resize ์ ์šฉ, 2ํ–‰ : resize + padding ์ ์šฉ)

 

- 1ํ–‰ : ์„œ๋กœ ๋‹ค๋ฅธ ์‚ฌ์ด์ฆˆ์˜ ์ด๋ฏธ์ง€๋ฅผ ์ผ์ •ํ•œ 64x256 ์˜ ํฌ๊ธฐ๋กœ ๋ฆฌ์‚ฌ์ด์ง•ํ•˜๋Š” transform์„ ์ ์šฉ

    ใ„ด ex) 32x100 -(๋ฆฌ์‚ฌ์ด์ง•)-> 64x256 

- 2ํ–‰ : ์„œ๋กœ ๋‹ค๋ฅธ ์‚ฌ์ด์ฆˆ์˜ ์ด๋ฏธ์ง€๋ฅผ height, width ๋น„์œจ์„ ๊ณ ์ •ํ•œ ์ฑ„๋กœ ๋ฆฌ์‚ฌ์ด์ง• ํ›„ ๋ถ€์กฑํ•œ ํ”ฝ์…€ ๋ถ€๋ถ„์„ padding

    ใ„ด ex) 32x100 -(๋ฆฌ์‚ฌ์ด์ง•)-> 64x200 -(ํŒจ๋”ฉ)-> 64x256

 

*๋ฆฌ์‚ฌ์ด์ง• vs ๋ฆฌ์‚ฌ์ด์ง• + ํŒจ๋”ฉ

์ ์ ˆํ•œ ๋น„์œจ์„ ๊ฐ€์ง„ ํ…์ŠคํŠธ ์ด๋ฏธ์ง€์˜ ๊ฒฝ์šฐ ๋ฆฌ์‚ฌ์ด์ง•ํ•ด๋„ ํ…์ŠคํŠธ๊ฐ€ ์ž˜ ๋ณด์ด์ง€๋งŒ, ์ฒซ ๋ฒˆ์งธ ์ด๋ฏธ์ง€์ธ '๋ง' ํ…์ŠคํŠธ์˜ ๊ฒฝ์šฐ ๊ฐ€๋กœ๋กœ ๋„ˆ๋ฌด ๊ธธ๊ฒŒ ๋ฆฌ์‚ฌ์ด์ง•ํ•˜๋‹ค ๋ณด๋‹ˆ ์ด๋ฏธ์ง€๊ฐ€ ๊ณผํ•˜๊ฒŒ ๋ณ€ํ˜•๋œ ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ๋‹ค. ์ด๋Š” ๋ชจ๋ธ ์ธ์‹ ์„ฑ๋Šฅ์„ ์ €ํ•˜์‹œํ‚ฌ ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ์ ๋‹นํ•œ ๋ฆฌ์‚ฌ์ด์ฆˆ ํ›„ ์•ž์ด๋‚˜ ๋’ท๋ถ€๋ถ„์„ ํŒจ๋”ฉ์œผ๋กœ ์ฒ˜๋ฆฌํ•˜๋Š” ๊ฒƒ์ด ๋” ์ข‹์€ ๊ฒฐ๊ณผ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ๋‹ค.

ํ…์ŠคํŠธ ์ด๋ฏธ์ง€๋ฟ๋งŒ ์•„๋‹ˆ๋ผ, ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์— ์•ค๋“œ ์œ ์ €๊ฐ€ ์ง์ ‘ ์—…๋กœ๋“œํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ์ฃผ์ž…ํ•˜๋Š” ๊ฒฝ์šฐ์— ๋ฐฐ์น˜ ํ˜•ํƒœ๋กœ ๋ฐ์ดํ„ฐ๋ฅผ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์— ์ธํผ๋Ÿฐ์Šคํ•˜๊ธฐ ์œ„ํ•ด ๊ณ ๋ คํ•ด๋ณด๋ฉด ์ข‹์„ ํŠธ๋ฆญ์ด๋‹ค. ์•ค๋“œ ์œ ์ €๊ฐ€ ์—…๋กœ๋“œํ•˜๋Š” ๋ฐ์ดํ„ฐ๋Š” ๊ต‰์žฅํžˆ ๋‹ค์–‘ํ•œ ํ˜•ํƒœ๋ฅผ ๊ฐ€์ง€๋‹ˆ๊นŒ.

 

 

์ฝ”๋“œ (๊นƒํ—ˆ๋ธŒ ๋งํฌ)

- ๋‹จ์ˆœํžˆ ์ด๋ฏธ์ง€๋ฅผ resize ํ•˜๋Š” ๊ฒฝ์šฐ์™€ collate_fn์„ ์ด์šฉํ•˜์—ฌ resize + padding ํ•˜๋Š” ๊ฒฝ์šฐ๋ฅผ ์‹œ๊ฐํ™”ํ•ด์„œ ๋น„๊ตํ•˜๊ธฐ ์œ„ํ•œ ์ฝ”๋“œ

    - data_loader_with_transform : ์ด๋ฏธ์ง€ resize transform์ด ํฌํ•จ๋œ dataloader
    - data_loader_with_collate_fn : ์ด๋ฏธ์ง€ resize + padding ๊ณผ์ •์„ ํฌํ•จํ•œ collate_fn์„ ํฌํ•จํ•œ dataloader

import os
import math
import torch
import glob
from PIL import Image
import cv2
import numpy as np
import torchvision.transforms as transforms
import torchvision
import matplotlib.pyplot as plt

curr_dir = os.path.dirname(__file__)

class ListDataset(torch.utils.data.Dataset):
    def __init__(self, image_list, transforms = None):
        self.image_list = image_list
        self.nSamples = len(image_list)
        self.transforms = transforms

    def __len__(self):
        return self.nSamples

    def __getitem__(self, index):
        img = self.image_list[index]
        img = Image.open(img).convert('RGB')
        if self.transforms is not None :
            img = self.transforms(img)
        label = np.random.randint(0,5, (1)) # ์ž„์˜์˜ ๋žœ๋ค ๋ ˆ์ด๋ธ”
        return img, label

class Image_Pad(object):
    def __init__(self, max_size, PAD_type='right'):
        self.toTensor = transforms.ToTensor()
        self.max_size = max_size
        self.max_width_half = math.floor(max_size[2] / 2)
        self.PAD_type = PAD_type

    def __call__(self, img):
        img = self.toTensor(img)
        c, h, w = img.size()
        Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
        Pad_img[:, :, :w] = img  # right pad
        if self.max_size[2] != w:  # add border Pad
            Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w)

        return Pad_img


class My_Collate(object):
    def __init__(self, imgH=32, imgW=100):
        self.imgH = imgH
        self.imgW = imgW

    def __call__(self, batch):
        batch = filter(lambda x: x is not None, batch)
        images, labels = zip(*batch)

        resized_max_w = self.imgW
        transform = Image_Pad((3, self.imgH, resized_max_w))

        resized_images = []
        for idx, image in enumerate(images):
            print(f'{idx} ๋ฒˆ์งธ ๋ฐ์ดํ„ฐ shape :', np.array(image).shape)
            w, h = image.size
            ratio = w / float(h)
            if math.ceil(self.imgH * ratio) > self.imgW:
                resized_w = self.imgW
            else:
                resized_w = math.ceil(self.imgH * ratio)

            resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)

            transformed_image = transform(resized_image)
            resized_images.append(transformed_image)

        image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)

        return image_tensors, labels



if __name__ == '__main__':
    imgH = 64
    imgW = 256

    image_list = glob.glob(os.path.join(os.path.dirname(__file__), '*.jpg'))

    transform = transforms.Compose([
            transforms.Resize((64,256)),
            transforms.ToTensor(),
        ])
    My_collate = My_Collate(imgH=imgH, imgW=imgW)
    
    dataset_with_transform = ListDataset(image_list, transforms = transform)
    dataset = ListDataset(image_list)

    data_loader_with_transform = torch.utils.data.DataLoader(dataset_with_transform, batch_size=len(image_list), shuffle=False)
    data_loader_with_collate_fn = torch.utils.data.DataLoader(dataset, batch_size=len(image_list), shuffle=False, collate_fn=My_collate)
    
    data_with_transform = next(iter(data_loader_with_transform))
    data_with_collate_fn = next(iter(data_loader_with_collate_fn))

    data = torch.vstack([data_with_transform[0], data_with_collate_fn[0]])

    grid = torchvision.utils.make_grid(data, nrow = len(image_list))
    plt.imshow(grid.permute(1,2,0))
    plt.show()

 

 

๋ฐ˜์‘ํ˜•