Skip to content

Instantly share code, notes, and snippets.

@wanchaol
Last active August 21, 2018 20:35
Show Gist options
  • Save wanchaol/1beeb077526ef72512f6f6a9bfae7477 to your computer and use it in GitHub Desktop.
Save wanchaol/1beeb077526ef72512f6f6a9bfae7477 to your computer and use it in GitHub Desktop.

Torch NN library support in JIT Script

The torch nn library is considered as the first library we want to support in JIT Script. It consists of modules including activation, conv, rnn, loss, etc. These modules inherits from torch.nn.Module, each of them is wrapping things up in different classes (like argument preprocessing, etc.), then call nn.functional to do real work. The nn.functional is also exposed to the users.

Proposed plan:

  1. identify the difference between nn functional ops and the corresponding ATen op, PR already out pytorch/pytorch#10409
  2. maintain a separate copy of nn.functional that do script annotation, apply some workarounds to make every function generate the IR
  3. open registration for nn.functional ops (possibly nn modules), to let us directly inline the graph in C++
  4. scripting nn.modules to make each submodule runs in script mode.

Below are some of our current constraints in JIT and details of the purposed plan. This document is intended to clarify the current constraints in the JIT and document the plan to support the torch.nn library in JIT script mode.

Constraints

In today's JIT world, we already have the ScriptModule which subclasses nn.Module, with most of methods implemented. Theoretically we want to directly annotate the entire torch.nn library for different modules and share a single source in JIT and non-JIT mode.

Direct annotation example:

# nn functional

@torch.jit.script
def threshold(input, threshold, value, inplace=False):
    r"""Thresholds each element of the input Tensor.

    See :class:`~torch.nn.Threshold` for more details.
    """
    if inplace:
        return torch._C._nn.threshold_(input, threshold, value)
    return torch._C._nn.threshold(input, threshold, value)
   

# activation module
class Threshold(ScriptModule):

    def __init__(self, threshold, value, inplace=False):
        super(Threshold, self).__init__()
        self.threshold = threshold
        self.value = value
        self.inplace = inplace

    @torch.jit.script_method
    def forward(self, input):
        return F.threshold(input, self.threshold, self.value, self.inplace)

While we are trying to make the entire library to work in script mode under direct annotation, there's multiple constraints now in JIT that prevents us from pursuing that path, some of them are noted below:

  • exception/error raising, warning, etc could not be supported in script
  • perf regression with overspecialized ScriptModule, and it does not support all methods, will likely break things
  • binding in-consistent: functions binded in the torch._C._nn is not available in torch namespace
  • in-place operation does not supported in script mode
  • default arguments in function signature are not supported in script
  • ScriptModule non-tensor attributes has to be in the constants set in order to script the module, that make user could not modify the attributes afterwards (like changing training flag/eval() etc.)

The above constraints are a list of issues we discovered when we start looking into library support, it might not be a complete list as there might be other special cases we need to deal with. Since those constraints are not easy fixes, it might took us several months to get the constraints removed, so in order to quickly support torch.nn in JIT Script, we need to adopt a different method in implementing the initial version.

Details

nn.functional in script

Because of the above constraints with JIT script and some unforeseen issues for direct annotation, we decided to take the approach to maintain two copies of torch.nn.functional. The new copy for JIT Script will be under torch.nn.script. The script version involves direct annotation for the ops and some hackable workarounds:

  • torch.nn.script.functional will primarily involves the wrapper that wraps all the functional interfaces that requires complicated pre/post processing besides calling the raw op, we annotate those part with script
  • identify the difference between nn functional ops and the corresponding ATen op, PR already out pytorch/pytorch#10409
  • for the functions binded in the torch._C._nn namespace, we will try to bind them all in torch namespace
  • a simple temporary workaround now is that JIT script could directly replace torch.nn.functional with the corresponding ATen op. So for simple function only calls torch._C._nn, we will utilize the replacement in short term.

nn.Modules in script

There're bunch of useful modules and it provides user a higher level abstraction on constructing the network (e.g. nn.Sequential, nn.Linear, etc.). In order to support those modules in JIT script, we could maintain a separate copy of nn.modules under torch.nn.script (still under discussion), with example ScriptModule as below:

class GLU(ScriptModule):

    def __init__(self, dim=-1):
        super(GLU, self).__init__()
        self.dim = dim

    __constants__ = ['dim']

    @torch.jit.script_method
    def forward(self, input):
        return F.glu(input, self.dim)

Open registration on nn.functional and NN.Modules

Instead of exposing script version of nn.functinal and nn.modules directly to the end user, we will apply table registration for our nn.functional ops that directly translate nn.functional and modules to the corresponding graph in C++. So user will write

import torch.nn.functional as F

@torch.jit.script
def embedding_functional(input):
    # input/embedding processing
    input = ...
    embedding = ...
    
    return F.embedding(input, embeddings)

the JIT will directly inlining the IR of F.embedding to the whole IR if the function is annotated with script. We register all the ops in nn.functional and possibly nn.modules to a in-memory table for the compiler to do the inlining work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment