主要APIの変更通知ではないです!
NNablaでは,パラメトリック関数(例:PF.convolution)のパラメータは,TensorFlow V1.xのように,スコープコンテキストと辞書を使ってプロセス単位でグローバルに管理されています.この方法はコードを書くときには書きやすいのですが,パラメータの管理がグローバルなので,ネットワークの全体像が見えずらい場合もあります.一方,PyTorchやChainerでは,パラメトリック関数はクラスで定義されており,コードを書くときには,同じ関数を2行書かなくてはならない場合(initでの初期値化とcall (foward)での呼び出し)もありコードは冗長なのですが,ネットワークの全体像は見えやすかったりします.どっちもプロコンありますが,最近は後者の方が皆さん好きそうです.NNablaで,パラメトリック関数クラスを作ったので,ここではそのコードスニペットを紹介します.
Model by Class | Model by Function | |
---|---|---|
パラメータ管理 | 簡単 | 煩雑 |
長さ | 冗長 | 簡潔 |
読みやすさ | (誰が書いても)読みやすい | (書き方に依存して)読みづらい |
コードスニペット
...
import nnabla.experimental.parametric_function_classes as PFC
...
...
class ResUnit(PFC.Module):
def __init__(self, inmaps=64, outmaps=64):
self.conv0 = PFC.Conv2d(inmaps, inmaps // 2, (1, 1))
self.bn0 = PFC.BatchNorm2d(inmaps // 2)
self.conv1 = PFC.Conv2d(inmaps // 2, inmaps // 2, (3, 3))
self.bn1 = PFC.BatchNorm2d(inmaps // 2)
self.conv2 = PFC.Conv2d(maps // 2, outmaps, (1, 1))
self.bn2 = PFC.BatchNorm2d(outmaps)
self.act = F.relu
self.shortcut_func = False
if inmaps != outmaps:
self.shortcut_func = True
self.shortcut_conv = PFC.Conv2d(inmaps, outmaps, (3, 3))
self.shortcut_bn = PFC.BatchNorm2d(outmaps)
def __call__(self, x, test=False):
s = x
h = x
h = self.act(self.bn0(self.conv0(h), test))
h = self.act(self.bn1(self.conv1(h), test))
h = self.bn2(self.conv1(h), test)
if self.shortcut_func:
s = self.shortcut_conv(s)
s = self.shortcut_bn(s)
h = self.act(h + s)
return h
...
詳細は,main.pyをご覧ください.