Neural Network Libraries Step by Step 2

2019年4月26日 金曜日

Tips , チュートリアル

Posted by Takuya Yashima

前回の記事ではVariableの基本的な使い方を見てきましたが、そこではVariableによる演算を定義しても、計算グラフが構築されるだけであって演算処理自体は行われず、演算の実行にはVariable.forwardを使う必要があるということを説明しました。 Neural Network Librariesのデフォルトの挙動ではこのような「静的な」グラフを構築するようになっていますが、「動的」な計算グラフを構築することも可能です。すなわち、演算を定義した瞬間にその演算が実行されるようになります。

この「動的」計算グラフを構築するには、Neural Network LibrariesをDynamic Modeに切り替える必要があります。以下でその例を見ていきましょう。

import nnabla as nn
nn.set_auto_forward(True)
a = nn.Variable((1,))
b = nn.Variable((1,))
a.d = 3.
b.d = -2.

nn.set_auto_forward(True)によって、Dynamic Modeに入りました。それ以降、演算処理の挙動が変わります。 実際に演算を定義してみましょう。

c = a + b

これまでは、ここでc.forward()を呼び出さないと計算が実行されず、print(c.d)を行っても所望の値は出力されませんでした。 今回がどうなるか見てみましょう。

print(c.d)
[1.]

計算が実行されているのがわかります。 この状態でさらに計算グラフの構築を続けていきましょう。

d = 2 * c
print(d.d)
[2.]

一度nn.set_auto_forward(True)を呼び出すと、それ以降は常にDynamic Modeで処理が行われることがわかります。 もう一度デフォルトの挙動(静的グラフを構築するStatic Mode)にしたいときは、どうすればよいでしょうか。

nn.set_auto_forward(False)
e = d / 4.
print(e.d)
[8.038705e+10]

nn.set_auto_forward(False)を実行すると、Dynamic Modeから抜け、Static Modeに戻ります。
Dynamic Modeを抜けた後に定義した演算が実行されていないのが分かります。もちろん、forward()を行うことで演算はきちんと実行されます。

e.forward()
print(e.d)
[0.5]

nn.set_auto_forward()を適切なタイミングで実行してStatic ModeとDynamic Modeを切り替えれば、 処理の中でVariableのメソッドであるある部分だけはStatic Modeで処理を行い、その他の部分ではDynamic Modeを使うといった処理を行うことも可能です。 また、こういった使い分けをしたい場合、以下のようなコンテキストマネージャを使った手法が便利です。

x = nn.Variable((1, ))
x.d = 3.
y = nn.Variable((1, ))
y.d = 2.
with nn.auto_forward():
    z = x + y
    w = -2. * z
w_static = x + y
print("w = {}, w_static={}".format(w.d, w_static.d))
w = [-10.], w_static=[6.]

このように、withブロックの中だけがDynamic Modeで処理されます。 w_staticはブロックの外なので演算は定義されただけで実行はされず、結果は予期したものとは異なっているのがわかります。 この場合きちんと後でw_static.forward()を実行すれば問題ありませんが、Static ModeとDynamic Modeを混在させるときはそれらを切り替えるタイミングに注意しましょう。

w_static.forward()
print(w_static.d)
[5.]

 

Tips:ワンライナーでVariableに値をセットする

これまではVariableの定義、そしてそのVariableに値をセットする際に二行を使っていましたが、実際にはこれを一行でまとめることができます。 Variableのメソッドであるapply()を使うか、

x = nn.Variable((1,1)).apply(d=1.0)

numpyを用いた変数を扱う場合にはfrom_numpy_array()を使うことでコードがきれいにまとまります。

x = nn.Variable.from_numpy_array(np.ones((1,1)))

いずれもprint(x.d)を使うと、値がセットされていることが確認できます。

[[1.]]