Neural Network Librariesが公開されてもうすぐ2年となります。2017年6月のv0.9.1公開時に比べ、多くの機能が新たに加わり、Neural Network Librariesはますます使いやすくなってきました。そこで、これからNeural Network Librariesを使い始めてみたい!自分でネットワークを組んでみたい!より高度な使い方をしたい!という方のために、Neural Network Libraries Step by Stepと題して、いくつかチュートリアル記事を連載していきたいと思います。今回はその第1回目の記事として、Neural Network Librariesの中で最も基本的なコンポーネントであるVariableについて紹介していきます。
Variableは、ニューラルネットワークにおける数値演算時に重要な役割を果たします。
そのふるまいを簡単な演算を例にとって見ていきましょう。
x = nn.Variable((1,))
y = nn.Variable((1,))
z = x + y # (i)
しかしこのままでは、\(x, y, z\)という3つの変数とそれらをどう使った計算をするかを定義しただけにすぎず、計算は実行されていません(その計算結果は評価されていません)。
試しにここで\(z\)の値を調べてみると、
print(z.d)
[3.3362283e+21]
\(x\)も\(y\)も定義されていないので、当然おかしな値が出てきます(また、ここでz.dでなくzを出力させようとすると、zの値ではなくzというVariableの情報が出てしまうので注意しましょう)。
では、ここで\(x\)と\(y\)に値を代入した場合はどうなったか見てみましょう。Neural Network LibrariesにおいてVariableに値を代入するときは、このように書きます。
x.d = 1.
y.d = 2.
z = x + y # (ii)
ここで書かれているように、数値にアクセスするだけでなく、Variable.dは数値の代入にも使うことができます。
1+2=3なので、当然\(z\)の値は3になっているはずです・・・が、
print(z.d)
[3.3362283e+21]
なっていません。
これは、Variableを用いて定義される演算は「静的計算グラフ」を構築するようになっており、ただ演算を定義しただけではその値は計算されないことに起因します。
イメージとしては、以下のような図のようになっています。まず、(i)の時点では以下のようなグラフが構築されていました。
(ii)の時点では\(x\)と\(y\)の値が代入されましたが、依然として計算は実行されておらず、以下のようなグラフが存在している状態です。
では、どうすれば演算を実行させることができるのでしょうか。
そのためには、VariableのメソッドのひとつであるVariable.forward()を使う必要があります。
z.forward()
print(z.d)
[3.]
やっと3が出てきてくれました。このように、Variableを使うときは、入力となる変数を用意し、計算グラフを定義した後、その計算の出力となるVariableに対してVariable.forward()を呼ぶことで初めて計算が実行されます(定義された計算結果が「評価」されます)。この時点のグラフは以下のようになります。
基本的にVariable.forward()を呼ぶのは計算グラフの末端のVariableです。途中のVariableに対しVariable.forward()を呼んでしまうと、そのVariableまでの計算しか実行されません。以下はその例です。
x, y, z, w = [nn.Variable((1, )) for _ in range(4)]
x.d = 1
y.d = 2
z = x + y
w = 2*z # (iii)
z.forward() # (iv)
print(z.d, w.d)
[2.] [3.3362328e+21]
この状態は以下の図が対応します。
もちろん、ここできちんとw.forward()とすると、\(w\)の計算結果(2*3=6)が得られます。
また、中間値である\(z\)の値はこの再度のw.forward()実行によって影響を受けることはありません。
w.forward()
print(w.d)
[6.]
ここでようやく上の図のような状態となります。
このように、Variableを用いた演算はVariable.forward()を忘れてしまうと思わぬ挙動を生じます。Examplesなどの実用例を見ると、ニューラルネットワークの学習時には定義した損失関数lossに対し、イテレーションごとにVariable.forward()を実行させているのはこのためです。
とはいえ、常にVariable.forward()を呼ばないといけないわけではありません。
dynamicモードを利用する事で、演算を定義した瞬間にその計算が実行されるようにすることも可能です。この機能については別の記事で紹介します。