import torch
import torch.nn as nn
from torch.jit import ScriptModule, script_method
from typing import List
class BatchNormList(ScriptModule):
__constants__ = ['mom', 'eps']
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Benchmarking LSTMs... | |
name avg_fwd std_fwd avg_bwd std_bwd | |
cudnn_layernorm 32.71 0.7494 10.43 0.08965 | |
jit_layernorm 41.25 0.7082 98.66 2.56 | |
jit_layernorm_de 34.41 0.7501 113.3 1.037 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torch.jit import ScriptModule, script_method | |
from typing import List | |
class BatchNorm(ScriptModule): | |
__constants__ = ['mom', 'eps'] | |
def __init__(self, nf, mom=0.9, eps=1e-5): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- | |
-- ******** Summary ******** | |
-- General: | |
-- CMake version : 3.12.2 | |
-- CMake command : /private/home/wanchaol/.conda/envs/pt/bin/cmake | |
-- System : Linux | |
-- C++ compiler : /scratch/wanchaol/ccache/lib/c++ | |
-- C++ compiler id : GNU | |
-- C++ compiler version : 7.3.0 | |
-- BLAS : MKL |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
graph(%input.1 : Float(*, *), | |
%weight.1 : Float(*, *), | |
%bias.1 : Float(*), | |
%weight.2 : Float(*, *), | |
%bias.2 : Float(*), | |
%weight.3 : Float(*, *), | |
%bias.3 : Float(*), | |
%weight.4 : Float(*, *), | |
%bias.4 : Float(*), | |
%weight.5 : Float(*, *), |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
graph(%x : Float(*, *), | |
%hx : Float(*, *), | |
%cx : Float(*, *), | |
%w_ih : Float(*, *), | |
%w_hh : Float(*, *), | |
%b_ih : Float(*), | |
%b_hh : Float(*)): | |
%hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %b_hh, %b_ih, %hx, %w_hh, %x, %w_ih) | |
%30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy) | |
return (%30) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
graph(%x : Float(*, *), | |
%hx : Float(*, *), | |
%cx : Float(*, *), | |
%w_ih : Float(*, *), | |
%w_hh : Float(*, *), | |
%b_ih : Float(*), | |
%b_hh : Float(*)): | |
%hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %b_hh, %b_ih, %hx, %w_hh, %x, %w_ih) | |
%30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy) | |
return (%30) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
====================================================================== | |
ERROR: test_max_pool2d (__main__.TestOperators) | |
---------------------------------------------------------------------- | |
Traceback (most recent call last): | |
File "test/test_operators.py", line 148, in test_max_pool2d | |
max_pool2d(X) | |
File "/data/users/wanchaol/pytorch/torch/nn/modules/module.py", line 493, in __call__ | |
result = self.forward(*input, **kwargs) | |
RuntimeError: [10:54:24] /home/wanchaol/local/tvm/src/relay/ir/error.cc:131: | |
Error(s) have occurred. We have annotated the program with them: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@torch.jit.script | |
def test_if_refinement(weight, bias): | |
# type: (Optional[Tensor], Optional[Tensor]) -> Tuple[Optional[int], Optional[int]] | |
if weight is not None and bias is not None: | |
grad_weight = 1 | |
grad_bias = 2 | |
elif weight is not None: | |
grad_weight = 2 | |
grad_bias = None | |
elif bias is not None: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
import torch.jit as jit | |
from torch import Tensor | |
@jit.script | |
def test_mm_back(input1, input2, normalized_shape): | |
# type: (Tensor, Tensor, List[int]) -> Tensor | |
return F.layer_norm(torch.mm(input1, input2), normalized_shape) |