일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
- Interpretability
- 코딩테스트
- keras
- 딥러닝
- Score-CAM
- coding test
- Deep learning
- python
- Unsupervised learning
- SmoothGrad
- Cam
- 백준
- Class activation map
- 설명가능한 인공지능
- Artificial Intelligence
- 설명가능한
- 시계열 분석
- Machine Learning
- xai
- 코딩 테스트
- 기계학습
- meta-learning
- 메타러닝
- GAN
- 머신러닝
- Explainable AI
- cs231n
- 인공지능
- AI
- grad-cam
- Today
- Total
iMTE
Meta-Learning 4. Siamese, Prototypical, Relation, and Matching Networks 본문
Meta-Learning 4. Siamese, Prototypical, Relation, and Matching Networks
Wonju Seo 2019. 5. 13. 22:54Meta-Learning sub parts
Siamese, Prototypical, Relation, and Matching networks
Meta-learning이 사용 될 수 있는 networks를 소개하고자 한다. (개념을 위주로 이해해보자!)
1. Siamese Networks
Deep Metric Learning에서 다루었던 내용이지만, Siamese networks는 이미지 두개를 받는 두개의 networks (weight은 sharing하고 같은 구조)로 구성되어 있다. 두개의 networks는 두개의 이미지로 부터 feature vector (embedding)을 생성해내고, 이 feature vectors의 similarity로 부터 학습된다. Similarity 계산은 feature vector 사이의 거리를 계산하는데, Euclidean distance 혹은 cosine distance가 사용될 수 있다.
E = ||f(X1)-f(X2)|| - Euclidean distance
E = dot(f(X1),f(X2))/(||f(X1)||*||f(X2)||)
아주 직관적으로 풀어보자면, networks는 이미지로부터 특징을 추출해내고, 두 이미지로부터 추출된 특징들을 비교한다는 것이다. 이는, 단순히 그림을 비교하는 것보다 high-level의 특징을 고려할 수 있음으로 더 정확한 구분을 할 수 있다. (예, 두 이미지는 같다. 혹은 두 이미지는 다르다.)
Loss로는 Contrastive Loss가 사용된다.
Contrastive Loss = Y*E^2+(1-Y)*max(margin-E,0)^2, Y는 label, E는 similarity, margin은 margin.
다시 직관적으로 보자면, 같은 이미지는 최대한 가깝도록 feature vector가 생성되도록 하며, 다른 이미지는 margin 정도의 E를 갖도록 하여 feature vector가 공간적으로 떨어지도록 한다. (공간이라는 단어를 사용한 이유는, 거리의 개념이 사용되기 때문이다.)
2. Prototypical Networks
Prototypical network는는 간단하게, 각 이미지로부터 feature vector를 구한 다음에, 각 class에 해당하는 이미지들의 feature vector를 평균을 취해서 class prototype을 구한 이후, 새로운 이미지 (query image)로부터 얻은 feature vector와 비교를 하고 (distance 계산) softmax를 취해서, probability를 보여주는 networks이다.
Support set (e.g, training set)에서 각 이미지의 feature vector가 networks를 통해서 추출되고, 각 class의 이미지로 부터 얻은 feature vector를 평균을 취한다. 이는 class prototype이 된다. (직관적으로 이해하자면, 이 prototype은 현재 갖고 있는 이미지의 feature vector의 general한 정보를 담고 있다.) 그리고, 새로운 query image에 동일한 feature extractor (e.g., networks)로부터 feature가 추출되고, 이 feature와 class prototype이 비교된다. 비교되는 방법은 당연히 distance를 계산하는 방법이고 (e.g., calculation similarity), 각 class에 해당하는 prototype과 비교한 distance에 softmax를 취해서 probability를 계산한다. Loss로는 negative log probability가 사용된다.
Variant로 Gaussian prototypical network와 semi-prototypical network가 있다. Gaussian prototypical network는 단순히 class prototype을 구하는게 아니라, covariance matrix도 같이 계산하여, confidence range를 고려하는 network이다. 이 경우, noise에 강건해진다. Semi-prototypical network는 unlabeled data를 처리하기 위한 networks로서, 임의의 class를 부여해서 prototypical network를 학습시킨다. 하지만, 임의의 class를 하나만 고려하는 경우, high variance를 갖는 문제가 있다. 이 경우 모든 unlabeled data를 고려하지말고, class prototype으로부터 threshold를 정해서 그 안에 있는 data만 고려함으로써 부분적으로 해결할 수 있다.
3. Relation Networks
Relation networks도 이전 siamese, prototypical networks와 비슷한 형태이다. 마찬가지로, networks가 feature vector를 추출하는데, support set에서 추출된 feature vector와 query image의 feature vector의 combination을 relation network를 통과해서 각 class에 해당하는 relationship (0~1)을 계산한다. 이전에는 단순히 feature vector를 비교하였다면, relation network라는 새로운 network를 통과시켜서 relationship을 계산한다는 점이 siamese network와 prototypical network와 다른점이다. 만약 few-shot learning에 적용하는 경우라면, feature vector를 combine할 때, support set의 각 class에 해당하는 모든 이미지의 feature vector를 더한 다음에 query image의 feature vector와 combine한다.
4. Matching Networks
Matching networks는 relation networks와 비슷한데, support set과 query image에서 사용되는 feature extractor (e.g., networks)가 다르며, support set과 query image의 feature vector의 cosine distance로 부터 attention matrix가 계산되고, attention matrix가 one-hot encoding에 곱해져서 각 class에 대한 probability가 계산된다는 점이 다르다. (이렇게 적고보니 완전 다르다. 근데, 다른게 맞다.) Cosine similarity에 softmax를 적용해서 attention map을 적용했다는 점이 흥미롭다. 하지만, 굳이 one-hot encoding에 곱해야하는지는 의문이다. (왜 computational load를 늘릴까?)
다음에는 Neural Turing Machine, Memory Augmented Neural Networks를 다루겠다.