Skip to content

Instantly share code, notes, and snippets.

@tudor-berariu
Created March 30, 2017 14:48
Show Gist options
  • Save tudor-berariu/94a63aee085d62d8c443ea8b91e7ed21 to your computer and use it in GitHub Desktop.
Save tudor-berariu/94a63aee085d62d8c443ea8b91e7ed21 to your computer and use it in GitHub Desktop.
PyTorch code hanging with latest commit
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
from torch.autograd import Variable
import sys
IN_SIZE = 10
OUT_SIZE = 17
BATCH_SIZE = 32
WORKERS_NO = 1
class SimpleModel(nn.Module):
def __init__(self, in_size, out_size):
super(SimpleModel, self).__init__()
h_size = in_size + (out_size - in_size) // 2
self.l1 = nn.Linear(in_size, h_size)
self.l2 = nn.Linear(h_size, out_size)
def forward(self, inputs):
return F.tanh(self.l2(F.relu(self.l1(inputs))))
class Worker(mp.Process):
def __init__(self, pid, shared_model):
super(Worker, self).__init__()
self.my_pid = pid
self.shared_model = shared_model
def run(self):
shared_model = self.shared_model
my_pid = self.my_pid
sys.stdout.write("Worker {:d} started.\n".format(my_pid))
loss = F.smooth_l1_loss(
shared_model(Variable(torch.rand(BATCH_SIZE, IN_SIZE))),
Variable(torch.rand(BATCH_SIZE, OUT_SIZE))
)
sys.stdout.write("Forward phase done for {:d}!\n".format(my_pid))
sys.stdout.flush()
loss.backward()
sys.stdout.write("Backward phase done for {:d}!\n".format(my_pid))
sys.stdout.flush()
def main():
shared_model = SimpleModel(IN_SIZE, OUT_SIZE)
shared_model.share_memory()
loss = F.smooth_l1_loss(
shared_model(Variable(torch.rand(BATCH_SIZE, IN_SIZE))),
Variable(torch.rand(BATCH_SIZE, OUT_SIZE))
)
loss.backward()
workers = [Worker(i, shared_model) for i in range(WORKERS_NO)]
for worker in workers:
worker.start()
for worker in workers:
worker.join()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment