iMTE

Meat-Learning 8. Meta-learning for semi-supervised few-shot classification 본문

Deep learning study/Meta-Learning

Meat-Learning 8. Meta-learning for semi-supervised few-shot classification

Wonju Seo 2019. 5. 25. 16:13

Meta-learning for semi-supervised few-shot classification (ICLR, 2018)


Abstract

"In this work, we advance this few-shot classification paradigm towards a scenario where unlabeled examples are also available within each episode....To address this paradigm, we propose novel extensions of Prototypical Networks that are augmented with the ability to use unlabeled examples when producing prototypes."

Unlabeled data에 대해서 few-shot learning을 적용시키는 것이, 이 논문의 동기이다. 이전에 소개한, prototypical networks를 사용해서 unlabeled data에 대한 learning을 진행시켜보자는 것이다.

Introduction

"However, current deep learning approaches struggle in tackling problems for which labeled data are scarce. Specifically, while current methods excel at tackling a single problem with lots of labeled data, methods that can simultaneous solve a large variety of problems only a few labels are lacking...For this reason, recently there has been an increasing body of work on few-shot learning, which considers the design of learning algorithms that specifically allow for better generalization on problems with small labeled training sets."

위의 문장을 천천히 읽어보길 바란다. 이 논문에서는 few-shot learning scheme에서 unlabeled data가 추가될 때 어떻게 학습을 구성할지를 몇가지 방법으로 제안을 하고 있는데, 천천히 알아보도록 해보자.

Background (Prototypical networks)

Prototypical network는 embedding function h(x)를 학습시키는 것인데, 이는 입력을 feature로 바꿔주고, 같은 class는 가깝게 다른 class는 멀리 mapping한다. 각 class마다 prototype이 결정되며, 이는 다음의 식을 통해서 계산된다.

간단하게, 각 class 얻은 embedding의 평균을 prototype이라고 정의를 한다. 만약 새로운 example이 입력되는 경우, class는 다음과 같은 연산을 통해서 결정된다.

즉, 입력된 example의 embedding과 각 class의 prototype과의 비교(e.g., Euclidean distance)하고 softmax를 취해주는 형태로서, class를 결정짓는다. 학습시에는 softmax function이 사용되었기 때문에, negative log-likelihood가 사용된다.

그리고, class는 다음과 같은 argmax function으로 구해진다.

쉽게 직관적으로 이해할 수 있을 것이라고 생각하기 때문에, 다음 part로 넘어가보자.

Semi-supervised few-shot learning

Unlabeled data에 대해서 어떻게 few-shot learning 즉, prototypical networks를 적용시킬 수 있을까? 본 논문에서는 refinement라는 과정을 소개를 한다. 이 과정은 밑의 그림에서 나타나있다.

1. Prototypical networks with soft k-means

Unlabeled data를 고려한 prototypical networks를 제안하는 방법중 가장 간단한 방법은 semi-supervised clustering 방법을 사용하는 것이다. Support set과 unlabeled set을 사용해서 prototype cluster를 수정하는 것이다. 이때, 미분이 가능한 soft k-mean clustering을 사용하는 것이다. 다음과 같은 식을 통해서 prototype이 refinement된다.

즉, unlabeled sample을 갖고 현재 cluster에 대해서 classification을 한 다음에, 이 결과를 반영하여 prototype을 구한다는 것이다. 하지만, 이런 경우 unlabeled cluster가 심각하게 prototype을 손상시킬 수 있는 문제를 가질 수 있다.

2. Prototypical networks with soft k-means with a distractor cluster.

앞서 설명한 것 처럼, soft k-mean 방법은 unlabeled example이 N개의 class에 속할 수 있다는 가정하고 있지만, 이런 가정 대신에 다른 class가 있는 모델을 견고하게 만드는 것이 일반적일 수 있고 이러한 class를 distractor라고 부른다. 예를 들어, 외발 자전거와 스쿠터 사진을 구분하기 위해서, label이 없는 웹 이미지를 다운받는다고 하자. 이때, 모든 이미지가 외바 자전거와 스쿠터만 있을 것이라고 생각하는 것은 현실적이지 않다. 몇몇 이미지는 자전거와 같은 비슷한 class를 포함하게 된다. 

