Skip to content

Instantly share code, notes, and snippets.

@matthewfl
Last active December 9, 2020 16:18
Show Gist options
  • Save matthewfl/08faba910623913ffa8eb01a60b6a14b to your computer and use it in GitHub Desktop.
Save matthewfl/08faba910623913ffa8eb01a60b6a14b to your computer and use it in GitHub Desktop.
pytorch hogwild
import torch
import torch.multiprocessing as mp
import queue
import resource
import os
class SimpleSGD(torch.optim.Optimizer):
def __init__(self, params, lr=.01):
defaults = dict(lr=lr)
super(SimpleSGD, self).__init__(params, defaults)
def step(self, closure=None, lr=.01):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data
p.data.add_(-lr, d_p)
return loss
def run_processes(model, comm):
# change the gradient parameters to be independent
for param in model.parameters():
if param.grad:
param.grad.data = param.grad.data.clone()
model.reset()
pid = os.getpid()
cnt = 0
#optim = torch.optim.Adadelta(model.parameters())
optim = SimpleSGD(model.parameters())
try:
while True:
cnt_i, (idx, itm), alpha = comm.get(timeout=30)
cnt += 1
try:
model.reset()
optim.zero_grad()
###############################################################
# CHANGE THIS TO YOUR MODEL
loss, encoded, _ = model(itm)
if idx < learned_parameters.shape[0]:
learned_parameters[idx, :] = encoded.data.numpy()
print(pid, cnt, cnt_i, loss)
loss.backward(retain_variables=True)
optim.step(lr=alpha)
################################################################
except Exception as e:
print('###################')
print('FAILED: ', cnt, cnt_i)
import traceback
traceback.print_exc()
print('###################')
ram_usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
if ram_usage > 1024*1024*1024*5.5 or cnt > 150:
# then we must have leaked somewhere
# we will kill ourselves and let the main process start a new thread
print('###################')
print('KILLING SELF, using: ', ram_usage)
print('SELF: ',pid)
print('###################')
os._exit(1)
break
except queue.Empty as e:
# then we are done with processing
pass
def run_parallel(model, data_x, n_threads=10, save='/home/mfran/data/ce/m6-{}.pth'):
model.share_memory()
import gc
gc.collect()
gc.disable()
comm = mp.Queue(n_threads * 2 + 10)
processes = []
for _ in range(n_threads):
p = mp.Process(target=run_processes, args=(model, comm))
p.start()
processes.append(p)
cnt = 0
for p in data_x:
cnt += 1
alpha = args.alpha / (cnt ** .6)
comm.put((cnt, p, alpha))
print('added: ', cnt)
if cnt % (n_threads - 2) == 0:
# check that all the processes are still active otherwise replace
for i in range(n_threads):
p = processes[i]
if not p.is_alive():
# start a new process after an existing one killed itself
p.join()
p = mp.Process(target=run_processes, args=(model, comm))
p.start()
processes[i] = p
if cnt % 5000 == 0:
torch.save(model, save.format(cnt))
for p in processes:
p.join()
############################################################
# use by doing something like:
run_parallel(model, [ data_point_1, data_point_2, .... ] or an iterator)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment