iMTE

Variational Auto-Encoder (VAE) 본문

Deep learning/Keras

Variational Auto-Encoder (VAE)

Wonju Seo 2018. 7. 11. 17:26

참고 자료:

https://www.slideshare.net/ssuser06e0c5/variational-autoencoder-76552518

http://jaejunyoo.blogspot.com/2017/04/auto-encoding-variational-bayes-vae-1.html

https://ratsgo.github.io/generative%20model/2018/01/27/VAE/


Variational Auto-Encoder (VAE)


Auto-encoder는 high-dimensional data에 대해서 low-dimensional feature를 추출하고 (Encoder) 이 추출된 feature를 기반으로 original data를 복구하는 구조 (Decoder)를 갖고 있다.

개인적으로 참고자료들을 보고 (수많은 수식들...) 이해한 점은, low-dimensional 즉 latent variable을 가지고 새로운 data를 만들어보자라고 이해했다. 이전 Auto-encoder는 deterministic하게 latent variable을 고정했다면, Variational Auto-encoder는 latent variable에 zero-mean Gaussian Noise를 추가해서 본 이미지에서 살짝씩 달라지는 형태를 보자고 한 것 같다. 


다음은 Loss function이다.


이 항은 reconstruction error로 xi를 입력받은 encoder(q)가 생성해낸 z를 바탕으로 xi를 복구하는(decoder) 결과의 cross entropy를 나타내고 있다. 

이 항은 regularization 으로, sample된 z와 xi를 입력받은 encoder(q)가 생성해낸 z와의 probability distribution의 차이를 보고있다. 이 차이가 적으면 적을 수록 latent variable의 분포와 같아질 것이다. 


z를 sampling하는 과정이 미분이 불가능해서 backpropagation을 하지 못하는데, 이를 reparameterization trick을 사용했다. (참고자료를 참고하자!) 이 방법으로 미분가능하도록 만들어 Neural network가 학습을 하도록 하였다.


다음은 Keras에서 제공해주는 VAE를 구현하였다. (vae_loss가 동작안해서 인터넷에서 찾아서 함수를 사용했다.)


from keras.layers import Lambda, Input, Dense
from keras.models import Model
from keras.datasets import mnist
from keras.losses import mse, binary_crossentropy
from keras.utils import plot_model
from keras import backend as K
import numpy as np
import matplotlib.pyplot as plt
import os
from keras import objectives
# reparameterization trick
# instead of sampling from Q(z|x), samples eps = N(0,I)
# backpropagation을 위해서, sampling은 미분이 불가능해서 backprop이 불가능!
# z = z_mean + sqrt(var)*eps
def sampling(args):
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # by default, random_normal has mean =0, std=1.0
    epsilon = K.random_normal(shape=(batch,dim))
    return z_mean +K.exp(0.5*z_log_var)*epsilon

def main():
    (X_train,Y_train), (X_test,Y_test) = mnist.load_data()
    row = 28
    col = 28
    dim = row * col
    X_train = np.reshape(X_train,[-1,dim]).astype('float32')/255
    X_test = np.reshape(X_test,[-1,dim]).astype('float32')/255
        
    input_shape = (dim,)
    intermediate_dim = 512
    batch_size = 128
    latent_dim = 2 # mean and standard deviation!
    epochs = 30
    
    # VAE model = autoencoder (encoder + decoder)
    inputs = Input(shape=input_shape,name='encoder_input')
    # train q(z|x) -> approximation
    x = Dense(intermediate_dim,activation='relu')(inputs)
    x = Dense(intermediate_dim,activation='relu')(x)
    z_mean = Dense(latent_dim,name='z_mean')(x)
    z_log_var = Dense(latent_dim,name='z_log_var')(x)
    
    # use reparameterization trick to push the sampling out as input
    # z_mean+sqrt(var)*eps , Adding zero-mean Gaussian noise
    z = Lambda(sampling,output_shape=(latent_dim,),name='z')([z_mean,z_log_var])
    
    encoder = Model(inputs,[z_mean,z_log_var,z],name='encoder')
    encoder.summary()
    plot_model(encoder,to_file='vae_mlp_encoder.jpg',show_shapes=True)
    
    # decoder
    # p(x|z)
    latent_inputs = Input(shape=(latent_dim,),name='z_sampling')
    x = Dense(intermediate_dim,activation='relu')(latent_inputs)
    x = Dense(intermediate_dim,activation='relu')(x)
    outputs = Dense(dim,activation='sigmoid')(x) # 0~1
    
    decoder = Model(latent_inputs,outputs,name='decoder')
    decoder.summary()
    plot_model(decoder,to_file='vae_mlp_decoder.jpg',show_shapes=True)
    
    # VAE
    outputs = decoder(encoder(inputs)[2])
    vae = Model(inputs,outputs,name='vae_mlp')

    models = (encoder,decoder)
    data = (X_train,Y_train)
    
    def vae_loss(x, x_decoded_mean):
        xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
        kl_loss = -0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var))
        loss = xent_loss + kl_loss
        return loss
    
    vae.compile(optimizer='adam',loss=vae_loss)
    vae.summary()
    plot_model(vae,to_file='vae_mlp.jpg',show_shapes=True)
    vae.fit(X_train,X_train,epochs=epochs,batch_size=batch_size,validation_data=(X_test,X_test))
    vae.save_weights('vae_mlp_mnist.h5')
    
    plot_results(models,data,batch_size=batch_size,model_name='vae_mlp')

def plot_results(models,
                 data,
                 batch_size=128,
                 model_name="vae_mnist"):
    """Plots labels and MNIST digits as function of 2-dim latent vector
    # Arguments:
        models (tuple): encoder and decoder models
        data (tuple): test data and label
        batch_size (int): prediction batch size
        model_name (string): which model is using this function
    """

    encoder, decoder = models
    x_test, y_test = data
    os.makedirs(model_name, exist_ok=True)

    filename = os.path.join(model_name, "vae_mean.png")
    # display a 2D plot of the digit classes in the latent space
    z_mean, _, _ = encoder.predict(x_test,
                                   batch_size=batch_size)
    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=y_test)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.savefig(filename)
    plt.show()

    filename = os.path.join(model_name, "digits_over_latent.png")
    # display a 30x30 2D manifold of digits
    n = 30
    digit_size = 28
    figure = np.zeros((digit_size * n, digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-4, 4, n)
    grid_y = np.linspace(-4, 4, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[i * digit_size: (i + 1) * digit_size,
                   j * digit_size: (j + 1) * digit_size] = digit

    plt.figure(figsize=(10, 10))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range + 1
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap='Greys_r')
    plt.savefig(filename)
    plt.show()

if __name__ == '__main__':
    main()


VAE structure

Decoder structure

Encoder structure


Keras code에서 Lambda를 사용해서 reparameterization trick이 구현되었다.


-----------------------------------------------------------------------------------------------------------

공부하면서 알은 점들.


1. Posterior probability distribution을 구하기 위해서 trick을 사용하는데, VAE에서는 Variational Inference(변분추론)을 사용해서 Lower bounding을 최대화하는 과정을 통해서 posterior probability distribution의 parameter를 추정하였음.


2. Neural network로 학습시키고자 하는 probability distribution 을 위해서 학습하고자 하는 distribution과 NN의 distribution 사이의 분포 차이를 봄: 분포 차이는 Kullback Leibler Divergence (KLD) 로 계산을 하며 다음과 같이 정의됨.

만약 둘의 분포가 같다면 KLD은 0이 된다.

Comments