Skip to content

Instantly share code, notes, and snippets.

@nilesh0109
Last active December 22, 2020 15:11
Show Gist options
  • Save nilesh0109/da7ec2d228325298384edca2eb3d76c0 to your computer and use it in GitHub Desktop.
Save nilesh0109/da7ec2d228325298384edca2eb3d76c0 to your computer and use it in GitHub Desktop.
Pytorch model split across different gpus
from torch import nn
class Network(nn.Module):
def __init__(self, split_gpus=False):
super().__init__()
self.module1 = ...
self.module2 = ...
self.split_gpus = split_gpus
if split_gpus: #considering only two gpus
self.module1.cuda(0)
self.module2.cuda(1)
def forward(self, x):
if self.split_gpus:
x = x.cuda(0)
x = self.module1(x)
if self.split_gpus:
x = x.cuda(1)
x = self.module2(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment