シリーズで、NNablaのPython-like C++APIをご紹介したいと思います。
第八回目は、半教師学習(VAT)のサンプルを元にして、一イテレーションで複数の学習処理を行う学習例をご紹介いたします。
深層学習における半教師学習の代表例であるVAT(Virtual Adversarial Training:仮想敵対学習)の学習サンプルを説明します。
半教師学習は、ラベルありデータを使った教師あり学習と、ラベルなしデータを使った教師なし学習を組み合わせて学習する機械学習手法です。
半教師学習は、主として、教師あり学習の性能をラベルなしデータを使って向上する目的で用います。
VATは2015年、Takeru Miyatoらによって提案された深層学習の半教師学習の手法の一つです。
VATは、Ian Goodfellowらに提唱された教師あり学習の敵対学習(Adversarial Training:ここでは敵対学習と呼ぶことにします)を、半教師学習で使えるように拡張した技術です。
敵対学習は、識別器が最も誤りやすいノイズをデータに付加して学習する技術です。
この技術の狙いは、最も誤りやすいノイズに対しても耐性をもつような学習をすることで、外乱にロバストな識別器を獲得することです。
敵対学習の名前の由来は、識別器とノイズが敵対的な関係にあることに起因します。
敵対学習では、普通、敵対ノイズの算出に正解ラベルが必要です。
このため、教師あり学習が前提となります。
これに対し、仮想敵対学習(VAT:Virtual Adversarial Training)は、識別器の予測が最も変化しやすいノイズをデータに付加して学習します。
つまり、ノイズの算出に正解ラベルが不要となり、教師なし学習や半教師学習で用いることができます。
以下は、VATの論文です。
Distributional Smoothing with Virtual Adversarial Training
Takeru Miyato, Shin-ichi Maeda, Masanori Koyama, Ken Nakae, Shin Ishii
https://arxiv.org/abs/1507.00677
VATを用いた半教師学習は3つの処理を交互に繰り返して行います。
ラベルありデータを用いた教師あり学習、敵対ノイズの算出、ラベルなしデータと敵対ノイズを用いた教師なし学習の3つの処理です。
ラベルありデータによる教師あり学習には、クロスエントロピーなどの標準的なロスを用います。
ラベルなしデータによる教師なし学習は、以下のロスを使います。
\(
L_u (x;n,\theta)=-D(f_\theta (x),f_\theta (x+\epsilon n))
\)
このロスは、ノイズなしデータとノイズありデータの特徴空間上での距離に相当します。
クラス分類の場合は、この特徴ベクトルをもとに算出したクラスの確率分布の間の距離を用います。
クラスの確率分布の算出には、ソフトマックス関数を用います。
また、クラスの確率分布の間の距離には、カルバックライブラー距離を用います。
\(
D(f_\theta (x),f_\theta (x+\epsilon n))=D_{\text{KL}}(\text{softmax}(f_\theta (x)); \text{softmax}(f_\theta (x+\epsilon n)))
\)
このロスを小さくすることで、ノイズがあっても予測が変わらないようなロバストな識別器を獲得するのが狙いです。
ここで用いるノイズが仮想敵対ノイズであり、上記ロスを最大にするようなノイズ\(n^*\)を用います。
\(
n^*=\text{argmax}_n {L_u (x;n,\theta)} s.t.|n|^2=1
\)
この計算には、以下の処理を再帰的に繰り返すパワーメソッドという逐次近似を使います。
\(
n←\nabla_n [L_u (x;n,\theta)]/ \sqrt{|\nabla_n [L_u (x;n,\theta)]|^2 }
\)
逐次回数は、論文などによれば、実験的には一回で十分であることがわかっています。
VATによる半教師学習のプログラムは、以下の3つのソースコードファイルで構成されます。
以下では、プログラムのコアにあたるvat_training.hppを中心に説明します。
VAT学習プログラム:vat_training.hpp
このソースコードファイルでは、VATのネットワークのビルド、学習、評価ループの記述を行っています。
ここでは、VATのネットワークの記述と、学習ループの記述を説明します。
- VATのネットワークの記述
VATのネットワークは、教師あり学習用ネットワーク、教師なし学習用ネットワークで構成されます。
教師あり学習用ネットワークは、フィードフォワードネットワークにクロスエントロピーによるロスを接続したネットワークです。
フィードフォワードネットワークは、データを特徴空間に射影する役割を持っています。
以下3行目のmlp_net関数の中身をサンプルプログラムでみると、そのネットワークが記述されています。
auto xl = make_shared(Shape_t({batch_size_l, 1, 28, 28}), false);
auto tl = make_shared(Shape_t({batch_size_l, 1}), false);
auto yl = mlp_net(xl, n_h, n_y, params, false);
auto loss_l = f::mean(f::softmax_cross_entropy(yl, tl, 1), {0, 1}, false);
教師なし学習用ネットワークは、ノイズなしデータとノイズありデータの特徴空間上での距離をロスとするネットワークです。
このネットワークは、ノイズ算出時と、教師なし学習時でパラメータを共有しています。
ネットワーク構造は、mlp_net関数で、パラメータはparamsで記述されています。
ただし、ネットワークは、ノイズ算出用と学習用でわけて作っています。
これは、敵対ノイズのノルムを個別に設定するためです。
特徴空間上での特徴ベクトル同士の距離は、distance関数で記述しています。
以下は、特徴空間でのノイズなしデータの特徴ベクトルを算出するためのネットワークです。
以下でyuをy1で再定義している理由は、yuの計算を何度も繰り返さないためと、yuへの誤差逆伝搬を防ぐためです。
auto xu = make_shared(Shape_t({batch_size_u, 1, 28, 28}), false);
auto yu = mlp_net(xu, n_h, n_y, params, false);
auto y1 = make_shared(yu->variable(), true);
y1->set_need_grad(false);
以下は敵対ノイズを算出するためのネットワークです。
以下で\(r\)は、前述の逐次解法における規格化処理を行うための変数です。
この変数は、あとの処理でも用いることができるように、persistentを設定しています。
auto noise = make_shared(Shape_t({batch_size_u, 1, 28, 28}), true);
auto r = noise / f::pow_scalar((f::sum(f::pow_scalar(noise, 2), {1, 2, 3}, true)), 0.5);
r->set_persistent(true);
以下は、敵対ノイズを算出するためのロスと、教師なし学習を行うためのロスを算出する箇所です。
auto y2 = mlp_net(xu + xi * r, n_h, n_y, params, false);
auto y3 = mlp_net(xu + eps * r, n_h, n_y, params, false);
auto loss_k = f::mean(distance(y1, y2), {0}, false);
auto loss_u = f::mean(distance(y1, y3), {0}, false);
- 学習ループの記述
学習ループも、教師あり学習の処理と、敵対ノイズを算出する処理、教師なし学習を行う処理に分かれます。
まず、以下は、教師あり学習の処理で、これは今まで説明してきたとおりです。
// Training with labeled data solver->zero_grad();
loss_l->forward(/*clear_buffer=*/false, /*clear_no_need_grad=*/true);
loss_l->variable()->grad()->fill(1.0);
loss_l->backward(/*NdArrayPtr grad=*/nullptr, /*clear_buffer=*/true);
solver->weight_decay(weight_decay);
solver->update();
続いて、教師なし学習の処理ですが、まずノイズなしデータから特徴ベクトルを算出します。
// Training with unlabeled data
train_data_iterator.provide_data(cpu_ctx, batch_size_u, xu, tu);
yu->forward(/*clear_buffer=*/false, /*clear_no_need_grad=*/false);
敵対ノイズは、まず、逐次解法の初期値をガウスノイズから求めます。
// Calculating virtual adversarial noise
float_t *n_d = noise->variable()->cast_data_and_get_pointer(cpu_ctx, true);
for (int i = 0; i < noise->variable()->size(); i++, n_d++) *n_d = normal(engine);
続いて、敵対ノイズの逐次法による算出では、ノイズに逆伝搬された勾配を、ノイズに置き換える処理を繰り返します。
for (int k = 0; k < max_iter_power_method; k++) {
r->variable()->grad()->fill(0.0);
loss_k->forward(/*clear_buffer=*/false, /*clear_no_need_grad=*/true);
loss_k->variable()->grad()->fill(1.0);
loss_k->backward(/*NdArrayPtr grad=*/nullptr, /*clear_buffer=*/true);
n_d = noise->variable()->cast_data_and_get_pointer(cpu_ctx, true);
float_t *r_g = r->variable()->cast_grad_and_get_pointer(cpu_ctx, false);
for (int i = 0; i < noise->variable()->size(); i++, n_d++, r_g++) *n_d = *r_g;
}
最後に、教師なし学習の処理を行います。
// Updating with virtual adversarial noise
solver->zero_grad();
loss_u->forward(/*clear_buffer=*/false, /*clear_no_need_grad=*/true);
loss_u->variable()->grad()->fill(1.0);
loss_u->backward(/*NdArrayPtr grad=*/nullptr, /*clear_buffer=*/true);
solver->weight_decay(weight_decay);
solver->update();
以上の三つの処理を、ロスの更新が収束するか、あるいは、繰り返し最大数まで繰り返すことで学習を完了します。