Skip to content

Instantly share code, notes, and snippets.

graph(%input.1 : Float(*, *),
%weight.1 : Float(*, *),
%bias.1 : Float(*),
%weight.2 : Float(*, *),
%bias.2 : Float(*),
%weight.3 : Float(*, *),
%bias.3 : Float(*),
%weight.4 : Float(*, *),
%bias.4 : Float(*),
%weight.5 : Float(*, *),
--
-- ******** Summary ********
-- General:
-- CMake version : 3.12.2
-- CMake command : /private/home/wanchaol/.conda/envs/pt/bin/cmake
-- System : Linux
-- C++ compiler : /scratch/wanchaol/ccache/lib/c++
-- C++ compiler id : GNU
-- C++ compiler version : 7.3.0
-- BLAS : MKL
import torch
import torch.nn as nn
from torch.jit import ScriptModule, script_method
from typing import List


class BatchNormList(ScriptModule):
    __constants__ = ['mom', 'eps']
import torch
import torch.nn as nn
from torch.jit import ScriptModule, script_method
from typing import List
class BatchNorm(ScriptModule):
__constants__ = ['mom', 'eps']
def __init__(self, nf, mom=0.9, eps=1e-5):
Benchmarking LSTMs...
name avg_fwd std_fwd avg_bwd std_bwd
cudnn_layernorm 32.71 0.7494 10.43 0.08965
jit_layernorm 41.25 0.7082 98.66 2.56
jit_layernorm_de 34.41 0.7501 113.3 1.037
@wanchaol
wanchaol / schema.md
Last active December 20, 2018 23:35

ATen/JIT Function Schema

This doc is to talk about the difference in ATen/JIT function schema and find a way to align each other.

List syntax

Currently in ATen function schema when we want to define a list, we have:

@wanchaol
wanchaol / none.py
Last active November 28, 2018 19:56
import torch
class Test(torch.jit.ScriptModule):
def __init__(self, b = None):
self.b = b
def forward(self, input):
x = input
if self.b is not None:
x = self.b(input)
import torch
def fn(input):
return torch.log(input + 1e-8)
input = torch.rand(5, 5)
output = fn(input)
import torch
import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 1 input image channel, 6 output channels, 5x5 square convolution
import torch
def remove_sentence_boundary(tensor):
tensor_shape = list(tensor.data.shape)
new_shape = list(tensor_shape)
new_shape[1] = tensor_shape[1] - 2
tensor_without_boundary_tokens = torch.zeros(new_shape, device=tensor.device)
return tensor_without_boundary_tokens
traced_fn = torch.jit.trace(remove_sentence_boundary, torch.rand(10, 20, 30))