(Note that this is not an indication to change the primary API.)
In NNabla, parametric functions (e.g., PF.convolution) are defined in a similar manner as TensorFlow V1.x, i.e., trainable parameters are managed in a global scope of a process by using a dictionary matching parameter names to a trainable parameter and the scope context.
It is intuitively straightforward when writing code, while a bit complicated to manage trainable parameters, since trainable parameters are not managed in a local scope of a class but globally, and it may be difficult to see the whole picture of the neural network at a single glance.
On the other hand, f trainable parameters are held in a class as in PyTorch or Chainer, it is very easy to see the whole picture of a network, while representation of the network may be lengthier, since we have to write two lines about a parametric function to be used i.e., in the init method and call (or forward) method of a class. There are pros and cons for each, but people seem to prefer PyTorch-like parametric function definition.
Model by Class | Model by Function | |
---|---|---|
Parameter Management | Easy | Hard |
Length | Redundant | Brief |
Readability | Good | not Good |
Code Snippet
...
import nnabla.experimental.parametric_function_classes as PFC
...
...
class ResUnit(PFC.Module):
def __init__(self, inmaps=64, outmaps=64):
self.conv0 = PFC.Conv2d(inmaps, inmaps // 2, (1, 1))
self.bn0 = PFC.BatchNorm2d(inmaps // 2)
self.conv1 = PFC.Conv2d(inmaps // 2, inmaps // 2, (3, 3))
self.bn1 = PFC.BatchNorm2d(inmaps // 2)
self.conv2 = PFC.Conv2d(maps // 2, outmaps, (1, 1))
self.bn2 = PFC.BatchNorm2d(outmaps)
self.act = F.relu
self.shortcut_func = False
if inmaps != outmaps:
self.shortcut_func = True
self.shortcut_conv = PFC.Conv2d(inmaps, outmaps, (3, 3))
self.shortcut_bn = PFC.BatchNorm2d(outmaps)
def __call__(self, x, test=False):
s = x
h = x
h = self.act(self.bn0(self.conv0(h), test))
h = self.act(self.bn1(self.conv1(h), test))
h = self.bn2(self.conv1(h), test)
if self.shortcut_func:
s = self.shortcut_conv(s)
s = self.shortcut_bn(s)
h = self.act(h + s)
return h
...
See the main.py for details.