The torch nn library is considered as the first library we want to support in JIT Script. It consists of modules including activation, conv, rnn, loss, etc. These modules inherits from torch.nn.Module, each of them is wrapping things up in different classes (like argument preprocessing, etc.), then call nn.functional to do real work. The nn.functional is also exposed to the users.
- identify the difference between nn functional ops and the corresponding ATen op, PR already out pytorch/pytorch#10409
- maintain a separate copy of nn.functional that do script annotation, apply some workarounds to make every function generate the IR
- open registration for nn.functional ops (possibly nn modules), to let us directly inline the graph in C++
- scripting nn.modules to make each submodule runs in script mode.
Below are some of our current constraints in JIT and details of the purposed plan. This document is intended to clarify the current constraints in the JIT and document the plan to support the torch.nn library in JIT script mode.
In today's JIT world, we already have the ScriptModule which subclasses nn.Module, with most of methods implemented. Theoretically we want to directly annotate the entire torch.nn library for different modules and share a single source in JIT and non-JIT mode.
Direct annotation example:
# nn functional
@torch.jit.script
def threshold(input, threshold, value, inplace=False):
r"""Thresholds each element of the input Tensor.
See :class:`~torch.nn.Threshold` for more details.
"""
if inplace:
return torch._C._nn.threshold_(input, threshold, value)
return torch._C._nn.threshold(input, threshold, value)
# activation module
class Threshold(ScriptModule):
def __init__(self, threshold, value, inplace=False):
super(Threshold, self).__init__()
self.threshold = threshold
self.value = value
self.inplace = inplace
@torch.jit.script_method
def forward(self, input):
return F.threshold(input, self.threshold, self.value, self.inplace)
While we are trying to make the entire library to work in script mode under direct annotation, there's multiple constraints now in JIT that prevents us from pursuing that path, some of them are noted below:
- exception/error raising, warning, etc could not be supported in script
- perf regression with overspecialized ScriptModule, and it does not support all methods, will likely break things
- binding in-consistent: functions binded in the torch._C._nn is not available in torch namespace
- in-place operation does not supported in script mode
- default arguments in function signature are not supported in script
- ScriptModule non-tensor attributes has to be in the constants set in order to script the module, that make user could not modify the attributes afterwards (like changing training flag/eval() etc.)
The above constraints are a list of issues we discovered when we start looking into library support, it might not be a complete list as there might be other special cases we need to deal with. Since those constraints are not easy fixes, it might took us several months to get the constraints removed, so in order to quickly support torch.nn in JIT Script, we need to adopt a different method in implementing the initial version.
Because of the above constraints with JIT script and some unforeseen issues for direct annotation, we decided to take the approach to maintain two copies of torch.nn.functional. The new copy for JIT Script will be under torch.nn.script. The script version involves direct annotation for the ops and some hackable workarounds:
- torch.nn.script.functional will primarily involves the wrapper that wraps all the functional interfaces that requires complicated pre/post processing besides calling the raw op, we annotate those part with script
- identify the difference between nn functional ops and the corresponding ATen op, PR already out pytorch/pytorch#10409
- for the functions binded in the torch._C._nn namespace, we will try to bind them all in torch namespace
- a simple temporary workaround now is that JIT script could directly replace torch.nn.functional with the corresponding ATen op. So for simple function only calls torch._C._nn, we will utilize the replacement in short term.
There're bunch of useful modules and it provides user a higher level abstraction on constructing the network (e.g. nn.Sequential, nn.Linear, etc.). In order to support those modules in JIT script, we could maintain a separate copy of nn.modules under torch.nn.script (still under discussion), with example ScriptModule as below:
class GLU(ScriptModule):
def __init__(self, dim=-1):
super(GLU, self).__init__()
self.dim = dim
__constants__ = ['dim']
@torch.jit.script_method
def forward(self, input):
return F.glu(input, self.dim)
Instead of exposing script version of nn.functinal and nn.modules directly to the end user, we will apply table registration for our nn.functional ops that directly translate nn.functional and modules to the corresponding graph in C++. So user will write
import torch.nn.functional as F
@torch.jit.script
def embedding_functional(input):
# input/embedding processing
input = ...
embedding = ...
return F.embedding(input, embeddings)
the JIT will directly inlining the IR of F.embedding to the whole IR if the function is annotated with script. We register all the ops in nn.functional and possibly nn.modules to a in-memory table for the compiler to do the inlining work.