In nnabla, QAT (quantization-aware training) is performed in two stages, RECORDING and TRAINING.
In RECORDING stage, we collect and record the dynamic range of each parameter and buffer.
In TRAINING stage, we insert Quantization & Dequantization node to simulate the quantization effect.
We provide QATScheduler to support Quantization-Aware-Training.
Creating a QATScheduler:
from nnabla.utils.qnn import QATScheduler, QATConfig, PrecisionMode # Create training network pred = model(image, test=False) # Create validation network vpred = model(vimage, test=True) # configure of QATScheduler config = QATConfig() config.bn_folding = True config.bn_self_folding = True config.channel_last = False config.precision_mode = PrecisionMode.SIM_QNN config.skip_bias = True config.niter_to_recording = 1 config.niter_to_training = steps_per_epoch * 2 # Create a QATScheduler object qat_scheduler = QATScheduler(config=config, solver=solver) # register the training network to QATScheduler qat_scheduler(pred) # register the validation network to QATScheduler qat_scheduler(vpred, training=False)
Shorthand for TensorRT:
from nnabla.utils.qnn import QATScheduler, QATTensorRTConfig config = QATTensorRTConfig() qat_scheduler = QATScheduler(config=config, solver=solver) qat_scheduler(pred) qat_scheduler(vpred, training=False)
Modifying your training loop:
More reinforcement learning examples have been added to nnabla-examples!
Each sample code is implemented with nnabla-rl to solve tasks such as CartPole, Pendulum, as well as Atari and MuJoCo.
Please try it!
Dynamically load mpi library and nccl library
The OpenMPI and NCCL libraries used in distributed training are now dynamically loaded at runtime.
Up tot now, it was necessary to install
nnabla-ext-cuda110-nccl2-mpi2-1-1 if system has OpenMPI v2, and
nnabla-ext-cuda110-nccl2-mpi3-1-6 if OpenMPI v3.
From now, you can simply use
nnabla-ext-cuda110 to run distributed training no matter which version of openMPI/NCCL are installed on your system.
- Derived class from Module wraps all methods for parameter scope
- fix inf serialization bug
- save no_image_normalization in executor if exist
- support libnccl.so.2
- fix segment fault caused by derived struct invalid
- fix crash bug for max_pooling function
- Build flatc from latest upstream release (CPU / GPU)
- Update gpg keys for centos based nvidia docker
- update nvidia gpg key
- use numpy 1.20.0 and later (CPU / GPU)
- fix the problem of tensorboardX support document
- Update protobuf version (CPU) / GPU)
- Adjust protobuf version dependancy on Win/Non-Win platform