Created
October 21, 2020 21:27
-
-
Save sandeepkumar-skb/2a9fa7950a01d0856fc7e678e8768961 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import time | |
from torch import Tensor | |
class Net(nn.Module): | |
def __init__(self, features): | |
super().__init__() | |
self.fc_layers = [nn.Linear(features, features) for _ in range(100)] | |
self.layers = nn.Sequential(*self.fc_layers) | |
def forward(self, inp): | |
return self.layers(inp) | |
class Model(nn.Module): | |
def __init__(self, features): | |
super().__init__() | |
self.net = Net(features) | |
@torch.jit.export | |
def forward_fork(self,inp): | |
b1 = torch.jit.fork(self.net, inp) | |
b2 = torch.jit.fork(self.net, inp) | |
return torch.jit.wait(b1) + torch.jit.wait(b2) | |
@torch.jit.export | |
def forward(self, inp): | |
b1 = self.net(inp) | |
b2 = self.net(inp) | |
return b1 + b2 | |
if __name__ == "__main__": | |
features = 1024 | |
inp = torch.rand((features,features), device='cuda') | |
net = torch.jit.script(Model(features).cuda().eval()) | |
count = 0 | |
times_forked = 0 | |
times = 0 | |
for i in range(200): | |
torch.cuda.synchronize() | |
start = time.time() | |
out = net.forward_fork(inp) | |
torch.cuda.synchronize() | |
end = time.time() | |
if i > 10: | |
count += 1 | |
times_forked += end - start | |
torch.cuda.synchronize() | |
start = time.time() | |
out = net.forward(inp) | |
torch.cuda.synchronize() | |
end = time.time() | |
if i > 10: | |
times += (end - start) | |
print(f"Time taken for jitted model without fork: {times/count * 1000} ms") | |
print(f"Time taken for jitted model with fork: {times_forked/count * 1000} ms") |
Author
sandeepkumar-skb
commented
Oct 22, 2020
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment