Skip to content

Instantly share code, notes, and snippets.

@bearpelican
Created November 21, 2018 00:28
Show Gist options
  • Select an option

  • Save bearpelican/502e12234c82d72a82168788de12f0c0 to your computer and use it in GitHub Desktop.

Select an option

Save bearpelican/502e12234c82d72a82168788de12f0c0 to your computer and use it in GitHub Desktop.
import os
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.utils.data
import torch.utils.data.distributed
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
cudnn.benchmark = True
print('Distributed initializing process group')
# torch.cuda.set_device(0)
dist.init_process_group(backend='nccl', init_method='tcp://localhost:8008', world_size=1, rank=0)
print('Loading model')
first_layer = nn.Linear(10, 10)
second_layer = nn.Linear(10, 10)
model = nn.Sequential(first_layer, second_layer).cuda()
for p in first_layer.parameters(): p.requires_grad_(False)
print('Loading distributed')
model = DistributedDataParallel(model, device_ids=[0], output_device=0)
print('Forward')
out = model(torch.ones([1,10]).cuda())
# Backprop
print('Backward')
loss = out.sum()
loss.backward()
print('DONE')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment