シリーズで、NNablaのPython-like C++APIをご紹介したいと思います。
第七回目は、メトリック学習(SiameseNet)のサンプルを元にして、複数のデータ供給を使った学習例をご紹介いたします。
メトリック学習に適したニューラルネットワークの代表例であるシャミーズネットワークを説明します。
メトリック学習は、データ間の距離情報を元にデータから特徴空間への射影方法を学習する学習技術です。
また、シャミーズネットワークは、2つのデータをそれぞれ入力できる同一の並列ネットワークと、これらのネットワークの出力の距離を算出するネットワークで構成されています。
シャミーズネットワークを使うと、2つのデータの特徴空間上での距離を算出できます。
シャミーズネットワークを用いたメトリック学習は、Yann LeCunらにより提案されました。
この学習では、算出されたネットワークが、ラベルとして与えられた距離に近づくようにパラメータを更新していきます。
以下がその論文です。
Learning a Similarity Metric Discriminatively, with Application to Face Verification
Sumit Chopra, Raia Hadsell, Yann LeCun
http://yann.lecun.com/exdb/publis/pdf/chopra-05.pdf
SiameseNetの実行ファイルを作るプログラムは、以下の3つのソースコードファイルで構成されます。
以下では、プログラムのコアにあたるsiamese_training.hppから、これまで説明しなかった内容を中心に説明します。
シャミーズネットワーク学習プログラム:siamese_training.hpp
このソースコードファイルでは、シャミーズネットワークの構築、ロスの接続、学習、評価の記述を行っています。
ここでは、ネットワークモデルの構築とロスの説明を行います。
- シャミーズネットワークの設計
このプログラムで、シャミーズネットワークは、mnist_lenet_siamese関数で記述されます。
この関数には、2つのデータを特徴空間に射影する処理と、2つの特徴変数の距離を算出する処理が記述されます。
データを特徴空間に射影する処理は、mnist_lenet_feature関数で記述されます。
この関数は、データから特徴変数を算出する通常のフィードフォワードネットワークです。
このネットワークは2つのデータを同一の特徴空間に射影するためパラメータを共有します。
プログラム上では、同一のパラメータを引数として渡すことで実現しています。
また、距離を算出する処理は、ここでは、特徴ベクトルの差の二乗和を算出することで実現しています。
auto h0 = mnist_lenet_feature(x0, params, test);
auto h1 = mnist_lenet_feature(x1, params, test);
auto h = f::squared_error(h0, h1);
auto p = f::sum(h, {1}, true);
- コントラスティブロスの記述
コントラスティブロスは、以下の式で表されます。
\(L(x_0,x_1,t;\theta)=t*S_\theta (x_0,x_1 )+(1-t)*[\text{max}(m-\sqrt{S_\theta (x_0,x_1)+\epsilon},0) ]^2
\)
ここで、\(x_0\)、\(x_1\)は入力ペア、\(t\)は、入力ペアが同じクラスならば1、違うクラスならば0となるラベル、\(S_\theta(x_0, x_1)\)はペアの特徴空間上での距離で、\(\theta\)はニューラルネットワークのパラメータです。
\(m\)はマージンパラメータで、\(\epsilon\)は微小な数値です。
第一項は、同じクラスの入力ペアは特徴空間上で近くなると小さくなります。
第二項は、異なるクラスの入力ペアの距離がマージンパラメータ程度になると小さくなります。
2つの効果を組み合わせることで、同じクラスのデータは特徴空間上で近く、異なるクラスのデータは特徴空間上でマージン程度離れるように、データから特徴空間への射影が学習されます。
プログラム上では、コントラスティブロスは、以下のようにcontrastive_loss関数で記述されています。
変分オートエンコーダで説明したような四則演算のオーバライドを利用することで、簡潔な表現が実現できています。
auto sim_cost = l * sd;
auto dissim_cost = (1 - l) * (f::pow_scalar(f::maximum_scalar(margin - f::pow_scalar(sd + eps, 0.5), 0), 2.0));
return sim_cost + dissim_cost;