シリーズで、NNablaのPython-like C++APIをご紹介したいと思います。
第六回目は、DCGANのサンプルを元にして、複数のソルバーを使う学習例をご紹介いたします。
生成モデル用ニューラルネットワークの別の代表例であるGAN(Generative Adversarial Network:敵対生成ネットワーク)の派生版であるDCGAN(Deep Convolutional GAN)の学習プログラム例を説明します。
GANは、2014年に、Ian Goodfellowらによって提唱された生成モデルと識別モデルを敵対させるネットワークです。
以下がその論文です。
Generative Adversarial Networks
Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, Yoshua Bengio
https://arxiv.org/abs/1406.2661
GANにおいて、生成モデルは、乱数を入力し、データと同じ次元をもつ疑似データを生成するニューラルネットワークです。
また、識別モデルは、生成データか本物のデータのいずれかを入力し、本物か偽物かを識別するニューラルネットワークです。
学習時に、生成モデルは、識別モデルに生成データを偽物を見破られないように、また、識別モデルは、生成データを偽物を見破れるように対立させることで、より本物に近いデータを生成する生成モデルの実現が期待されます。
このため、識別モデルは、生成データか、本物のデータを入力したときに、これを本物か偽物かを正しく識別するように学習されます。
目的関数は、以下のように二クラス分類に適したバイナリクロスエントロピーを用います。
\(L_{discriminator} (\phi;X)=\text{E}_{x\sim X} [-\text{log}(D_\phi (x))]+\text{E}_{z\sim N(z;0,1)} [-\text{log}(1-D_\phi (G_\theta (z))) ]
\)
ここで、\(X\)は本物データの集合、\(\phi \)は識別モデルのパラメータ、\(\theta\)は生成モデルのパラメータ、\(D_\phi (x)\)は\(x\)が本物である確率を算出するニューラルネットワーク、\(G_\theta (z)\)は乱数\(z\)から疑似データを生成するニューラルネットワークです。
また、第一項は、本物データを本物と識別したときに小さくなるロス、第二項は、生成データを偽物と識別したときに小さくなるロスです。
識別モデルの学習時には、生成モデルのパラメータは学習しないように構成されています。
一方、生成モデルは、識別が誤った識別をすると小さくなるロス関数で学習されます。
目的関数は、たとえば、以下のように設定されます。
\(L_{generater} (\theta)=\text{E}_{z\sim N(z;0,1)} [-\text{log}(D_\phi (G_\theta (z))) ]
\)
生成モデルの学習時には、識別モデルのパラメータは学習しないように構成されています。
DCGAN(Deep Convolutional GAN)は、2015年にAlec Radfordらによって提唱され、GANをたとえば画像生成に適するようにコンボリューションとバッチノーマライゼーションを用いて精緻化したネットワーク構成をしています。
以下がその論文です。
Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks
Alec Radford, Luke Metz, Soumith Chintala
https://arxiv.org/abs/1511.06434
DCGANの学習実行ファイルを作るプログラムは、以下の3つのソースコードファイルで構成されます。
以下では、プログラムのコアにあたるdcgan_training.hppから、これまで説明しなかった機能を中心に説明します。
DCGAN学習プログラム:dcgan_training.hpp
このソースコードファイルでは、DCGANにおける生成モデル、識別モデルのネットワークの記述、それぞれの入力からロスまでのネットワークの構築、ソルバーのセットアップ、学習、評価の記述を行っています。
ここではGANに固有な記述である入力からロスまでのネットワークの構築と、ソルバーのセットアップについて、これまで説明していない機能の説明を交えて紹介します。
- GANのロス関数
GANのロスには、識別モデルのロスと、生成モデルの2つのロスがあり、それぞれ異なるネットワークを用いて算出します。
さらに、識別モデルのロスは、本物データから算出されたロスと、生成データ(偽物データ)から算出される2つのロスを足し合わせて算出されます。
サンプルプログラムでは、これらの合計3つに分けられたロスに対して、それぞれ対応するネットワークを構築しています。
3つのネットワークは、本物データから識別モデルのロスの第一項を算出するネットワーク、偽物データから識別モデルのロスの第二項を算出するネットワーク、偽物データから生成モデルのロスを算出するネットワークとなっています。
偽物データから生成モデルのロスへ
以下は、偽物データから生成モデルのロスを算出するネットワークを記述する部分です。
偽物データは乱数から作りますが、乱数から偽物データを生成するまでのネットワークの構築は、generator関数の中で行っています。
また、偽物データの本物らしさ(本物と推定される確率)を算出するネットワークの構築は、discriminator関数で行っています。
途中のset_persistent(true)は、この変数(fake)を永続化する指示を表しています。
この指示が必要となるのは、変数fakeを後で利用するからです。
NNablaでは、ネットワーク中の中間変数を後で使いたい場合には、このような指示が必要です。
auto z = make_shared(Shape_t({batch_size, 100, 1, 1}), false);
auto fake = generator(z, max_h, false, params["gen"]);
fake->set_persistent(true);
auto pred_fake = discriminator(fake, max_h, false, params["dis"]);
auto loss_gen = f::mean(f::sigmoid_cross_entropy(pred_fake, f::constant(1, {batch_size, 1})), {0, 1}, false);
偽物データから識別モデルのロスへ
以下は、偽物データから識別モデルのロスを算出するネットワークです。
偽物データfakeは、まず、一旦新たな変数fake_disに置き換えられます。
そして、discriminator関数に入力すると、偽物入力の本物らしさを算出するネットワークが構築されます。
偽物データfakeを新たな変数fake_disに置き換える理由は、生成モデルへの誤差逆伝搬が起きないようにネットワークを切断するためです。
なお、set_need_grad(true)は、fake_disまでは誤差逆伝搬が起こるようにする指示です。
auto fake_dis = make_shared< CgVariable>(fake->variable(), true);
fake_dis->set_need_grad(true);
auto pred_fake_dis = discriminator(fake_dis, max_h, false, params["dis"]);
pred_fake_dis->set_persistent(true);
auto loss_dis = f::mean(f::sigmoid_cross_entropy(pred_fake_dis, f::constant(0, {batch_size, 1})), {0, 1}, false);
本物データから識別モデルのロスへ
以下は、本物データから識別モデルのロスを算出するネットワークです。
本物データxは、discriminator関数に入力すると、本物入力の本物らしさを算出するネットワークが接続されます。
auto x = make_shared(Shape_t({batch_size, 1, 28, 28}), false);
auto pred_real = discriminator(x, max_h, false, params["dis"]);
loss_dis = loss_dis + f::mean(f::sigmoid_cross_entropy(pred_real, f::constant(1, {batch_size, 1})), {0, 1}, false);
- モデル別ソルバー
GANでは、普通、生成モデルと識別モデルを別々に学習します。
このため、ソルバーを2つ用意し、それぞれのネットワークに関するパラメータだけを学習できるようにする必要があります。
ネットワーク別のパラメータの指定は、パラメータ保持クラスであるParameterDirectoryの機能を使います。
以下のように、ParameterDirectoryに、予め設定したモデル別の名前空間(”gen”と“dis”)を与えることで、パラメータを区別することができます。
auto solver_gen = create_AdamSolver(ctx, learning_rate, 0.5, 0.999, 1.0e-8);
auto solver_dis = create_AdamSolver(ctx, learning_rate, 0.5, 0.999, 1.0e-8);
solver_gen->set_parameters(params["gen"].get_parameters());
solver_dis->set_parameters(params["dis"].get_parameters());