Skip to content

Instantly share code, notes, and snippets.

@wanchaol
Last active November 21, 2018 19:28
Show Gist options
  • Save wanchaol/bf50cc9ce03cfe3f4418eed666fa215a to your computer and use it in GitHub Desktop.
Save wanchaol/bf50cc9ce03cfe3f4418eed666fa215a to your computer and use it in GitHub Desktop.
lenet trace test
import torch
import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 1 input image channel, 6 output channels, 5x5 square convolution
# kernel
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
# Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
# If the size is a square you can only specify a single number
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
net = LeNet().cuda()
print(net)
traced_lenet = torch.jit.trace(net, torch.randn(1, 1, 32, 32, device='cuda'))
print(traced_lenet)
===================
LeNet(
(conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
graph(%input.1 : Float(1, 1, 32, 32)
%1 : Float(6, 1, 5, 5)
%2 : Float(6)
%3 : Float(16, 6, 5, 5)
%4 : Float(16)
%weight.1 : Float(120, 400)
%bias.1 : Float(120)
%weight.2 : Float(84, 120)
%bias.2 : Float(84)
%weight : Float(10, 84)
%bias : Float(10)) {
%11 : int = prim::Constant[value=1](), scope: LeNet/Conv2d[conv1]
%12 : int = prim::Constant[value=1](), scope: LeNet/Conv2d[conv1]
%13 : int[] = prim::ListConstruct(%11, %12), scope: LeNet/Conv2d[conv1]
%14 : int = prim::Constant[value=0](), scope: LeNet/Conv2d[conv1]
%15 : int = prim::Constant[value=0](), scope: LeNet/Conv2d[conv1]
%16 : int[] = prim::ListConstruct(%14, %15), scope: LeNet/Conv2d[conv1]
%17 : int = prim::Constant[value=1](), scope: LeNet/Conv2d[conv1]
%18 : int = prim::Constant[value=1](), scope: LeNet/Conv2d[conv1]
%19 : int[] = prim::ListConstruct(%17, %18), scope: LeNet/Conv2d[conv1]
%20 : bool = prim::Constant[value=0](), scope: LeNet/Conv2d[conv1]
%21 : int = prim::Constant[value=0](), scope: LeNet/Conv2d[conv1]
%22 : int = prim::Constant[value=0](), scope: LeNet/Conv2d[conv1]
%23 : int[] = prim::ListConstruct(%21, %22), scope: LeNet/Conv2d[conv1]
%24 : int = prim::Constant[value=1](), scope: LeNet/Conv2d[conv1]
%25 : bool = prim::Constant[value=0](), scope: LeNet/Conv2d[conv1]
%26 : bool = prim::Constant[value=0](), scope: LeNet/Conv2d[conv1]
%27 : bool = prim::Constant[value=1](), scope: LeNet/Conv2d[conv1]
%input.2 : Float(1, 6, 28, 28) = aten::_convolution(%input.1, %1, %2, %13, %16, %19, %20, %23, %24, %25, %26, %27), scope: LeNet/Conv2d[conv1]
%input.3 : Float(1, 6, 28, 28) = aten::relu(%input.2), scope: LeNet
%30 : int = prim::Constant[value=2](), scope: LeNet
%31 : int = prim::Constant[value=2](), scope: LeNet
%32 : int[] = prim::ListConstruct(%30, %31), scope: LeNet
%33 : int[] = prim::ListConstruct(), scope: LeNet
%34 : int = prim::Constant[value=0](), scope: LeNet
%35 : int = prim::Constant[value=0](), scope: LeNet
%36 : int[] = prim::ListConstruct(%34, %35), scope: LeNet
%37 : int = prim::Constant[value=1](), scope: LeNet
%38 : int = prim::Constant[value=1](), scope: LeNet
%39 : int[] = prim::ListConstruct(%37, %38), scope: LeNet
%40 : bool = prim::Constant[value=0](), scope: LeNet
%input.4 : Float(1, 6, 14, 14), %42 : Long(1, 6, 14, 14) = aten::max_pool2d_with_indices(%input.3, %32, %33, %36, %39, %40), scope: LeNet
%43 : int = prim::Constant[value=1](), scope: LeNet/Conv2d[conv2]
%44 : int = prim::Constant[value=1](), scope: LeNet/Conv2d[conv2]
%45 : int[] = prim::ListConstruct(%43, %44), scope: LeNet/Conv2d[conv2]
%46 : int = prim::Constant[value=0](), scope: LeNet/Conv2d[conv2]
%47 : int = prim::Constant[value=0](), scope: LeNet/Conv2d[conv2]
%48 : int[] = prim::ListConstruct(%46, %47), scope: LeNet/Conv2d[conv2]
%49 : int = prim::Constant[value=1](), scope: LeNet/Conv2d[conv2]
%50 : int = prim::Constant[value=1](), scope: LeNet/Conv2d[conv2]
%51 : int[] = prim::ListConstruct(%49, %50), scope: LeNet/Conv2d[conv2]
%52 : bool = prim::Constant[value=0](), scope: LeNet/Conv2d[conv2]
%53 : int = prim::Constant[value=0](), scope: LeNet/Conv2d[conv2]
%54 : int = prim::Constant[value=0](), scope: LeNet/Conv2d[conv2]
%55 : int[] = prim::ListConstruct(%53, %54), scope: LeNet/Conv2d[conv2]
%56 : int = prim::Constant[value=1](), scope: LeNet/Conv2d[conv2]
%57 : bool = prim::Constant[value=0](), scope: LeNet/Conv2d[conv2]
%58 : bool = prim::Constant[value=0](), scope: LeNet/Conv2d[conv2]
%59 : bool = prim::Constant[value=1](), scope: LeNet/Conv2d[conv2]
%input.5 : Float(1, 16, 10, 10) = aten::_convolution(%input.4, %3, %4, %45, %48, %51, %52, %55, %56, %57, %58, %59), scope: LeNet/Conv2d[conv2]
%input.6 : Float(1, 16, 10, 10) = aten::relu(%input.5), scope: LeNet
%62 : int = prim::Constant[value=2](), scope: LeNet
%63 : int = prim::Constant[value=2](), scope: LeNet
%64 : int[] = prim::ListConstruct(%62, %63), scope: LeNet
%65 : int[] = prim::ListConstruct(), scope: LeNet
%66 : int = prim::Constant[value=0](), scope: LeNet
%67 : int = prim::Constant[value=0](), scope: LeNet
%68 : int[] = prim::ListConstruct(%66, %67), scope: LeNet
%69 : int = prim::Constant[value=1](), scope: LeNet
%70 : int = prim::Constant[value=1](), scope: LeNet
%71 : int[] = prim::ListConstruct(%69, %70), scope: LeNet
%72 : bool = prim::Constant[value=0](), scope: LeNet
%x : Float(1, 16, 5, 5), %74 : Long(1, 16, 5, 5) = aten::max_pool2d_with_indices(%input.6, %64, %65, %68, %71, %72), scope: LeNet
%78 : int = prim::Constant[value=1](), scope: LeNet
%79 : int = aten::size(%x, %78), scope: LeNet
%s.1 : Long() = prim::NumToTensor(%79), scope: LeNet
%81 : int = prim::Constant[value=2](), scope: LeNet
%82 : int = aten::size(%x, %81), scope: LeNet
%s.2 : Long() = prim::NumToTensor(%82), scope: LeNet
%84 : int = prim::Constant[value=3](), scope: LeNet
%85 : int = aten::size(%x, %84), scope: LeNet
%s : Long() = prim::NumToTensor(%85), scope: LeNet
%87 : Long() = prim::Constant[value={1}](), scope: LeNet
%num_features.1 : Long() = aten::mul(%s.1, %87), scope: LeNet
%num_features : Long() = aten::mul(%num_features.1, %s.2), scope: LeNet
%90 : Long() = aten::mul(%num_features, %s), scope: LeNet
%91 : int = prim::TensorToNum(%90), scope: LeNet
%92 : int = prim::Constant[value=-1](), scope: LeNet
%93 : int[] = prim::ListConstruct(%92, %91), scope: LeNet
%input.7 : Float(1, 400) = aten::view(%x, %93), scope: LeNet
%95 : Float(400!, 120!) = aten::t(%weight.1), scope: LeNet/Linear[fc1]
%96 : int = prim::Constant[value=1](), scope: LeNet/Linear[fc1]
%97 : int = prim::Constant[value=1](), scope: LeNet/Linear[fc1]
%input.8 : Float(1, 120) = aten::addmm(%bias.1, %input.7, %95, %96, %97), scope: LeNet/Linear[fc1]
%input.9 : Float(1, 120) = aten::relu(%input.8), scope: LeNet
%100 : Float(120!, 84!) = aten::t(%weight.2), scope: LeNet/Linear[fc2]
%101 : int = prim::Constant[value=1](), scope: LeNet/Linear[fc2]
%102 : int = prim::Constant[value=1](), scope: LeNet/Linear[fc2]
%input.10 : Float(1, 84) = aten::addmm(%bias.2, %input.9, %100, %101, %102), scope: LeNet/Linear[fc2]
%input : Float(1, 84) = aten::relu(%input.10), scope: LeNet
%105 : Float(84!, 10!) = aten::t(%weight), scope: LeNet/Linear[fc3]
%106 : int = prim::Constant[value=1](), scope: LeNet/Linear[fc3]
%107 : int = prim::Constant[value=1](), scope: LeNet/Linear[fc3]
%108 : Float(1, 10) = aten::addmm(%bias, %input, %105, %106, %107), scope: LeNet/Linear[fc3]
return (%108);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment