Skip to content

Instantly share code, notes, and snippets.

@tbenst
Created July 11, 2018 03:02
Show Gist options
  • Select an option

  • Save tbenst/bc79eb802df48bea37a6143567c78c85 to your computer and use it in GitHub Desktop.

Select an option

Save tbenst/bc79eb802df48bea37a6143567c78c85 to your computer and use it in GitHub Desktop.
import time
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
class SuperResBlock(nn.Module):
"""Upsample Volume using subpixel convolution.
Reference: https://arxiv.org/pdf/1609.05158.pdf"""
def __init__(self, upscale_factor):
super(SuperResBlock, self).__init__()
self.activation = nn.ReLU()
self.dconv1 = nn.Parameter(T.FloatTensor(64,1,5,5))
self.dpad1 = (2,2)
self.dbn1 = nn.BatchNorm2d(64)
self.dconv2 = nn.Parameter(T.FloatTensor(64,64,3,3))
self.dpad2 = (1,1)
self.dbn2 = nn.BatchNorm2d(64)
self.dconv3 = nn.Parameter(T.FloatTensor(32,64,3,3))
self.dpad3 = (1,1)
self.dbn3 = nn.BatchNorm2d(32)
self.dconv4 = nn.Parameter(T.FloatTensor(upscale_factor**2,32,3,3))
self.dpad4 = (1,1)
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
self.initialize_weights()
def forward(self, x):
x = self.activation(self.dbn1(F.conv2d(x, self.dconv1, padding=self.dpad1)))
x = self.activation(self.dbn2(F.conv2d(x, self.dconv2, padding=self.dpad2)))
x = self.activation(self.dbn3(F.conv2d(x, self.dconv3, padding=self.dpad3)))
x = F.conv2d(x, self.dconv4, padding=self.dpad4)
x = self.pixel_shuffle(x)
# add back single channel
x = x[:,:,None]
return x
def initialize_weights(self):
nn.init.orthogonal_(self.dconv1, nn.init.calculate_gain('relu'))
nn.init.orthogonal_(self.dconv2, nn.init.calculate_gain('relu'))
nn.init.orthogonal_(self.dconv3, nn.init.calculate_gain('relu'))
nn.init.orthogonal_(self.dconv4)
for bn in [self.dbn1,self.dbn2,self.dbn3]:
nn.init.constant_(bn.weight, 1)
nn.init.constant_(bn.bias, 0)
class SuperResBlockNotFunctional(nn.Module):
"""Upsample Volume using subpixel convolution.
Reference: https://arxiv.org/pdf/1609.05158.pdf"""
def __init__(self, upscale_factor):
super(SuperResBlockNotFunctional, self).__init__()
self.activation = nn.ReLU()
self.dpad1 = (2,2)
self.dconv1 = nn.Conv2d(1,64,5,5,padding=self.dpad1)
self.dbn1 = nn.BatchNorm2d(64)
self.dpad2 = (1,1)
self.dconv2 = nn.Conv2d(64,64,3,3,padding=self.dpad2)
self.dbn2 = nn.BatchNorm2d(64)
self.dpad3 = (1,1)
self.dconv3 = nn.Conv2d(64,32,3,3,padding=self.dpad3)
self.dbn3 = nn.BatchNorm2d(32)
self.dpad4 = (1,1)
self.dconv4 = nn.Conv2d(32,upscale_factor**2,3,3,padding=self.dpad4)
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
self.initialize_weights()
def forward(self, x):
x = self.activation(self.dbn1(self.dconv1(x, )))
x = self.activation(self.dbn2(self.dconv2(x)))
x = self.activation(self.dbn3(self.dconv3(x)))
x = self.dconv4(x)
x = self.pixel_shuffle(x)
# add back single channel
x = x[:,:,None]
return x
def initialize_weights(self):
nn.init.orthogonal_(self.dconv1.weight, nn.init.calculate_gain('relu'))
nn.init.orthogonal_(self.dconv2.weight, nn.init.calculate_gain('relu'))
nn.init.orthogonal_(self.dconv3.weight, nn.init.calculate_gain('relu'))
nn.init.orthogonal_(self.dconv4.weight)
for bn in [self.dbn1,self.dbn2,self.dbn3]:
nn.init.constant_(bn.weight, 1)
nn.init.constant_(bn.bias, 0)
net = SuperResBlock(2).cuda()
inp = T.randn(64, 1, 224, 224, requires_grad=True).cuda()
for i in range(5):
net.zero_grad()
out = net.forward(inp)
loss = out.sum()
loss.backward()
T.cuda.synchronize()
start=time.time()
for i in range(100):
net.zero_grad()
out = net.forward(inp)
loss = out.sum()
loss.backward()
T.cuda.synchronize()
end=time.time()
print("Functional convolution FP32 Iterations per second: ", 100/(end-start))
net = SuperResBlockNotFunctional(2).cuda()
inp = T.randn(64, 1, 224, 224, requires_grad=True).cuda()
for i in range(5):
net.zero_grad()
out = net.forward(inp)
loss = out.sum()
loss.backward()
T.cuda.synchronize()
start=time.time()
for i in range(100):
net.zero_grad()
out = net.forward(inp)
loss = out.sum()
loss.backward()
T.cuda.synchronize()
end=time.time()
print("FP32 Iterations per second: ", 100/(end-start))
net = SuperResBlock(2).cuda().half()
inp = T.randn(64, 1, 224, 224, requires_grad=True).half().cuda()
T.cuda.synchronize()
start=time.time()
for i in range(100):
net.zero_grad()
out = net.forward(inp)
loss = out.float().sum()
loss.backward()
T.cuda.synchronize()
end=time.time()
print("Functional convolution FP16 Iterations per second: ", 100/(end-start))
net = SuperResBlockNotFunctional(2).cuda().half()
inp = T.randn(64, 1, 224, 224, requires_grad=True).half().cuda()
T.cuda.synchronize()
start=time.time()
for i in range(100):
net.zero_grad()
out = net.forward(inp)
loss = out.float().sum()
loss.backward()
T.cuda.synchronize()
end=time.time()
print("FP16 Iterations per second: ", 100/(end-start))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment