This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
@torch.jit.script | |
def test_ad(a, b, c, d, e, f): | |
# type: (List[int], Tensor, List[int], Tensor, Tensor, Tensor) | |
v19 = f > 0 | |
v17 = v19.type_as(f) | |
v14 = e.mul(v17) | |
grad_input = v14.mm(d) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
LSTM forward graph: after the PR #20039: | |
graph(%input : Float(*, *), | |
%input0 : Float(*, *), | |
%cx : Float(*, *), | |
%weight : Float(*, *), | |
%weight0 : Float(*, *), | |
%bias : Float(*), | |
%bias0 : Float(*)): | |
%9 : Float(*, *) = aten::t(%weight) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
LSTM forward graph: before the PR #20039 | |
graph(%input : Float(*, *), | |
%input0 : Float(*, *), | |
%cx : Float(*, *), | |
%weight : Float(*, *), | |
%weight0 : Float(*, *), | |
%bias : Float(*), | |
%bias0 : Float(*)): | |
%9 : Float(*, *) = aten::t(%weight) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
loss value: tensor(7.5122, grad_fn=<_TransducerBackward>) | |
caffe2/aten/src/ATen/native/Normalization.cpp:311:18: runtime error: 6.17887e+39 is outside the range of representable values of type 'float' | |
#0 0x7f877131e37a in std::tuple<at::Tensor, at::Tensor, at::Tensor> at::native::batch_norm_backward_cpu_template<float>(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, bool, double, std::array<bool, 3ul>)::'lambda'(long, long)::operator()(long, long) const::'lambda0'(float&, float const&)::operator()(float&, float const&) const caffe2/aten/src/ATen/native/Normalization.cpp:311 | |
#1 0x7f877131ddc3 in void at::apply_op<std::tuple<at::Tensor, at::Tensor, at::Tensor> at::native::batch_norm_backward_cpu_template<float>(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, bool, double, std::array<bool, 3ul>)::'lambda'(long, long)::operator()(long, l |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
import torch.jit as jit | |
from torch import Tensor | |
@jit.script | |
def test_mm_back(input1, input2, normalized_shape): | |
# type: (Tensor, Tensor, List[int]) -> Tensor | |
return F.layer_norm(torch.mm(input1, input2), normalized_shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@torch.jit.script | |
def test_if_refinement(weight, bias): | |
# type: (Optional[Tensor], Optional[Tensor]) -> Tuple[Optional[int], Optional[int]] | |
if weight is not None and bias is not None: | |
grad_weight = 1 | |
grad_bias = 2 | |
elif weight is not None: | |
grad_weight = 2 | |
grad_bias = None | |
elif bias is not None: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
====================================================================== | |
ERROR: test_max_pool2d (__main__.TestOperators) | |
---------------------------------------------------------------------- | |
Traceback (most recent call last): | |
File "test/test_operators.py", line 148, in test_max_pool2d | |
max_pool2d(X) | |
File "/data/users/wanchaol/pytorch/torch/nn/modules/module.py", line 493, in __call__ | |
result = self.forward(*input, **kwargs) | |
RuntimeError: [10:54:24] /home/wanchaol/local/tvm/src/relay/ir/error.cc:131: | |
Error(s) have occurred. We have annotated the program with them: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
graph(%x : Float(*, *), | |
%hx : Float(*, *), | |
%cx : Float(*, *), | |
%w_ih : Float(*, *), | |
%w_hh : Float(*, *), | |
%b_ih : Float(*), | |
%b_hh : Float(*)): | |
%hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %b_hh, %b_ih, %hx, %w_hh, %x, %w_ih) | |
%30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy) | |
return (%30) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
graph(%x : Float(*, *), | |
%hx : Float(*, *), | |
%cx : Float(*, *), | |
%w_ih : Float(*, *), | |
%w_hh : Float(*, *), | |
%b_ih : Float(*), | |
%b_hh : Float(*)): | |
%hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %b_hh, %b_ih, %hx, %w_hh, %x, %w_ih) | |
%30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy) | |
return (%30) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
graph(%input.1 : Float(*, *), | |
%weight.1 : Float(*, *), | |
%bias.1 : Float(*), | |
%weight.2 : Float(*, *), | |
%bias.2 : Float(*), | |
%weight.3 : Float(*, *), | |
%bias.3 : Float(*), | |
%weight.4 : Float(*, *), | |
%bias.4 : Float(*), | |
%weight.5 : Float(*, *), |