본문 바로가기
🏛 Research/Detection & Segmentation

[간단 설명] Semi-Supervised Semantic Segmentation / Segmentation에서 unlabeled 데이터를 사용하여 학습하는 방법

by 뭅즤 2022. 1. 13.
반응형

Semi-supervised semantice segmentation 이라는 분야를 설명하기 위해 아래 논문들을 소개합니다.

  1. Semi-supervised semantic segmentation needs strong, varied perturbations (BMVC 2020)
  2. Semi-Supervised Semantic Segmentation with Cross-Consistency Training (CVPR 2020)
  3. Guided Collaborative Training for Pixel-wise Semi-Supervised Learning (ECCV 2020)
  4. PSEUDOSEG: DESIGNING PSEUDO LABELS FOR SEMANTIC SEGMENTATION (ICLR 2021)

Semi-supervised learning 은 적은 수의 labeld data와 많은 un-labeled data로 network를 학습시키는 방법입니다. segmentation의 경우 classification task 보다 아래의 이유들로 까다롭습니다.

1. Pixel-wise prediction을 해야함

2. Semi-supervised learning 에서는 이미지에 특정 augmentation(perturbation)을 주는 방식을 많이 사용하는데, affine transformation 기반의 이러한 augmentation이 mask(label) 또한 바뀌어야하기 때문에 segmentation task에 적합하지 않습니다.

 

Paper 1 : Semi-supervised semantic segmentation needs strong, varied perturbations (BMVC2020)

CutMix-Seg

위 논문은 mean teacher framwork(Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results, NIPS2017)에 affine transform 기반의 augmentation이 아닌 cutmix 방법을 사용해서 semi-supervised segmentation 을 수행합니다.

Consistency loss

구체적인 방법은 2개의 서로 다른 input image들을 cutmix로 합치고 student network로 prediction 한 mask와, 2개의 input image들을 각각 Teacher network에 forwarding시켜 만들어진 prediction map을 cutmix로 합친 mask로 consistency loss를 구해 network를 학습시킵니다.

 

Paper 2 : Semi-Supervised Semantic Segmentation with Cross-Consistency Training (CVPR2020)

unlabeld data는 K개의 보조 decoder들을 사용하여 encoder에서 추출된 feature에 특정 perturbation을 주고 decoder들에서 생성된 mask들과, main decoder가 생성한 mask간의 unsupervised loss를 계산하여 network를 학습합니다.

perturbation은 Feature based perturbation(feature map의 activation 중 일부를 제거하거나 노이즈 삽입), Prediction based perturbation(main 또는 auxiliary decoder의 예측을 기반으로 perturbation 추가), Random perturbation 세가지를 사용합니다.

 

반응형

 

Paper3 : Guided Collaborative Training for Pixel-wise Semi-Supervised Learning (ECCV 2020)

GCT Framework

다르게 초기화가 된(두 모델 사이의 perturbation을 이용하기 위해 필수) 2개의 모델을 만들고(T1, T2), input image x가 들어가면 T_k(x)를 예측하고, x와 Tk가 concat 되어 F로 들어가서 flaw probability map을 추정합니다. prediction confidence map은 1-F(x,Tk(x))로 approximate 됩니다.

GCT framework는 GAN 처럼 두 단계로 학습되는데, 첫 번째 단계에서는 fixed된 F(F : flaw detector)로 Tk를 학습합니다. label 이 있는 data의 경우 Tk(xl)은 해당 label y 와의 supervised loss를 계산합니다.

Unlabeled data를 학습하기 위해 본 논문에서는 dynamic consistency constraint(Ldc) 와 flaw correlation constraint(Lfc) 를 제안합니다. Ldc는 T1(x), T2(x)에서 신뢰할 수 있는 pixel을 앙상블하기 위한 loss 이고, Lfc는 신뢰할수 없는 prediction을 수정하기 위한 loss 입니다. 

 

 

Paper 4 : PSEUDOSEG: DESIGNING PSEUDO LABELS FOR SEMANTIC SEGMENTATION

PseudoSeg Framework

PseudoSeg 모델은 grad-cam 을 이용하는 참신하고 간단한 방법으로 pseudo label을 만들어서 성능을 향상시킵니다.

Input image에 weak augmentation을 가한 후 네트워크(f)에 주입 후 출력된 decoder prediction과 self-attention grad cam으로 출력한 feature을 fusion하여 pseudo label을 만들고, input image에 strong augmentation을 가한 후 네트워크(f)에 주입하여 나온 prediction feature를 생성합니다. 생성된 pseudo label과 prediction으로 loss를 계산하여 모델을 학습시킵니다.

 

아래 그림을 보면 decoer 출력보다 SGC map과 fusion한 pseudo label이 정확하고, 이를 decoder(strong) 과 loss를 계산하여 모델을 학습시키는 것을 시각적으로 확인할 수 있습니다.

 

반응형