Skip to content

Instantly share code, notes, and snippets.

@beam2d
Last active February 23, 2018 01:20
Show Gist options
  • Save beam2d/dd87d97b8d7581f95cd19b1de86c9f4d to your computer and use it in GitHub Desktop.
Save beam2d/dd87d97b8d7581f95cd19b1de86c9f4d to your computer and use it in GitHub Desktop.
New-Style Function Implementation Guideline

New-Style Function Implementation Guideline

This is a guideline to give the standard way of implementing a new-style function in Chainer. Different developers can align the coding style by writing implementations based on this document.

Basics

A new-style function implementation consists of at least two elements: an implementation of FunctionNode and a user interface function.

FunctionNode is a new-style function node class, which has a similar structure with Function. You can write __init__ and forward (as well as forward_cpu and forward_gpu) like those of Function. There is only one difference: input arrays are NOT retained by default. If you want to use them in backward, call retain_inputs() explicitly.

The backward, on the other hand, is totally different from that of Function. Please read the document of FunctionNode.backward() and FunctionNode.backward_accumulate(). Inside of these methods, you can use self.get_retained_inputs() and self.get_retained_outputs() to retrieve the retained inputs/outputs. Note that grad_outputs and these retained inputs/outputs are all given as Variable objects, and backward() must return a tuple of Variable objects.

Using new-style API (basics)

Consider the following simple function using old API.

class MyFunc(function.Function):
    def __init__(self, args): …  # <set up attributes>

    def forward_cpu(self, inputs):
        …  # <compute outputs>
        return outputs

    def forward_gpu(self, inputs):
        … # <compute outputs>
        returnn outputs

    def backward(self, inputs, grad_outputs):
        …  # <compute grad inputs>
        return grad_inputs_in_arrays

def my_func(inputs…):
    return MyFunc(args)(inputs…)

If backward is so simple that it can be simply mapped to an implementation using existing differentiable functions, we can fix it using FunctionNode as follows.

class MyFunc(function_node.FunctionNode):
    def __init__(self, args): …  # not changed
    def forward_cpu(self, inputs): …  # not changed except for retain inputs/outputs
    def forward_gpu(self, inputs): …  # not changed except for retain inputs/outputs

    def backward(self, indexes, grad_outputs):
        # Compute the gradient w.r.t. inputs specified by indexes using F.*** functions.
        # Note that input variables are not given; use self.get_retained_inputs() to get them.
        return grad_inputs_in_variables

def my_func(inputs…):
    return MyFunc(args).apply((inputs, …))[0]  # [0] if only one output; omit it otherwise

Note that self.retain_inputs(()) in the existing code (i.e., let the function retain nothing) can be omitted in the new style API because the inputs are not retained by default.

Using new-style API (optimized)

If the existing function has an optimized implementation of backward, you should keep the current optimized code. In order to do that, you have to write a new function that implements the first-order gradient of the original function.

Consider the above MyFunc example with optimized backward implementation (probably written separately by backward_cpu and backward_gpu). The new-style version will then be written as follows.

class MyFunc(function_node.FunctionNode):
    …   # __init__, forward_cpu, and forward_gpu are same as above

    def backward(self, indexes, grad_outputs):
        ... # Prepare arguments
        return MyFuncGrad(args).apply((inputs_to_grad, …))

class MyFuncGrad(function_node.FunctionNode):
    def __init__(self, args): …  # if needed

    def forward_cpu(self, inputs):
        # paste the existing backward_cpu implementation here
        return outputs  # grad_inputs of the original function

    def forward_gpu(self, inputs):
        # paste the existing backward_gpu implementation here
        return outputs  # grad_inputs of the original function

    def backward(self, indexes, grad_outputs):
        # grad_outputs is actually the second order gradient w.r.t. inputs.
        # This method computes the second order gradient w.r.t. outputs of the original MyFunc.
        return grad_inputs  # grad_grad_outputs of the original function

def my_func(inputs…): …  # same as above

Testing the new-style function that supports double backward

You can use check_double_backward (see chainer/chainer#3096 until merged) to run gradient check for the second order gradient. This function runs two backwpropagations; first to compute the gradient gx of y w.r.t. x, and second to compute the gradient of gx w.r.t. x. It is basically used like check_backward except for two important differences.

  1. It requires an additional argument, x_grad_grad, which is an array or a tuple of arrays used for initializing the gradient array of each gradient w.r.t. an input. In other words, this argument is used to initialize gx.grad for the second backprop.
  2. If the function to be tested is a linear function of inputs, check_double_backward will fail with a message RuntimeError: gradients of some arguments are not calculated. More precisely, a function whose gradients all do not depend on at least one of the inputs directly nor indirectly (in a differentiable way) will fail with such a message. It is because backward through such gradients do not reach to the input. In order to avoid this error, you have to make the function nonlinear, e.g. by appling x * x to the output of the function you want to test.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment