Skip to content

Instantly share code, notes, and snippets.

@sandeepkumar-skb
Created October 21, 2020 21:27
Show Gist options
  • Save sandeepkumar-skb/2a9fa7950a01d0856fc7e678e8768961 to your computer and use it in GitHub Desktop.
Save sandeepkumar-skb/2a9fa7950a01d0856fc7e678e8768961 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import time
from torch import Tensor
class Net(nn.Module):
def __init__(self, features):
super().__init__()
self.fc_layers = [nn.Linear(features, features) for _ in range(100)]
self.layers = nn.Sequential(*self.fc_layers)
def forward(self, inp):
return self.layers(inp)
class Model(nn.Module):
def __init__(self, features):
super().__init__()
self.net = Net(features)
@torch.jit.export
def forward_fork(self,inp):
b1 = torch.jit.fork(self.net, inp)
b2 = torch.jit.fork(self.net, inp)
return torch.jit.wait(b1) + torch.jit.wait(b2)
@torch.jit.export
def forward(self, inp):
b1 = self.net(inp)
b2 = self.net(inp)
return b1 + b2
if __name__ == "__main__":
features = 1024
inp = torch.rand((features,features), device='cuda')
net = torch.jit.script(Model(features).cuda().eval())
count = 0
times_forked = 0
times = 0
for i in range(200):
torch.cuda.synchronize()
start = time.time()
out = net.forward_fork(inp)
torch.cuda.synchronize()
end = time.time()
if i > 10:
count += 1
times_forked += end - start
torch.cuda.synchronize()
start = time.time()
out = net.forward(inp)
torch.cuda.synchronize()
end = time.time()
if i > 10:
times += (end - start)
print(f"Time taken for jitted model without fork: {times/count * 1000} ms")
print(f"Time taken for jitted model with fork: {times_forked/count * 1000} ms")
@sandeepkumar-skb
Copy link
Author

Screen Shot 2020-10-21 at 7 53 00 PM

@sandeepkumar-skb
Copy link
Author

Multiple threads launched by forking-
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment