iMTE

Pix2Pix 본문

Deep learning/GAN

Pix2Pix

Wonju Seo 2018. 7. 21. 18:38

Pix2Pix

Image-to-Image Translation with Conditional Adversarial Networks

Isola, Phillip, et al. "Image-to-image translation with conditional adversarial networks." arXiv preprint (2017).


cycleGAN, DiscoGAN, Pix2Pix 와 같은 image-to-image translation model은 보란듯이 첫 페이지에 결과를 보여준다. (이러니 안궁금할 수가 없지.) 핵심만 간단하게 해석하고 접근을 해보자.

먼저, Pix2Pix는 다른 cylceGAN, DiscoGAN과 달리 Paired image를 요구한다. 즉, input output이 서로 관련이 있는 것으로, 위의 그림에서와 같이 pair가 있는 것이다. 가장 쉽게 얻을 수 있는 것은 RGB to gray pair일 것이다. RGB image를 gray scale로 만드는 건 그렇게 어렵지 않다. 만약 이 데이터 셋을 사용한다면 우리는 gray를 RGB로 변환시킬 수 있는 모델을 학습시킬 수 있을 것이다. 

CNN (Convolutional Neural Network)에게 paired image이니까.. 입력과 출력의 차이를 줄여주도록 학습하면 완성이 될 것이라고 볼 수 있지만.... 출력으로 blurry한 image가 나오게된다. 논문은 다음과 같이 언급한다.

"If we take a naive approach, and ask the CNN to minimize Euclidean distance between predicted and ground truth pixels, it will tend to produce blurry results"

논문은 이 문제를 지목했고, 우리가 볼때 blurry한 이미지는 'real' image라고 볼 수 없으니.. 여기에 GAN을 써보자는게 이 논문의 접근 방법이다. GAN은 그럴듯한 이미지를 만들어내는 것이 목적이니, 이 그럴듯한 이미지로 blurry한 문제를 해결해보자는 것이다.

위 그림은 L1 loss로 학습된 결과 (blurry)와 제안된 모델로 학습된 결과를 보여주고 있다. (너무 극명하게 차이가 나는 것을 확인할 수 있다.)

네트워크의 구조는 U-net을 사용했고, discriminator는 PatchGAN의 방법을 사용했다. 이미지 전체를 보고 가짜다 진짜다라고 판단하지말고, patch형태로 나눠서 보자라고 해석하는게 나은 것 같다. patch형태로 본다면 어떤 patch가 loss가 큰지를 generator는 알고 그 부분을 더 잘 학습할 것이다.

논문에서는 conditional GAN의 형태의 loss function을 사용했는데, 입력 image와 noise인 z를 사용하는 형태이다.

논문에서는 z를 random gaussian noise를 사용하기 보다, dropout을 사용했다.

"Instead, for our final models, we provide noise only in the form of dropout, applied on several layers of our generator at both training and test time"

그리고 L1 loss를 추가해서, 최종적인 loss는 다음과 같이 계산된다.

이 논문에서는 loss function을 위와 같이 정했는데, 여기에는 크게 두가지 이유가 있다. L1 loss는 image의 low-frequency content를 학습할 수 있다. (그래서 averaging effect가 나타나는 것이다.) 반대로 adversarial loss는 realistic image를 만들기 위해서 high-frequency content를 학습하게 된다. 따라서, 기존의 L1 혹은 L2 loss로만 얻을 수 있었던 이미지와 달리 좀더 sharp하고 실제와 같은 이미지를 얻게되는 것이다. (서로 다른 content를 학습시키는데 기존 방법(L1, L2 loss)과 GAN이라는 새로운 방법론이 사용될 수 있다는 아이디어를 생각해낸것이다!)

위 그림은 Pix2Pix가 학습하는 방법에 대해서 설명하고 있다. Paired 이미지가 Discriminator에 입력되고 Discriminator는 fake와 real을 구분하게 된다.

위의 그림은 이 논문에서 다양한 domain에 적용시켜본 결과를 보여주고 있다. edge2photo에서 부터 sketh2portrait, depth2streetview, background removal등의 예들이 보여지고 있다. 한눈에 볼때, 잘 학습이 되어 좋은 결과를 보여주는 것을 확인할 수 있다.

다음은 keras로 cityscapes dataset으로 구현해본 Pix2pix의 결과이다.

generated된 이미지가 살짝 흐린 감이 있지만, 그래도 논문에나온 L1 loss만 고려할 때 보다 더 sharp하고 realistic한 이미지를 얻을 수 있었다.



'Deep learning > GAN' 카테고리의 다른 글

starGAN  (0) 2018.07.25
DiscoGAN  (0) 2018.07.20
CycleGAN  (0) 2018.07.19
Super-resolution GAN (SRGAN)  (0) 2018.07.19
Conditional GAN  (0) 2018.06.15
Comments