Skip to content

Instantly share code, notes, and snippets.

Benchmarking LSTMs...
name avg_fwd std_fwd avg_bwd std_bwd
cudnn_layernorm 32.71 0.7494 10.43 0.08965
jit_layernorm 41.25 0.7082 98.66 2.56
jit_layernorm_de 34.41 0.7501 113.3 1.037
import torch
import torch.nn as nn
from torch.jit import ScriptModule, script_method
from typing import List
class BatchNorm(ScriptModule):
__constants__ = ['mom', 'eps']
def __init__(self, nf, mom=0.9, eps=1e-5):
import torch
import torch.nn as nn
from torch.jit import ScriptModule, script_method
from typing import List


class BatchNormList(ScriptModule):
    __constants__ = ['mom', 'eps']
--
-- ******** Summary ********
-- General:
-- CMake version : 3.12.2
-- CMake command : /private/home/wanchaol/.conda/envs/pt/bin/cmake
-- System : Linux
-- C++ compiler : /scratch/wanchaol/ccache/lib/c++
-- C++ compiler id : GNU
-- C++ compiler version : 7.3.0
-- BLAS : MKL
graph(%input.1 : Float(*, *),
%weight.1 : Float(*, *),
%bias.1 : Float(*),
%weight.2 : Float(*, *),
%bias.2 : Float(*),
%weight.3 : Float(*, *),
%bias.3 : Float(*),
%weight.4 : Float(*, *),
%bias.4 : Float(*),
%weight.5 : Float(*, *),
graph(%x : Float(*, *),
%hx : Float(*, *),
%cx : Float(*, *),
%w_ih : Float(*, *),
%w_hh : Float(*, *),
%b_ih : Float(*),
%b_hh : Float(*)):
%hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %b_hh, %b_ih, %hx, %w_hh, %x, %w_ih)
%30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
return (%30)
@wanchaol
wanchaol / gist:ded928ba77c4703e41fa1e681172bec5
Created April 12, 2019 18:27
lstmcell_forward_unoptimized
graph(%x : Float(*, *),
%hx : Float(*, *),
%cx : Float(*, *),
%w_ih : Float(*, *),
%w_hh : Float(*, *),
%b_ih : Float(*),
%b_hh : Float(*)):
%hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %b_hh, %b_ih, %hx, %w_hh, %x, %w_ih)
%30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
return (%30)
======================================================================
ERROR: test_max_pool2d (__main__.TestOperators)
----------------------------------------------------------------------
Traceback (most recent call last):
File "test/test_operators.py", line 148, in test_max_pool2d
max_pool2d(X)
File "/data/users/wanchaol/pytorch/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
RuntimeError: [10:54:24] /home/wanchaol/local/tvm/src/relay/ir/error.cc:131:
Error(s) have occurred. We have annotated the program with them:
@torch.jit.script
def test_if_refinement(weight, bias):
# type: (Optional[Tensor], Optional[Tensor]) -> Tuple[Optional[int], Optional[int]]
if weight is not None and bias is not None:
grad_weight = 1
grad_bias = 2
elif weight is not None:
grad_weight = 2
grad_bias = None
elif bias is not None:
import torch
import torch.nn.functional as F
import torch.jit as jit
from torch import Tensor
@jit.script
def test_mm_back(input1, input2, normalized_shape):
# type: (Tensor, Tensor, List[int]) -> Tensor
return F.layer_norm(torch.mm(input1, input2), normalized_shape)