Distractor를 단순히 soft k-means에 포함시키는 것은 refinement process에 악영향을 미침으로, distractor를 고려하면서 prototype을 refinement해야한다. 다음의 방법으로 refinement가 진행된다.

위의 가정에 의하면 distractor의 cluster는 origin에 있게 된다. 그리고 본 논문에서는 추가적으로 length-scales을 추가하여 cluster 내의 distances의 variations을 고려하였다.

본 논문에서 r값은 1~N의 class에 대해서는 1로 결정하였고 오직 distractor의 r만 학습이 되도록 하였다.

3. Prototypical networks with soft k-means and masking

단순히 distractor unlabeled example을 하나의 cluster로 modeling 하는 것은 너무 쉬운 방법이다. 그리고 distractor가 하나의 class로 mapping되지 않는다면 매우 큰 variance를 갖게 될 것이다. (상상해보자, 사자와 개가 distractor인 경우, 이 두개의 variance는 얼마나 클지!) 이러한 상황을 반영하기 위해서, multiple class로 mapping하는 과정이 필요하다.

"To address this problem, we propose an improved variant : instead of capturing distractors with a high-variance catch-all cluster, we model distractors as examples that are within some area of any of the legitimate class prototype. This is done by incorporating a soft-masking mechanism on the contribution of unlabeled examples. At a high level, we want unlabeled examples that are closer to a prototype to be masked less than those that are farther."

위의 문장이 핵심 적인 내용을 포함하고 있는데, 분산이 큰 cluster를 사용해서 distractors를 capture하는 대신에, class prototype의 특정 영역 내에 있는 examples로서 distractor를 modeling 하겠다는 것이다. 이를 위해서 soft-masking mechanism이 사용된다. 예를 들어, prototype에 가까운 unlabeled example은 덜 masking하도록 하는 것이다. (즉, 이전에 직접 distractor를 정해주기 보다는, 학습의 형태로 distractor를 잘 정하도록 한다는 것이다.)

먼저 Soft k-mean은 다음과 같이 수정이 된다. 

그런 다음, soft threshold인 beta와 slope r을 각 prototype에 대해서 예측한다. 앞에서 계산된 normalized distance인 d의 다양한 통계적 정보를 입력으로 넣은 networks를 사용한다.

위의 MLP는 각 threshold가 intra-cluster variation의 양의 정보를 사용해서 unlabeled examples을 얼마나 고려하지 않을지를 결정하게 된다.

다음으로, 각 prototype의 각 example의 기여에 대한 soft mask m을 normalized distances의 threshold인 b와 비교를 함으로 계산해낸다.

앞에서 말했던 것 처럼, distractor에 대한 정보를 각 prototype의 threshold와 slope을 갖고 추출해내고 (이것이 distractor인지 아닌지를 말이다.) 이를 masking하여서 prototype 생성에 반영하겠다는 것이 3번째 방법의 핵심이다.

Related works and Experiments (생략)

Conclusion

"In this work, we propose a novel semi-supervised few-shot learning paradigm, where an unlabeled set is added to each episode. We also extend the setup to more realistic situations where the unlabeled set has novel classes distinct from the labeled classes."

위의 문장을 천천히 읽어보도록 하자. 결국, 이 논문은 어떻게 few-shot learning에서 unlabeled dataset을 처리할 수 있는 지 방법을 제시했다는 점이 핵심이다.

지금까지 여러 meta-learning algorithm을 review했습니다. 여러모로 읽는 분들이 도움이 됬을 것이라고 생각하고, meta-learning 편은 이정도까지만 review하고 다른 algorithm을 review하도록 하겠습니다. (그렇다고 공부를 안한다는 건 아닙니다 ㅎ)

Comments