💻 Programming/AI & ML

[pytorch] Dataloader의 'collate_fn'을 사용한 이미지 패딩. 가변 사이즈의 이미지를 batch로 묶어 Dataloader에 주입하는 방법.

뭅즤 2023. 3. 3. 11:17
반응형

Pytorch의 Dataloader는 인덱스에 따른 데이터를 반환해주는 dataset, 가져올 데이터의 인덱스를 컨트롤하는 sampler와 batch로 묶인 데이터를 batch로 묶을 때 필요한 함수를 정의하는 collate_fn 등의 파라미터를 가진다.

 

딥러닝 모델을 학습 또는 인퍼런스 하다보면 가변 사이즈의 데이터를 모델에 주입해야 할 경우가 생기는데, 이미지 데이터의 경우 일반적으로 특정 사이즈(e.g. 224x224)로 이미지를 리사이즈해서 사용하는 경우가 많다. 그래서 일반적으로 퍼블릭 데이터를 사용하는 경우 별 생각없이 transforms.Resize() 함수를 사용해서 모든 데이터를 일괄된 사이즈로 변경해서 사용하는 경우가 대부분이다.

 

하지만, 실제 환경에서 일괄된 이미지 리사이징을 사용하는 경우 이미지의 height, width 비율이 크게 변경될 수 있다. 이는 모델의 성능과도 직결되기 때문에 중요한 문제이다. Batch로 데이터를 묶어서 모델에 통과시키기 위해서는 데이터의 사이즈를 동일하게 만들어야 하기 때문에 각 이미지를 원본 비율로 모델에 통과시키기 위해서는 batch를 1로 만들어야 한다. 하지만 이 경우 모델의 속도가 굉장히 떨어지기 때문에 비효율적이다. 

 

 

그렇다면 각기 다른 비율의 이미지들을 과한 리사이징 없이 batch로 묶어 모델에 넣어주는 방법은 없을까?

 

torch.utils.data. DataLoader 의 파라미터 중 하나인 collate_fn

 

이미지 리사이징과 함께 부족한 픽셀을 패딩해주면 되는데, map-style 데이터셋의 데이터를 batch로 묶을 때 필요한 전처리를 수행하게 해주는 collate_fn를 이용하면 쉽게 리사이징 + 패딩을 적용해서 데이터 사이즈를 동일하게 만든 batch를 구성할 수 있다.

 

이외에도 batch를 구성하기 전에 수행하고 싶은 전처리가 있다면 적용할 수 있다.

 

 

collate_fn 사용 예시

6 개의 텍스트 이미지로 데이터셋을 구성한 후 Dataset의 transform을 이용하여 데이터의 크기를 일정하게 리사이징해주는 방법과 커스텀 collate_fn 함수를 정의하여 리사이징과 패딩을 적용한 예시를 시각화.

 

Dataloader에 들어가는 데이터 시각화 (1행 : resize 적용, 2행 : resize + padding 적용)

 

- 1행 : 서로 다른 사이즈의 이미지를 일정한 64x256 의 크기로 리사이징하는 transform을 적용

    ㄴ ex) 32x100 -(리사이징)-> 64x256 

- 2행 : 서로 다른 사이즈의 이미지를 height, width 비율을 고정한 채로 리사이징 후 부족한 픽셀 부분을 padding

    ㄴ ex) 32x100 -(리사이징)-> 64x200 -(패딩)-> 64x256

 

*리사이징 vs 리사이징 + 패딩

적절한 비율을 가진 텍스트 이미지의 경우 리사이징해도 텍스트가 잘 보이지만, 첫 번째 이미지인 '말' 텍스트의 경우 가로로 너무 길게 리사이징하다 보니 이미지가 과하게 변형된 것을 볼 수 있다. 이는 모델 인식 성능을 저하시킬 수 있기 때문에 적당한 리사이즈 후 앞이나 뒷부분을 패딩으로 처리하는 것이 더 좋은 결과를 얻을 수 있다.

텍스트 이미지뿐만 아니라, 딥러닝 모델에 앤드 유저가 직접 업로드한 데이터를 주입하는 경우에 배치 형태로 데이터를 딥러닝 모델에 인퍼런스하기 위해 고려해보면 좋을 트릭이다. 앤드 유저가 업로드하는 데이터는 굉장히 다양한 형태를 가지니까.

 

 

코드 (깃허브 링크)

- 단순히 이미지를 resize 하는 경우와 collate_fn을 이용하여 resize + padding 하는 경우를 시각화해서 비교하기 위한 코드

    - data_loader_with_transform : 이미지 resize transform이 포함된 dataloader
    - data_loader_with_collate_fn : 이미지 resize + padding 과정을 포함한 collate_fn을 포함한 dataloader

import os
import math
import torch
import glob
from PIL import Image
import cv2
import numpy as np
import torchvision.transforms as transforms
import torchvision
import matplotlib.pyplot as plt

curr_dir = os.path.dirname(__file__)

class ListDataset(torch.utils.data.Dataset):
    def __init__(self, image_list, transforms = None):
        self.image_list = image_list
        self.nSamples = len(image_list)
        self.transforms = transforms

    def __len__(self):
        return self.nSamples

    def __getitem__(self, index):
        img = self.image_list[index]
        img = Image.open(img).convert('RGB')
        if self.transforms is not None :
            img = self.transforms(img)
        label = np.random.randint(0,5, (1)) # 임의의 랜덤 레이블
        return img, label

class Image_Pad(object):
    def __init__(self, max_size, PAD_type='right'):
        self.toTensor = transforms.ToTensor()
        self.max_size = max_size
        self.max_width_half = math.floor(max_size[2] / 2)
        self.PAD_type = PAD_type

    def __call__(self, img):
        img = self.toTensor(img)
        c, h, w = img.size()
        Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
        Pad_img[:, :, :w] = img  # right pad
        if self.max_size[2] != w:  # add border Pad
            Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w)

        return Pad_img


class My_Collate(object):
    def __init__(self, imgH=32, imgW=100):
        self.imgH = imgH
        self.imgW = imgW

    def __call__(self, batch):
        batch = filter(lambda x: x is not None, batch)
        images, labels = zip(*batch)

        resized_max_w = self.imgW
        transform = Image_Pad((3, self.imgH, resized_max_w))

        resized_images = []
        for idx, image in enumerate(images):
            print(f'{idx} 번째 데이터 shape :', np.array(image).shape)
            w, h = image.size
            ratio = w / float(h)
            if math.ceil(self.imgH * ratio) > self.imgW:
                resized_w = self.imgW
            else:
                resized_w = math.ceil(self.imgH * ratio)

            resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)

            transformed_image = transform(resized_image)
            resized_images.append(transformed_image)

        image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)

        return image_tensors, labels



if __name__ == '__main__':
    imgH = 64
    imgW = 256

    image_list = glob.glob(os.path.join(os.path.dirname(__file__), '*.jpg'))

    transform = transforms.Compose([
            transforms.Resize((64,256)),
            transforms.ToTensor(),
        ])
    My_collate = My_Collate(imgH=imgH, imgW=imgW)
    
    dataset_with_transform = ListDataset(image_list, transforms = transform)
    dataset = ListDataset(image_list)

    data_loader_with_transform = torch.utils.data.DataLoader(dataset_with_transform, batch_size=len(image_list), shuffle=False)
    data_loader_with_collate_fn = torch.utils.data.DataLoader(dataset, batch_size=len(image_list), shuffle=False, collate_fn=My_collate)
    
    data_with_transform = next(iter(data_loader_with_transform))
    data_with_collate_fn = next(iter(data_loader_with_collate_fn))

    data = torch.vstack([data_with_transform[0], data_with_collate_fn[0]])

    grid = torchvision.utils.make_grid(data, nrow = len(image_list))
    plt.imshow(grid.permute(1,2,0))
    plt.show()

 

 

반응형