iMTE

starGAN 본문

Deep learning/GAN

starGAN

Wonju Seo 2018. 7. 25. 23:03

starGAN

StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation

Choi, Yunjey, et al. "Stargan: Unified generative adversarial networks for multi-domain image-to-image translation." arXiv preprint 1711 (2017).

Facial emotion을 detection하는 알고리즘을 만들다가, starGAN 논문을 읽고 정리를 해본다.
StarGAN에서 주목한 점은, 기존의 접근 방식들이 (Pix2Pix, CycleGAN, DiscoGAN 등) 두 개 이상의 domain 사이의 관계를 학습할 때 generator와 discriminator를 독립적으로 학습을 시켜야한다는 점이다. 즉, Domain 수에 따라 generator 수가 증가하는 것이다. k 개 domain이면 k(k-1)개의 generator를 학습시켜야한다!

StarGAN은 두가지를 주목한 것으로 보이는데 첫번째는 앞에서 언급한대로, 여러개의 generator를 쓰지말고 하나의 generator와 discriminator로 분류를 해보자라는 것. 두번째는 서로 다른 dataset에서 얻을 수 있는 서로 다른 label을 generator가 학습해보자는 것이다. 예를 들어, CelebA 데이터는 hair color, gender, age등 에 대한 label이 있는데, Facial expression을 학습하기 위해서 RaFD라는 데이터를 사용하자는 것이다. 이 두가지가 메인이다. 즉, 다른 dataset에서 얻은 feature를 사용해서 원래 dataset의 데이터를 condition을 줘서 realistic한 image를 생성할 수 있다는 것이다.

구체적으로 이 논문에서 지목한 기존의 image-to-image translation model의 문제점은 다음과 같다.

1. 여러 domain을 설정하는 경우, 각 generator가 전체 training data를 완전히 활용하지 않는다는 점. training set을 완전히 활용하지 않으면 생성된 이미지의 품질이 제한된다는 점.

2. 각 dataset이 부분적으로 분류되기 때문에, 서로 다른 dataset의 domain을 공동으로 학습할 수 없다는 점.

3. 여러개의 generator가 필요한점

이 논문에서는 mask vector라는 개념을 사용해 서로 다른 dataset의 domain간 joint training이 가능하게했다. 이 방법은 특정 dataset에서 주어진 label에만 초점을 맞추고, 모르는 label은 무시하는 방법을 사용했다.

"Our proposed method ensures that the model can ignore unknown labels and focus on the label provided by a particular dataset"

학습의 방법은 다음과 같다.

(b)~(d)는 어디서 많이 본 그림이다. 자세히 보면, Generator는 target domain정보와 input image를 받고 fake image를 생성해내고, 생성된 fake image에 original domain이 concatenate되서 다시 generator를 통과되서 원래 input image와 비슷한 형태를 만들어낸다. 그리고 생성된 fake image는 discriminator를 속이기 위해 학습이되며 discriminator는 real/fake를 분류하는 것과 동시에 domain의 label을 classification하는 auxiliary classifier의 역할을 한다. 기존의 방법에서 추가된 점이 바로 condition을 주는 점인데 (cGAN) label에 대한 정보를 입력과 같이 넣어주는 것이다. 기존의 방법들은 목적 자체가 이미 정해진 domain사이를 본것이라면 (말->얼룩말, 금발->흑발 등), 이 network의 구조는 multi-label에 condition되도록 generator를 학습시킨 것이라고 볼 수 있다. (개인적인 의견으로는 기존 방법이 1:1으로 되는 generator이니까, 여기에 조건을 넣어서 하나의 generator안에 다양한 1:1 매칭이 되도록 학습을 시킨 것 같다.)

Loss

1. Adversarial loss

Generator는 input image와 target domain label c로 condition된 이미지 G(x,c)를 생성해내고, D는 real과 fake를 구분하도록 학습한다.

2. Domain classification loss

주어진 input image x와 target domain c에 대해서, 이 논문은 c`의 target domain으로 분류되는 y input image로의 'translate'하는 것을 목표로 하고 있으며, auxiliary classifier를 D에 추가를 했다. 이 domain classification loss는 두 부분으로 나뉘는데, real image의 domain classification loss는 D를 최적화할 때 사용되고, fake image의 domain classification loss는 G를 최적화 할 때 사용된다.

: D에의해서 계산된 domain label에 대한 probability distribution을 나타낸다.

D에 대한 이 classification loss를 최소화함으로써, D는 original domain c`에 해당하는 real image를 분류할 수 있게 된다.

반대로, fake image의 domain classification의 loss function은 다음과 같이 정의된다.

G는 이 loss를 줄이도록 노력하면서 target domain c에 분류되는 image를 생성해낸다.

(CelebA의 input image가 A라는 label 정보를 갖고 있고, RaFD에서 B라는 label 정보를 바탕으로 B를 갖는 CelebA를 생성해내는 것이다. 그리고 다시 Generator는 fake image에서 A라는 label 정보를 입력받아 input image를 복구해내는 것이다! 그렇게 되면 Generator는 A와 B label에 대한 정보를 학습할 수 있게 되는 것이다. - 서로 다른 dataset임에도 불구하고)

3. Reconstruction loss

Generator가 만들어낸 이미지가 adversarial loss를 최소화 하기 위해서 realistic해지지만, 보장할 수 없으므로, DiscoGAN이나 CycleGAN에서 주목한 cycle-consistency loss를 추가해서 복원된 이미지가 입력된 이미지와 비슷하도록 학습을 한다.

"By minimizing the adversarial and classification losses, G is trained to generate images that are realistic and classified to its correct target domain. However, minimizing the losses does not guarantee that translated images preserve that content of its input images while chainging only the domain-related part of the inputs."

따라서, 최종 loss는 다음과 같아진다.

먼저, CelebA로만 학습시킨 모델의 결과는 다음과 같다.

다양한 label을 condition을 줘서 generator를 통해서 image를 생성할 수 있다. 위의 그림을 보면 StarGAN이 다른 모델에 비해서 상당히 realistic한 결과를 보여주고 있음을 확인할 수 있다. (단일 label만 고려한 경우에는 뚜렷한 차이를 보기가 어렵지만, label들이 여러개들이 고려되는 경우에 그 차이가 확실히 드러난다.)

두번째로, RaFD dataset으로 학습시킨 모델의 결과이다.

Facial expression에서 단일 label을 구현할 때도, realistic한 결과를 낼 수 있다는 것을 보여준다. StarGAN으로 생성된 이미지의 classification loss는 어느정도 될까? 이 논문에서는 기존의 모델(DIAT, cycleGAN, IcGAN)보다 classification error가 더 적을 뿐만 아니라 parameter의 수도 적다는 것을 보여주었다.

다음으로, multi domain뿐만 아니라 multi dataset으로 학습시킨 결과이다.

StarGAN -SNG은 RaFD로 학습시킨 모델로 CelebA에 적용시킨 결과이고, StarGAN-JNT는 CelebA와 RaFD로 학습시킨 모델로 CelebA에 적용시킨 결과이다. 단일 dataset에 학습시킨 것보다 CelebA와 RaFD로 학습시킨 코델이 더 잘 realistic한 image를 생성해낸다는 것을 바로 확인할 수 있다. 

(github) https://github.com/yunjey/StarGAN 

(한번 구현을 해봐야겠다.)

(이 posting에서 mask vector에 대한 설명과, correct mask vector와 wrong mask vector에 대한 설명은 skip했습니다. 관심있는 분은 꼭 논문을 읽어보시길 바랍니다.)



'Deep learning > GAN' 카테고리의 다른 글

Pix2Pix  (0) 2018.07.21
DiscoGAN  (0) 2018.07.20
CycleGAN  (0) 2018.07.19
Super-resolution GAN (SRGAN)  (0) 2018.07.19
Conditional GAN  (0) 2018.06.15
Comments