PyTorch-like functions

2019年3月19日 火曜日

API

Posted by Kazuki Yoshiyama

主要APIの変更通知ではないです!

NNablaでは,パラメトリック関数(例:PF.convolution)のパラメータは,TensorFlow V1.xのように,スコープコンテキストと辞書を使ってプロセス単位でグローバルに管理されています.この方法はコードを書くときには書きやすいのですが,パラメータの管理がグローバルなので,ネットワークの全体像が見えずらい場合もあります.一方,PyTorchやChainerでは,パラメトリック関数はクラスで定義されており,コードを書くときには,同じ関数を2行書かなくてはならない場合(initでの初期値化とcall (foward)での呼び出し)もありコードは冗長なのですが,ネットワークの全体像は見えやすかったりします.どっちもプロコンありますが,最近は後者の方が皆さん好きそうです.NNablaで,パラメトリック関数クラスを作ったので,ここではそのコードスニペットを紹介します.

Model by Class Model by Function
パラメータ管理 簡単 煩雑
長さ 冗長 簡潔
読みやすさ (誰が書いても)読みやすい (書き方に依存して)読みづらい

コードスニペット

...
import nnabla.experimental.parametric_function_classes as PFC
...

...
class ResUnit(PFC.Module):
    def __init__(self, inmaps=64, outmaps=64):
        self.conv0 = PFC.Conv2d(inmaps, inmaps // 2, (1, 1))
        self.bn0 = PFC.BatchNorm2d(inmaps // 2)
        self.conv1 = PFC.Conv2d(inmaps // 2, inmaps // 2, (3, 3))
        self.bn1 = PFC.BatchNorm2d(inmaps // 2)
        self.conv2 = PFC.Conv2d(maps // 2, outmaps, (1, 1))
        self.bn2 = PFC.BatchNorm2d(outmaps)
        self.act = F.relu

        self.shortcut_func = False
        if inmaps != outmaps:
            self.shortcut_func = True
            self.shortcut_conv = PFC.Conv2d(inmaps, outmaps, (3, 3))
            self.shortcut_bn = PFC.BatchNorm2d(outmaps)

    def __call__(self, x, test=False):
        s = x
        h = x
        h = self.act(self.bn0(self.conv0(h), test))
        h = self.act(self.bn1(self.conv1(h), test))
        h = self.bn2(self.conv1(h), test)
        if self.shortcut_func:
          s = self.shortcut_conv(s)
          s = self.shortcut_bn(s)
        h = self.act(h + s)
        return h
...

詳細は,main.pyをご覧ください.