シリーズで、NNablaのPython-like C++APIをご紹介したいと思います。
第五回目は、VAEのサンプルを元にして、四則演算や算術関数を使ったロスの記述例をご紹介いたします。
生成モデル用ニューラルネットワークの代表例であるVAE(Variational Auto-Encoder:変分オートエンコーダ)を説明します。
VAEは、2013年に、Diederik Kingmaらによって提唱されたニューラルネットのオートエンコーダの形式で変分ベイズ推論を書き下したモデルです。
以下はその論文です。
Auto-Encoding Variational Bayes
Diederik P Kingma, Max Welling
https://arxiv.org/abs/1312.6114
VAEは、データの生成分布と、潜在変数の事後分布をオートエンコーダのデコーダ(出力側)と、エンコーダ(入力側)でモデル化するニューラルネットです。
パラメータの学習は、通常のニューラルネットの誤差逆伝搬学習を、誤差として以下のELBO(変分下界)と呼ばれる目的関数(符号は反転)を用いることで実行します。
\(L(x;\phi,\theta)=\text{E}_{z \sim q_\phi (z│x)} [\text{log}(p_\theta (x│z))+\text{log}(p_\theta (z))-\text{log}(q_\phi (z│x))
\)
ここで、\(\phi\) はエンコーダのパラメータ、\(\theta\)はデコーダのパラメータ、\(x\)はデータ、\(z\)は潜在変数、\(q_\phi (z|x)\)は潜在変数の事後分布、\(p_\theta(x|z)\)はデータの生成分布、\(p_\theta (z)\)は潜在変数の事前分布です。
目的関数の第一項は、潜在変数の分布からデータが生成される尤度を表します。
また、第二項と第三項は、潜在変数の事前分布と、データから推定される潜在変数の事後分布の分布間距離を表しています。
VAEでは、潜在変数の事前分布を標準ガウス分布とすることで、潜在変数の分布が標準ガウス分布に近づくように学習がなされます。
期待値計算における「\(z\sim q_\phi (z|x)\)」は、潜在変数の事後分布からのモンテカルロサンプリングを表します。
しかし、潜在変数を実際にサンプリングしてしまうと、潜在変数以前のネットワークに誤差が逆伝搬できないという問題が生じます。
この問題を回避するために、VAEでは、まず、潜在変数の事後分布を共分散が対角化されたガウス分布に制限しています。
\(q_\phi (z|x) = N(z; \mu_\phi (x), \sigma_\phi ^ 2 (x))
\)
ここで、\(\mu_\phi (x)\)、\( \sigma_\phi ^ 2 (x)\)はガウス分布の平均、分散パラメータです。
この平均、分散パラメータは、データを入力して学習パラメータのニューラルネットを介して算出されるようにします。
このようにすると、モンテカルロサンプリングの処理が、以下のように標準ガウス分布のサンプリングとニューラルネットに分離できます。
\(z = \mu_\phi (x) + \sigma_\phi(x) \epsilon
\)
ここで、\(\epsilon\)は標準ガウス分布に従ってサンプリングされた値でネットワークを介した情報は含まれていません。
一方、\(\mu_\phi (x)\)、\(\sigma_\phi(x)\)には、ネットワークを介した情報が含まれているため、誤差の逆伝搬が可能になります。
VAEでは、このようにネットワークの情報を残したままサンプリングを行うことで、モンテカルロサンプリングを実現しつつ、潜在変数以前のネットワークへの誤差逆伝搬を可能にしています。
このような処理のことをリパラメタライゼーショントリックといいます。
なお、モンテカルロサンプリングの回数は、繰り返し学習の効果があることから、一サンプルあたり一回でも十分であるとされています。
VAEの実行ファイルを作るプログラムは、以下の3つのソースコードファイルで構成されます。
以下では、プログラムのコアにあたるvae_training.hppから、これまで説明しなかった内容を説明します。
VAE学習プログラム:vae_training.hpp
このソースコードファイルでは、VAEによる生成モデルの記述と、モデルのビルド、学習、評価の記述を行っています。
モデルのビルド、学習、評価の記述は、これまでの記述と大きな違いはありません。
ここでは、VAEのネットワーク構造の記述について、これまで説明しきれなかった内容を中心に紹介します。
- エンコーダ
エンコーダは、データを入力して、前述の平均、分散パラメータを算出します。
ここで、データはMNISTの白黒画像で、平均、分散パラメータは、潜在変数の次元毎に設定されるパラメータです。
データからパラメータまでの処理は、二層の全結合と、パラメータ別に分岐されたそれぞれ一層の全結合層で構成されています。
分散パラメータは0以上となるような制約が必要ですが、ここでは、対数分散(logvar)とみなすことで回避しています。
// Fully connected layers, and Elu replaced from original Softplus.
auto h1 = f::elu(pf::affine(xa, 1, 500, parameters["fc1"]), 1.0);
auto h2 = f::elu(pf::affine(h1, 1, 500, parameters["fc2"]), 1.0);
// The outputs are the parameters of Gauss probability density.
auto mu = pf::affine(h2, 1, 50, parameters["fc_mu"]);
auto logvar = pf::affine(h2, 1, 50, parameters["fc_logvar"]);
auto sigma = f::exp(logvar * 0.5);
- デコーダ
デコーダでは、まず、潜在変数のサンプリングを行い、続いて、潜在変数から、データの生成分布のパラメータを算出します。
潜在変数のサンプリング
潜在変数のサンプリングでは、前述したリパラメタライゼーションによる処理を行います。
以下は、リパラメタライゼーションを利用した潜在変数のサンプリングの例になります。
// The prior variable and the reparameterization trick
auto epsilon = f::randn(0.0, 1.0, shape_z, 706);
auto z = mu + sigma * epsilon;
ガウス乱数レイヤ
Python-like C++APIでは、PythonAPIと同様乱数を出力するレイヤが用意されています。
リパラメタライゼーションに用いる乱数の生成には、ガウス乱数を生成するrandnレイヤを用います。
引数のshape_zは、潜在変数のテンソルサイズでその算出方法は、容易なのでここでは割愛しています。
四則演算子の上書き
Python-like C++APIでは、PythonAPIと同様、四則演算子の上書き(オーバーロード)を行っています。
演算は、CgVariableクラス同士の四則演算の他に、スカラとCgVariableクラスの四則演算ができます。
データの生成分布のパラメータ算出
サンプリングされた潜在変数は、二層の全結合層を介し、さらに、もう一層の全結合層を経て、データの生成分布のパラメータを算出します。
データの生成分布とは、画素ごとの白黒分布をモデル化する二項分布(ベルヌーイ分布)のことであ、パラメータとはベルヌーイ確率のことです。
以下でprobは画素ごとのベルヌーイ確率を集めたテンソルです。
以下のn_pbは、画素数を表し、shape_pbは画像のテンソル形状を表します。
これらの算出方法は、容易なので割愛しています。
// Fully connected layers, and Elu replaced from original Softplus.
auto h3 = f::elu(pf::affine(z, 1, 500, parameters["fc3"]), 1.0);
auto h4 = f::elu(pf::affine(h3, 1, 500, parameters["fc4"]), 1.0);
auto h5 = pf::affine(h4, 1, n_pb, parameters["fc5"]);
auto prob = f::reshape(h5, shape_pb, true);
- ELBO(Evidence Lower Bound変分下界)ロス
ロスの計算は、データを白黒二値化する処理と、実際のデータの生成分布とデータからロスを算出する部分に分かれます。
データを二値化する処理は、以下のようにgreater_equal_scalarレイヤを用います。
// Binarized input
auto xb = f::greater_equal_scalar(xa, 0.5);
一方、ロスを計算する処理は、前述の式の三項をそれぞれ算出したのち、足し合わせて計算します。
以下のプログラムでは、各項の計算は、足し合わせる際にキャンセルされる定数項を除いた式が記述されています。
// E_q(z|x)[log(q(z|x))]
// without some constant terms that will canceled after summation of loss
auto logqz = 0.5 * f::sum(1.0 + logvar, {1}, false);
// E_q(z|x)[log(p(z))]
// without some constant terms that will canceled after summation of loss
auto logpz = 0.5 * f::sum(mu * mu + sigma * sigma, {1}, false);
// E_q(z|x)[log(p(x|z))]
auto logpx = f::sum(f::sigmoid_cross_entropy(prob, xb), {1, 2, 3}, false);
// Vae loss, the negative evidence lowerbound
auto loss = f::mean(logpx + logpz - logqz, {0}, false);