Created
July 11, 2018 03:02
-
-
Save tbenst/bc79eb802df48bea37a6143567c78c85 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import time | |
| import torch as T | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.backends.cudnn as cudnn | |
| cudnn.benchmark = True | |
| class SuperResBlock(nn.Module): | |
| """Upsample Volume using subpixel convolution. | |
| Reference: https://arxiv.org/pdf/1609.05158.pdf""" | |
| def __init__(self, upscale_factor): | |
| super(SuperResBlock, self).__init__() | |
| self.activation = nn.ReLU() | |
| self.dconv1 = nn.Parameter(T.FloatTensor(64,1,5,5)) | |
| self.dpad1 = (2,2) | |
| self.dbn1 = nn.BatchNorm2d(64) | |
| self.dconv2 = nn.Parameter(T.FloatTensor(64,64,3,3)) | |
| self.dpad2 = (1,1) | |
| self.dbn2 = nn.BatchNorm2d(64) | |
| self.dconv3 = nn.Parameter(T.FloatTensor(32,64,3,3)) | |
| self.dpad3 = (1,1) | |
| self.dbn3 = nn.BatchNorm2d(32) | |
| self.dconv4 = nn.Parameter(T.FloatTensor(upscale_factor**2,32,3,3)) | |
| self.dpad4 = (1,1) | |
| self.pixel_shuffle = nn.PixelShuffle(upscale_factor) | |
| self.initialize_weights() | |
| def forward(self, x): | |
| x = self.activation(self.dbn1(F.conv2d(x, self.dconv1, padding=self.dpad1))) | |
| x = self.activation(self.dbn2(F.conv2d(x, self.dconv2, padding=self.dpad2))) | |
| x = self.activation(self.dbn3(F.conv2d(x, self.dconv3, padding=self.dpad3))) | |
| x = F.conv2d(x, self.dconv4, padding=self.dpad4) | |
| x = self.pixel_shuffle(x) | |
| # add back single channel | |
| x = x[:,:,None] | |
| return x | |
| def initialize_weights(self): | |
| nn.init.orthogonal_(self.dconv1, nn.init.calculate_gain('relu')) | |
| nn.init.orthogonal_(self.dconv2, nn.init.calculate_gain('relu')) | |
| nn.init.orthogonal_(self.dconv3, nn.init.calculate_gain('relu')) | |
| nn.init.orthogonal_(self.dconv4) | |
| for bn in [self.dbn1,self.dbn2,self.dbn3]: | |
| nn.init.constant_(bn.weight, 1) | |
| nn.init.constant_(bn.bias, 0) | |
| class SuperResBlockNotFunctional(nn.Module): | |
| """Upsample Volume using subpixel convolution. | |
| Reference: https://arxiv.org/pdf/1609.05158.pdf""" | |
| def __init__(self, upscale_factor): | |
| super(SuperResBlockNotFunctional, self).__init__() | |
| self.activation = nn.ReLU() | |
| self.dpad1 = (2,2) | |
| self.dconv1 = nn.Conv2d(1,64,5,5,padding=self.dpad1) | |
| self.dbn1 = nn.BatchNorm2d(64) | |
| self.dpad2 = (1,1) | |
| self.dconv2 = nn.Conv2d(64,64,3,3,padding=self.dpad2) | |
| self.dbn2 = nn.BatchNorm2d(64) | |
| self.dpad3 = (1,1) | |
| self.dconv3 = nn.Conv2d(64,32,3,3,padding=self.dpad3) | |
| self.dbn3 = nn.BatchNorm2d(32) | |
| self.dpad4 = (1,1) | |
| self.dconv4 = nn.Conv2d(32,upscale_factor**2,3,3,padding=self.dpad4) | |
| self.pixel_shuffle = nn.PixelShuffle(upscale_factor) | |
| self.initialize_weights() | |
| def forward(self, x): | |
| x = self.activation(self.dbn1(self.dconv1(x, ))) | |
| x = self.activation(self.dbn2(self.dconv2(x))) | |
| x = self.activation(self.dbn3(self.dconv3(x))) | |
| x = self.dconv4(x) | |
| x = self.pixel_shuffle(x) | |
| # add back single channel | |
| x = x[:,:,None] | |
| return x | |
| def initialize_weights(self): | |
| nn.init.orthogonal_(self.dconv1.weight, nn.init.calculate_gain('relu')) | |
| nn.init.orthogonal_(self.dconv2.weight, nn.init.calculate_gain('relu')) | |
| nn.init.orthogonal_(self.dconv3.weight, nn.init.calculate_gain('relu')) | |
| nn.init.orthogonal_(self.dconv4.weight) | |
| for bn in [self.dbn1,self.dbn2,self.dbn3]: | |
| nn.init.constant_(bn.weight, 1) | |
| nn.init.constant_(bn.bias, 0) | |
| net = SuperResBlock(2).cuda() | |
| inp = T.randn(64, 1, 224, 224, requires_grad=True).cuda() | |
| for i in range(5): | |
| net.zero_grad() | |
| out = net.forward(inp) | |
| loss = out.sum() | |
| loss.backward() | |
| T.cuda.synchronize() | |
| start=time.time() | |
| for i in range(100): | |
| net.zero_grad() | |
| out = net.forward(inp) | |
| loss = out.sum() | |
| loss.backward() | |
| T.cuda.synchronize() | |
| end=time.time() | |
| print("Functional convolution FP32 Iterations per second: ", 100/(end-start)) | |
| net = SuperResBlockNotFunctional(2).cuda() | |
| inp = T.randn(64, 1, 224, 224, requires_grad=True).cuda() | |
| for i in range(5): | |
| net.zero_grad() | |
| out = net.forward(inp) | |
| loss = out.sum() | |
| loss.backward() | |
| T.cuda.synchronize() | |
| start=time.time() | |
| for i in range(100): | |
| net.zero_grad() | |
| out = net.forward(inp) | |
| loss = out.sum() | |
| loss.backward() | |
| T.cuda.synchronize() | |
| end=time.time() | |
| print("FP32 Iterations per second: ", 100/(end-start)) | |
| net = SuperResBlock(2).cuda().half() | |
| inp = T.randn(64, 1, 224, 224, requires_grad=True).half().cuda() | |
| T.cuda.synchronize() | |
| start=time.time() | |
| for i in range(100): | |
| net.zero_grad() | |
| out = net.forward(inp) | |
| loss = out.float().sum() | |
| loss.backward() | |
| T.cuda.synchronize() | |
| end=time.time() | |
| print("Functional convolution FP16 Iterations per second: ", 100/(end-start)) | |
| net = SuperResBlockNotFunctional(2).cuda().half() | |
| inp = T.randn(64, 1, 224, 224, requires_grad=True).half().cuda() | |
| T.cuda.synchronize() | |
| start=time.time() | |
| for i in range(100): | |
| net.zero_grad() | |
| out = net.forward(inp) | |
| loss = out.float().sum() | |
| loss.backward() | |
| T.cuda.synchronize() | |
| end=time.time() | |
| print("FP16 Iterations per second: ", 100/(end-start)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment