๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
๐Ÿ’ป Programming/AI & ML

[pytorch] COCO Data Format ์ „์šฉ Custom Dataset ์ƒ์„ฑ

by ๋ญ…์ฆค 2022. 6. 4.
๋ฐ˜์‘ํ˜•

Object Detection๊ณผ Segmentation ์—์„œ ํ”ํžˆ ์‚ฌ์šฉ๋˜๋Š” COCO dataformat ์ „์šฉ Customdataset์„ ์ƒ์„ฑํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์†Œ๊ฐœํ•œ๋‹ค.

 

ํ”ํžˆ ์•Œ๊ณ  ์žˆ๋Š” COCO ๋ฐ์ดํ„ฐ์…‹์ด ์žˆ๊ณ , ๋งŽ์€ ๋ฐ์ดํ„ฐ์…‹๋“ค์ด COCO data format ์„ ๋”ฐ๋ฅด๋Š”๋ฐ, ์ด๋Ÿฌํ•œ ๋ฐ์ดํ„ฐ์…‹์„ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด Customdataset์„ ๊ตฌ์„ฑํ•˜๋Š” ๋ฐฉ๋ฒ•๊ณผ COCO API ์ธ Pycocotools ์‚ฌ์šฉ๋ฒ•์„ ์„ค๋ช…ํ•œ๋‹ค.

 

COCO Data Format

Detection task์—์„œ๋Š” Bounding box์˜ ์œ„์น˜์™€ class label์ด ํ•„์š”ํ•˜๊ณ  segmentation task ์—์„œ๋Š” segment mask ์ •๋ณด๊ฐ€ ํ•„์š”ํ•˜๋‹ค. ์ด๋Ÿฌํ•œ annotation ์ •๋ณด๋“ค์€ json ํ˜•ํƒœ๋กœ ์ œ๊ณต๋˜๊ณ , JSON ํŒŒ์ผ์—๋Š” Info, License, Images, Categories, Annotations 5๊ฐ€์ง€ ์ •๋ณด๊ฐ€ ๋‹ด๊ฒจ ์žˆ๋‹ค. ์ฃผ์˜ํ•  ์ ์€ ์ด๋ฏธ์ง€ ํ•˜๋‚˜์— ์—ฌ๋Ÿฌ ๊ฐ์ฒด๊ฐ€ ์กด์žฌํ•  ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ annotation ์ •๋ณด๋Š” ์—ฌ๋Ÿฌ๊ฐœ ์ผ ์ˆ˜ ์žˆ. ๋•Œ๋ฌธ์— dataset์„ ๋ถˆ๋Ÿฌ์˜ฌ ๋•Œ ํ•˜๋‚˜์˜ ์ด๋ฏธ์ง€์— ํฌํ•จ๋œ ๋ชจ๋“  annotation ์ •๋ณด๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๊ณ  maskํ™” ์‹œ์ผœ์„œ ground truth ๋ฅผ ๋งŒ๋“œ๋Š” ์ž‘์—…์ด ํ•„์š”ํ•˜๋‹ค.

 

<Annotation dictionary>

  • image_id : ์ด๋ฏธ์ง€ ์ธ๋ฑ์Šค
  • category_id : label ์ •๋ณด
  • segmentation : pixel ์ขŒํ‘œ
  • bbox : bounding box ์ •๋ณด 
  • ...๋“ฑ๋“ฑ

 

Pycocotools 

json ํŒŒ์ผ์—์„œ ์‚ฌ์šฉํ•ด์•ผํ•  ์ •๋ณด๋Š” ์ด๋ฏธ์ง€๋ณ„ annotation ์ •๋ณด์ธ๋ฐ, ๋ถ„์‚ฐ๋˜์–ด ์žˆ๋Š” ์ •๋ณด๋ฅผ Pycocotools ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํšจ์œจ์ ์œผ๋กœ COCO data format์˜ data ๋ฅผ ๋‹ค๋ฃฐ ์ˆ˜ ์žˆ๋‹ค.

 

*์ž์ฃผ ์‚ฌ์šฉํ•˜๋Š” method 

coco = COCO(dataDir)

coco.getCatIds() # category id ๋ฐ˜ํ™˜
coco.loadCats(cat_ids) # category id๋ฅผ ์ž…๋ ฅ์œผ๋กœ category name, super category ์ •๋ณด ๋‹ด๊ธด dict ๋ฐ˜ํ™˜
 
coco.getImgIds(imgIds=image_id) # imageg id ๋˜๋Š” category id → image id 
coco.loadImgs(image_id)[0# imageg id → annotations์˜ image dict (์ด๋ฏธ์ง€ ์ •๋ณด)
 
coco.getAnnIds(image_id# image id, category id → annotation id (ํ•ด๋‹น ์ด๋ฏธ์ง€ ๋˜๋Š” ์นดํ…Œ๊ณ ๋ฆฌ์˜ ๋ชจ๋“  annotation ์ •๋ณด)
coco.loadAnns(ann_ids) # image id → annotations info
coco.annToMask(anns) # annotation ์ •๋ณด๋กœ mask๋ฅผ ์ƒ์„ฑ
 

 

COCO Data Format Custom Dataset / Dataloaer
dataset_path = '/data' # Dataset ๊ฒฝ๋กœ ์ง€์ • ํ•„์š”
train_path = dataset_path + '/train.json'
val_path = dataset_path + '/val.json'
test_path = dataset_path + '/test.json'

import cv2
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from pycocotools.coco import COCO
import matplotlib.pyplot as plt

class COCO_dataformat(Dataset):
    def __init__(self, data_dir, mode = 'train', transform = None):
        super().__init__()
        self.mode = mode
        self.transform = transform # transform : albumentations ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์‚ฌ์šฉ
        self.coco = COCO(data_dir)

        self.cat_ids = self.coco.getCatIds() # category id ๋ฐ˜ํ™˜
        self.cats = self.coco.loadCats(self.cat_ids) # category id๋ฅผ ์ž…๋ ฅ์œผ๋กœ category name, super category ์ •๋ณด ๋‹ด๊ธด dict ๋ฐ˜ํ™˜
        self.classNameList = ['Backgroud'] # class name ์ €์žฅ 
        for i in range(len(self.cat_ids)):
          self.classNameList.append(self.cats[i]['name'])

    def __getitem__(self, index: int):
        image_id = self.coco.getImgIds(imgIds=index) # img id ๋˜๋Š” category id ๋ฅผ ๋ฐ›์•„์„œ img id ๋ฐ˜ํ™˜
        image_infos = self.coco.loadImgs(image_id)[0] # img id๋ฅผ ๋ฐ›์•„์„œ image info ๋ฐ˜ํ™˜
        
        # cv2 ๋ฅผ ํ™œ์šฉํ•˜์—ฌ image ๋ถˆ๋Ÿฌ์˜ค๊ธฐ(BGR -> RGB ๋ณ€ํ™˜ -> numpy array ๋ณ€ํ™˜ -> normalize(0~1))
        images = cv2.imread(os.path.join(dataset_path, image_infos['file_name']))
        images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)
        images /= 255.0 # albumentations ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋กœ toTensor ์‚ฌ์šฉ์‹œ normalize ์•ˆํ•ด์ค˜์„œ ๋ฏธ๋ฆฌ ํ•ด์ค˜์•ผ~
        
        if (self.mode in ('train', 'val')):
            ann_ids = self.coco.getAnnIds(imgIds=image_infos['id']) #img id, category id๋ฅผ ๋ฐ›์•„์„œ ํ•ด๋‹นํ•˜๋Š” annotation id ๋ฐ˜ํ™˜
            anns = self.coco.loadAnns(ann_ids) # annotation id๋ฅผ ๋ฐ›์•„์„œ annotation ์ •๋ณด ๋ฐ˜ํ™˜

            # ์ €์žฅ๋œ annotation ์ •๋ณด๋กœ label mask ์ƒ์„ฑ, Background = 0, ๊ฐ pixel ๊ฐ’์—๋Š” "category id" ํ• ๋‹น
            masks = np.zeros((image_infos["height"], image_infos["width"]))
            anns = sorted(anns, key=lambda idx : len(idx['segmentation'][0]), reverse=False)
            for i in range(len(anns)): # ์ด๋ฏธ์ง€ ํ•˜๋‚˜์— ์กด์žฌํ•˜๋Š” annotation ์ˆœํšŒ
                pixel_value = anns[i]['category_id'] # ํ•ด๋‹น ํด๋ž˜์Šค ์ด๋ฆ„์˜ ์ธ๋ฑ์Šค
                #className = classNameList[anns[i]['category_id']] # ํด๋ž˜์Šค ์ด๋ฆ„
                masks[self.coco.annToMask(anns[i]) == 1] = pixel_value # coco.annToMask(anns) : anns ์ •๋ณด๋กœ mask๋ฅผ ์ƒ์„ฑ / ๊ฐ์ฒด๊ฐ€ ์žˆ๋Š” ๊ณณ๋งˆ๋‹ค ๊ฐ์ฒด์˜ label์— ํ•ด๋‹นํ•˜๋Š” mask ์ƒ์„ฑ
            masks = masks.astype(np.int8)
                        
            if self.transform is not None:
                transformed = self.transform(image=images, mask=masks)
                images = transformed["image"]
                masks = transformed["mask"]
            return images, masks, image_infos
        
        if self.mode == 'test':
            if self.transform is not None:
                transformed = self.transform(image=images)
                images = transformed["image"]
            return images, image_infos
    
    def __len__(self) -> int:
        return len(self.coco.getImgIds()) # ์ „์ฒด dataset์˜ size ๋ฐ˜ํ™˜ 

import albumentations as A
from albumentations.pytorch import ToTensorV2

train_transform = A.Compose([
                            ToTensorV2()
                            ])

val_transform = A.Compose([
                          ToTensorV2()
                          ])

test_transform = A.Compose([
                          ToTensorV2()
                          ])

train_dataset = COCO_dataformat(data_dir=train_path, mode='train', transform=train_transform)
val_dataset = COCO_dataformat(data_dir=val_path, mode='val', transform=val_transform)
test_dataset = COCO_dataformat(data_dir=test_path, mode='test', transform=test_transform)

batch_size = 8
def collate_fn(batch):
    return tuple(zip(*batch))

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=4,
                                           collate_fn=collate_fn)

val_loader = torch.utils.data.DataLoader(dataset=val_dataset, 
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=4,
                                         collate_fn=collate_fn)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          num_workers=4,
                                          collate_fn=collate_fn)

 

 

๋ฐ˜์‘ํ˜•
Dataloader ๋ฅผ ์ด์šฉํ•œ data ์‹œ๊ฐํ™” 
numOfSample = 3
fig, axes = plt.subplots(nrows=numOfSample, ncols=2, figsize=(15, 15))
tmp_imgs = []
tmp_masks = []
coco = COCO('/content/gdrive/My Drive/baseline_code/input/data/train.json')

for imgs, masks, image_infos in train_loader:
    for idx in range(numOfSample):
      tmp_info = []
      for anns in coco.loadAnns(coco.getAnnIds(image_infos[idx]['id'])) :
        if [anns['category_id'],train_dataset.classNameList[anns['category_id']]] not in tmp_info:
          tmp_info.append([anns['category_id'],train_dataset.classNameList[anns['category_id']]])
      print(idx,'๋ฒˆ์งธ ์ด๋ฏธ์ง€์— ํฌํ•จ๋œ labels :',tmp_info)
      axes[idx][0].imshow(imgs[idx].permute([1,2,0]))
      axes[idx][0].grid(False)
      axes[idx][0].set_title("input image : {}".format(image_infos[idx]['file_name']), fontsize = 15)

      axes[idx][1].imshow(masks[idx])
      axes[idx][1].grid(False)
      axes[idx][1].set_title("input image : {}".format(image_infos[idx]['file_name']), fontsize = 15)
    break

plt.show()

 

๋ฐ˜์‘ํ˜•