Created
January 1, 2017 23:15
-
-
Save anonymous/b142a40d1ed309251cd1850fb8b622ca to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.multiprocessing as mp | |
from torch.multiprocessing import Semaphore | |
import sys | |
if sys.version_info[0] == 3: | |
Barrier = mp.Barrier | |
else: # version 2 | |
# from http://stackoverflow.com/a/26703365/117844 | |
class Barrier: | |
def __init__(self, n): | |
self.n = n | |
self.count = 0 | |
self.mutex = Semaphore(1) | |
self.barrier = Semaphore(0) | |
def wait(self): | |
self.mutex.acquire() | |
self.count = self.count + 1 | |
self.mutex.release() | |
if self.count == self.n: self.barrier.release() | |
self.barrier.acquire() | |
self.barrier.release() | |
class ParameterServer(object): | |
def __init__(self, n_processes): | |
self.queue = mp.Queue() | |
self.n_processes = n_processes | |
self.barrier = Barrier(n_processes) | |
def __getstate__(self): | |
return (self.queue, self.barrier, self.n_processes) | |
def __setstate__(self, state): | |
self.queue, self.barrier, self.n_processes = state | |
def sync_model(self, rank, model=None): | |
if rank == 0: | |
assert model is not None | |
for i in range(self.n_processes-1): | |
self.queue.put(model) | |
else: | |
model = self.queue.get() | |
# clone the gradients to break the sharing | |
for param in model.parameters(): | |
param._grad = param.grad.clone() | |
self.barrier.wait() | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment