This doc is to talk about the difference in ATen/JIT function schema and find a way to align each other.
Currently in ATen function schema when we want to define a list, we have:
input = torch.randn(2,2) | |
part = torch.tensor([[1, 2]]) | |
ind0 = torch.arange(0,1) | |
ind1 = torch.arange(0,2) | |
input = torch.index_put(input, (ind0, ind1), part) | |
import torch | |
import torch.nn as nn | |
from torch.autograd import Variable | |
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | |
import torch.nn.functional as F | |
import itertools | |
import tempfile | |
def pack_pad_seq(seq_tensor, seq_lengths): |
import torch | |
import torch.nn as nn | |
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | |
import torch.nn.functional as F | |
import itertools | |
import tempfile | |
def pack_pad_seq(seq_tensor, seq_lengths): |
import torch | |
class Test(torch.nn.Module): | |
def __init__(self): | |
super(Test, self).__init__() | |
def forward(self, input): | |
# y = input.size(0) + 1 | |
y = 0 | |
for i in range(10): |
import torch | |
a = torch.tensor([[-0.5689, 1.3550, -1.7742, -0.2412, 0.2400], | |
[-1.1720, 0.6153, 0.0285, 0.7397, 0.3760], | |
[ 1.0568, -0.9253, -0.5579, 0.1791, 1.3932 ], | |
[ 0.4966, 0.9272, -1.3335, -0.2913, 0.8120 ], | |
[-0.5048, -0.9092, 0.2757, 1.3891, 1.1164]]) | |
print("% output:") |
import torch | |
def remove_sentence_boundary(tensor): | |
tensor_shape = list(tensor.data.shape) | |
new_shape = list(tensor_shape) | |
new_shape[1] = tensor_shape[1] - 2 | |
tensor_without_boundary_tokens = torch.zeros(new_shape, device=tensor.device) | |
return tensor_without_boundary_tokens | |
traced_fn = torch.jit.trace(remove_sentence_boundary, torch.rand(10, 20, 30)) |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class LeNet(nn.Module): | |
def __init__(self): | |
super(LeNet, self).__init__() | |
# 1 input image channel, 6 output channels, 5x5 square convolution |
import torch | |
def fn(input): | |
return torch.log(input + 1e-8) | |
input = torch.rand(5, 5) | |
output = fn(input) |
import torch | |
class Test(torch.jit.ScriptModule): | |
def __init__(self, b = None): | |
self.b = b | |
def forward(self, input): | |
x = input | |
if self.b is not None: | |
x = self.b(input) | |