PyTorch-like functions

Tuesday, March 19, 2019


Posted by Kazuki Yoshiyama

(Note that this is not an indication to change the primary API.)

In NNabla, parametric functions (e.g., PF.convolution) are defined in a similar manner as TensorFlow V1.x, i.e., trainable parameters are managed in a global scope of a process by using a dictionary matching parameter names to a trainable parameter and the scope context.

It is intuitively straightforward when writing code, while a bit complicated to manage trainable parameters, since trainable parameters are not managed in a local scope of a class but globally, and it may be difficult to see the whole picture of the neural network at a single glance.

On the other hand, f trainable parameters are held in a class as in PyTorch or Chainer, it is very easy to see the whole picture of a network, while representation of the network may be lengthier, since we have to write two lines about a parametric function to be used i.e., in the init method and call (or forward) method of a class. There are pros and cons for each, but people seem to prefer PyTorch-like parametric function definition.

Model by Class Model by Function
Parameter Management Easy Hard
Length Redundant Brief
Readability Good not Good

Code Snippet

import nnabla.experimental.parametric_function_classes as PFC

class ResUnit(PFC.Module):
def __init__(self, inmaps=64, outmaps=64):
self.conv0 = PFC.Conv2d(inmaps, inmaps // 2, (1, 1))
self.bn0 = PFC.BatchNorm2d(inmaps // 2)
self.conv1 = PFC.Conv2d(inmaps // 2, inmaps // 2, (3, 3))
self.bn1 = PFC.BatchNorm2d(inmaps // 2)
self.conv2 = PFC.Conv2d(maps // 2, outmaps, (1, 1))
self.bn2 = PFC.BatchNorm2d(outmaps)
self.act = F.relu

self.shortcut_func = False
if inmaps != outmaps:
self.shortcut_func = True
self.shortcut_conv = PFC.Conv2d(inmaps, outmaps, (3, 3))
self.shortcut_bn = PFC.BatchNorm2d(outmaps)

def __call__(self, x, test=False):
s = x
h = x
h = self.act(self.bn0(self.conv0(h), test))
h = self.act(self.bn1(self.conv1(h), test))
h = self.bn2(self.conv1(h), test)
if self.shortcut_func:
s = self.shortcut_conv(s)
s = self.shortcut_bn(s)
h = self.act(h + s)
return h

See the main.py for details.