Last active
June 7, 2019 22:45
-
-
Save wanchaol/b6f71e4301dc141b811c6c1f3aac1d57 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pip install git+https://github.com/arraiyopensource/kornia | |
==== | |
import torch | |
import torch.nn as nn | |
from torch.testing import assert_allclose | |
import kornia | |
@torch.jit.script | |
def op_script(input, height, | |
width) -> torch.Tensor: | |
return kornia.normalize_pixel_coordinates(input, int(height), int(width)) | |
class MyTestModule(nn.Module): | |
def __init__(self): | |
super(MyTestModule, self).__init__() | |
def forward(self, input): | |
height, width = input.shape[-2:] | |
height = torch.tensor(height) | |
height = torch.tensor(width) | |
return op_script(input, height, width) | |
my_test_op = MyTestModule() | |
op_traced = torch.jit.trace(my_test_op, torch.rand(1,4,4,2)) | |
# create points grid | |
height, width = 5, 5 | |
points = kornia.create_meshgrid( | |
height, width, normalized_coordinates=False) # 1xHxWx2 | |
# we expect that the traced function generalises with different | |
# input shapes. Ideally we might want to infer to traced the h and w. | |
assert_allclose(op_traced(points), | |
kornia.normalize_pixel_coordinates(points, height, width)) | |
==== Error msg | |
raceback (most recent call last): | |
File "/scratch/wanchaol/local/pytorch/torch/jit/__init__.py", line 565, in run_mod_and_filter_tensor_outputs | |
outs = wrap_retval(mod(*_clone_inputs(inputs))) | |
File "/scratch/wanchaol/local/pytorch/torch/nn/modules/module.py", line 494, in __call__ | |
result = self.forward(*input, **kwargs) | |
File "test.py", line 20, in forward | |
return op_script(input, height, width) | |
RuntimeError: op_script() expected a value of type 'Tensor' for argument 'width' but instead found type 'int'. | |
Inferred 'width' to be of type 'Tensor' because it was not annotated with an explicit type. | |
Position: 2 | |
Value: 2 | |
Declaration: op_script(Tensor input, Tensor height, Tensor width) -> Tensor | |
During handling of the above exception, another exception occurred: | |
Traceback (most recent call last): | |
File "test.py", line 24, in <module> | |
op_traced = torch.jit.trace(my_test_op, torch.rand(1,4,4,2)) | |
File "/scratch/wanchaol/local/pytorch/torch/jit/__init__.py", line 730, in trace | |
check_tolerance, _force_outplace, _module_class) | |
File "/scratch/wanchaol/local/pytorch/torch/jit/__init__.py", line 859, in trace_module | |
check_tolerance, _force_outplace, True) | |
File "/scratch/wanchaol/local/pytorch/torch/autograd/grad_mode.py", line 43, in decorate_no_grad | |
return func(*args, **kwargs) | |
File "/scratch/wanchaol/local/pytorch/torch/jit/__init__.py", line 604, in _check_trace | |
fn_outs = run_mod_and_filter_tensor_outputs(func, inputs, 'Python function') | |
File "/scratch/wanchaol/local/pytorch/torch/jit/__init__.py", line 571, in run_mod_and_filter_tensor_outputs | |
' with test inputs.\nException:\n' + indent(str(e))) | |
torch.jit.TracingCheckError: Tracing failed sanity checks! | |
Encountered an exception while running the Python function with test inputs. | |
Exception: | |
op_script() expected a value of type 'Tensor' for argument 'width' but instead found type 'int'. | |
Inferred 'width' to be of type 'Tensor' because it was not annotated with an explicit type. | |
Position: 2 | |
Value: 2 | |
Declaration: op_script(Tensor input, Tensor height, Tensor width) -> Tensor |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment