Last active
October 15, 2019 19:16
-
-
Save ProGamerGov/e495448fdc665a8570cf7b57257470aa to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Inspired by: https://github.com/torch/nn/blob/master/GPU.lua | |
# And: https://github.com/jcjohnson/neural-style/blob/master/neural_style.lua#L360 | |
# As seen in: https://github.com/ProGamerGov/neural-style-pt | |
import torch | |
import torch.nn as nn | |
class ModelParallel(nn.Module): | |
r"""Splits a sequential network across multiple devices. | |
Args: | |
net (Module): a sequential model to be split across multiple devices | |
device_ids (list) list of zero-indexed GPU int and str c for CPU | |
net_splits (int or list of int): int or list of layer indices of where to split net | |
Example:: | |
>>> net = ModelParallel(model, device_ids=[0, 1, 2], net_splits=[2,5]) | |
>>> net = ModelParallel(model, device_ids=[c, 0], net_splits=[5]) # c is used for CPU ID | |
""" | |
def __init__(self, net, device_ids, device_splits): | |
super(ModelParallel, self).__init__() | |
self.device_list = self.name_devices(device_ids.split(',')) | |
self.chunks = self.chunks_to_devices(self.split_net(net, device_splits.split(','))) | |
def name_devices(self, input_list): | |
r"""Convert a list of zero-indexed GPU and CPU devices to their PyTorch names. | |
Arguments: | |
input_list (list): List of zero-indexed GPU devices, and 'c' for CPU | |
""" | |
device_list = [] | |
for i, device in enumerate(input_list): | |
if str(device).lower() != 'c': | |
device_list.append("cuda:" + str(device)) | |
else: | |
device_list.append("cpu") | |
return device_list | |
# Split a network into chunks | |
def split_net(self, net, device_splits): | |
r"""Split a sequential net in chunks. | |
Arguments: | |
net (list): A list of Sequential nets | |
net_splits (int or list of int): Layer indices of where to split net | |
""" | |
chunks, cur_chunk = [], nn.Sequential() | |
for i, l in enumerate(net): | |
cur_chunk.add_module(str(i), net[i]) | |
if str(i) in device_splits and device_splits != '': | |
del device_splits[0] | |
chunks.append(cur_chunk) | |
cur_chunk = nn.Sequential() | |
chunks.append(cur_chunk) | |
return chunks | |
def chunks_to_devices(self, chunks, device_list): | |
r"""Put a list of Sequential nets onto different devices. | |
Arguments: | |
chunks (list): A list of Sequential nets | |
device_list (list of string): A list of PyTorch device names | |
""" | |
for i, chunk in enumerate(chunks): | |
chunk.to(self.device_list[i]) | |
return chunks | |
def c(self, input, i): | |
r"""Convert a tensor to a device from self.device_list[i]'s backend. | |
Arguments: | |
input (Tensor): A float or CUDA tensor | |
i (int): An index value for self.device_list | |
""" | |
if input.type() == 'torch.FloatTensor' and 'cuda' in self.device_list[i]: | |
input = input.type('torch.cuda.FloatTensor') | |
elif input.type() == 'torch.cuda.FloatTensor' and 'cpu' in self.device_list[i]: | |
input = input.type('torch.FloatTensor') | |
return input | |
def forward(self, input): | |
for i, chunk in enumerate(self.chunks): | |
if i < len(self.chunks) -1: | |
input = self.c(chunk(self.c(input, i).to(self.device_list[i])), i+1).to(self.device_list[i+1]) | |
else: | |
input = chunk(input) | |
return input |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment