import torch
import torch.nn as nn
from torch.jit import ScriptModule, script_method
from typing import List
class BatchNormList(ScriptModule):
__constants__ = ['mom', 'eps']
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- | |
-- ******** Summary ******** | |
-- General: | |
-- CMake version : 3.12.2 | |
-- CMake command : /private/home/wanchaol/.conda/envs/pt/bin/cmake | |
-- System : Linux | |
-- C++ compiler : /scratch/wanchaol/ccache/lib/c++ | |
-- C++ compiler id : GNU | |
-- C++ compiler version : 7.3.0 | |
-- BLAS : MKL |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torch.jit import ScriptModule, script_method | |
from typing import List | |
class BatchNorm(ScriptModule): | |
__constants__ = ['mom', 'eps'] | |
def __init__(self, nf, mom=0.9, eps=1e-5): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Benchmarking LSTMs... | |
name avg_fwd std_fwd avg_bwd std_bwd | |
cudnn_layernorm 32.71 0.7494 10.43 0.08965 | |
jit_layernorm 41.25 0.7082 98.66 2.56 | |
jit_layernorm_de 34.41 0.7501 113.3 1.037 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class Test(torch.jit.ScriptModule): | |
def __init__(self, b = None): | |
self.b = b | |
def forward(self, input): | |
x = input | |
if self.b is not None: | |
x = self.b(input) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def fn(input): | |
return torch.log(input + 1e-8) | |
input = torch.rand(5, 5) | |
output = fn(input) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class LeNet(nn.Module): | |
def __init__(self): | |
super(LeNet, self).__init__() | |
# 1 input image channel, 6 output channels, 5x5 square convolution |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def remove_sentence_boundary(tensor): | |
tensor_shape = list(tensor.data.shape) | |
new_shape = list(tensor_shape) | |
new_shape[1] = tensor_shape[1] - 2 | |
tensor_without_boundary_tokens = torch.zeros(new_shape, device=tensor.device) | |
return tensor_without_boundary_tokens | |
traced_fn = torch.jit.trace(remove_sentence_boundary, torch.rand(10, 20, 30)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
a = torch.tensor([[-0.5689, 1.3550, -1.7742, -0.2412, 0.2400], | |
[-1.1720, 0.6153, 0.0285, 0.7397, 0.3760], | |
[ 1.0568, -0.9253, -0.5579, 0.1791, 1.3932 ], | |
[ 0.4966, 0.9272, -1.3335, -0.2913, 0.8120 ], | |
[-0.5048, -0.9092, 0.2757, 1.3891, 1.1164]]) | |
print("% output:") |