Last active
December 22, 2020 15:11
-
-
Save nilesh0109/da7ec2d228325298384edca2eb3d76c0 to your computer and use it in GitHub Desktop.
Pytorch model split across different gpus
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import nn | |
class Network(nn.Module): | |
def __init__(self, split_gpus=False): | |
super().__init__() | |
self.module1 = ... | |
self.module2 = ... | |
self.split_gpus = split_gpus | |
if split_gpus: #considering only two gpus | |
self.module1.cuda(0) | |
self.module2.cuda(1) | |
def forward(self, x): | |
if self.split_gpus: | |
x = x.cuda(0) | |
x = self.module1(x) | |
if self.split_gpus: | |
x = x.cuda(1) | |
x = self.module2(x) | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment