Profilerの使い方
今回は、nnablaのutilsとして用意されているGraphProfiler
について紹介します。
このモジュールは、nnablaで構築したニューラルネットワークのグラフにおいて、それぞれのlayer(function)でかかる実行速度を計測するツールです。
モデルのデプロイ時などに、どのfunctionの処理に時間がかかっているのかを効率的に確認することができます。
まずは、必要なモジュールをimportします。
通常nnablaを利用するときのモジュールに加えて、
4行目のようにnnabla.utils.profiler
からGraphProfiler
をimportしておきます。
(以下の例では、続くコードサンプルで登場するモジュールのみimportしています。)
import nnabla as nn
from nnabla.ext_utils import get_extension_context
import nnabla.solvers as S
from nnabla.utils.profiler import GraphProfiler
続いて、nnablaのglobal contextを設定します。
contextによって、計算を実行するデバイスの切替(cpu / gpu等)などを指定することができます。
今回はcpuでの計測を行いますが、gpuでの実行速度を計測したい場合はdevice = "cudnn"
と指定してください。
# Set up nnabla context
device = "cpu" # you can also use GPU ("cudnn")
ctx = get_extension_context(device)
nn.set_default_context(ctx)
続いて、モデルを構築します。
以下の例は簡略化して書いた一例です。
モデル、損失関数、最適化手法については任意のものを利用することができます。
# Create graph and set solver
x = nn.Variable(shape=...)
t = nn.Variable(shape=...)
y = model(x)
loss = F.mean(F.softmax_cross_entropy(y, t))
solver = S.Sgd()
solver.set_parameters(nn.get_parameters())
構築したモデルの出力をGraphProfilerに渡して.run()
を実行することで、そこから入力まで遡って全ての関数の実行速度を計測します。
この例では、モデルの末端のloss
を渡していますが、中間出力を渡すことで、入力からそこまでの関数だけを計測することもできます。
profiler = GraphProfiler(loss, solver=solver, device_id=0, ext_name=device, n_run=1000, max_measure_execution_time=1)
profiler.run()
ここで、特に計測に影響するGraphProfiler
の引数について少し説明しておきます。
solver
:
最適化に利用するsolverの指定。
この引数を指定するとパラメータのupdateにかかる時間も計測されます。-
n_run
:
各関数で何回の計測を行うか。
この実行回数での平均値が報告される。 -
max_measure_excution_time
:
1関数あたりの計測にかける時間の最大値。
単位は[sec]となっていて、計測時間がこの時間を超えると次の関数に移行します。
優先度はn_run
より高く、計測時間によってはn_run
で指定した回数よりも少ない回数しか計測しない場合もあります。
計測した結果は、以下の例のようにGraphProfilerCsvWriter
を利用することで、自動でcsvファイルにまとめて書き出すことができます。
from nnabla.utils.profiler import GraphProfilerCsvWriter
with open("./profile.csv", "w") as f:
writer = GraphProfilerCsvWriter(profiler, file=f)
writer.write()
このようにプロファイリングした結果を、グラフにしてみると各関数での演算速度を効果的に比較することが可能です。
今回は一例として、nnabla-examplesで公開しているlenetでProfilingしてみた結果をお見せします。
また、Profilingを含めたnnablaのデバッグに関するチュートリアルも用意してございますので、そちらも参考にしてみてください。
今回の内容は以上となります。
誰でも簡単に実行速度が計測できるAPI、ぜひご活用ください!