This is a guideline to give the standard way of implementing a new-style function in Chainer. Different developers can align the coding style by writing implementations based on this document.
A new-style function implementation consists of at least two elements:
an implementation of FunctionNode
and a user interface function.
FunctionNode
is a new-style function node class, which has a similar structure with Function
.
You can write __init__
and forward
(as well as forward_cpu
and forward_gpu
) like those of Function
.
There is only one difference: input arrays are NOT retained by default.
If you want to use them in backward
, call retain_inputs()
explicitly.
The backward
, on the other hand, is totally different from that of Function
.
Please read the document of FunctionNode.backward()
and FunctionNode.backward_accumulate()
.
Inside of these methods, you can use self.get_retained_inputs()
and self.get_retained_outputs()
to retrieve the retained inputs/outputs. Note that grad_outputs
and these retained inputs/outputs are all given as Variable
objects,
and backward()
must return a tuple of Variable
objects.
Consider the following simple function using old API.
class MyFunc(function.Function):
def __init__(self, args): … # <set up attributes>
def forward_cpu(self, inputs):
… # <compute outputs>
return outputs
def forward_gpu(self, inputs):
… # <compute outputs>
returnn outputs
def backward(self, inputs, grad_outputs):
… # <compute grad inputs>
return grad_inputs_in_arrays
def my_func(inputs…):
return MyFunc(args)(inputs…)
If backward
is so simple that it can be simply mapped to an implementation using existing differentiable functions,
we can fix it using FunctionNode
as follows.
class MyFunc(function_node.FunctionNode):
def __init__(self, args): … # not changed
def forward_cpu(self, inputs): … # not changed except for retain inputs/outputs
def forward_gpu(self, inputs): … # not changed except for retain inputs/outputs
def backward(self, indexes, grad_outputs):
# Compute the gradient w.r.t. inputs specified by indexes using F.*** functions.
# Note that input variables are not given; use self.get_retained_inputs() to get them.
return grad_inputs_in_variables
def my_func(inputs…):
return MyFunc(args).apply((inputs, …))[0] # [0] if only one output; omit it otherwise
Note that self.retain_inputs(())
in the existing code (i.e., let the function retain nothing) can be omitted in the new style API
because the inputs are not retained by default.
If the existing function has an optimized implementation of backward
, you should keep the current optimized code.
In order to do that, you have to write a new function that implements the first-order gradient of the original function.
Consider the above MyFunc
example with optimized backward
implementation
(probably written separately by backward_cpu
and backward_gpu
).
The new-style version will then be written as follows.
class MyFunc(function_node.FunctionNode):
… # __init__, forward_cpu, and forward_gpu are same as above
def backward(self, indexes, grad_outputs):
... # Prepare arguments
return MyFuncGrad(args).apply((inputs_to_grad, …))
class MyFuncGrad(function_node.FunctionNode):
def __init__(self, args): … # if needed
def forward_cpu(self, inputs):
# paste the existing backward_cpu implementation here
return outputs # grad_inputs of the original function
def forward_gpu(self, inputs):
# paste the existing backward_gpu implementation here
return outputs # grad_inputs of the original function
def backward(self, indexes, grad_outputs):
# grad_outputs is actually the second order gradient w.r.t. inputs.
# This method computes the second order gradient w.r.t. outputs of the original MyFunc.
return grad_inputs # grad_grad_outputs of the original function
def my_func(inputs…): … # same as above
You can use check_double_backward
(see chainer/chainer#3096 until merged) to run gradient check for the second order gradient.
This function runs two backwpropagations; first to compute the gradient gx
of y
w.r.t. x
, and second to compute the gradient of gx
w.r.t. x
.
It is basically used like check_backward
except for two important differences.
- It requires an additional argument,
x_grad_grad
, which is an array or a tuple of arrays used for initializing the gradient array of each gradient w.r.t. an input. In other words, this argument is used to initializegx.grad
for the second backprop. - If the function to be tested is a linear function of inputs,
check_double_backward
will fail with a messageRuntimeError: gradients of some arguments are not calculated
. More precisely, a function whose gradients all do not depend on at least one of the inputs directly nor indirectly (in a differentiable way) will fail with such a message. It is because backward through such gradients do not reach to the input. In order to avoid this error, you have to make the function nonlinear, e.g. by applingx * x
to the output of the function you want to test.