-
chainerのメモリ削減の仕組み
-
逆伝播の計算時にいらない配列等を削除する仕組み
-
いらない配列とは?
-> +, -, sigmoidなど
(例:sigmoid)
function backward(f::Sigmoid, gy) y = f.grad_field.outputs[1] # backpropにはinputsが必要ない @. return (y * (1 - y)) * gy end
-
次のような計算を考える。(ただし, g(x)は逆伝播に入力の配列が必要ないもの。)
すると逆伝播の時は次のような計算グラフができる。
-
この時、g(x)の逆伝播にはx2が必要ないので、ユーザーがx2への参照を解放した時にx2を消しても良さそう.
-
ただ、gのinputsにはx2が入っているのでPythonオブジェクトの参照カウント周りでトリッキーなことをする必要がある。(gのinputsに参照が残っているので参照カウントが0にならないから、ユーザが参照を消しても検知できないってこと?)
-
そこで、ユーザーから直接参照されるオブジェクトと、逆伝播のグラフを分ける。
-
これをすると、ユーザーがその変数への直接参照を失ったときに検出して、不要な配列を消すことができる。
- 各Variableごとにこんな感じのを作る.
- 青いのを
Variable
, 橙をVariableNode
とする。
-
全体はこんな感じになる
-
ここで、x2はいらないので、x2 の
VariableNode
(橙)からdataへの参照を切る。
- すると、ユーザーがx2(の
Variable
?)への参照を解放すると、dataを解放することができる。 - 出力の場合も同じように管理できる(出力が逆伝播に必要な場合もある(sigmoidなど))