Last active
December 9, 2020 16:18
-
-
Save matthewfl/08faba910623913ffa8eb01a60b6a14b to your computer and use it in GitHub Desktop.
pytorch hogwild
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.multiprocessing as mp | |
import queue | |
import resource | |
import os | |
class SimpleSGD(torch.optim.Optimizer): | |
def __init__(self, params, lr=.01): | |
defaults = dict(lr=lr) | |
super(SimpleSGD, self).__init__(params, defaults) | |
def step(self, closure=None, lr=.01): | |
loss = None | |
if closure is not None: | |
loss = closure() | |
for group in self.param_groups: | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
d_p = p.grad.data | |
p.data.add_(-lr, d_p) | |
return loss | |
def run_processes(model, comm): | |
# change the gradient parameters to be independent | |
for param in model.parameters(): | |
if param.grad: | |
param.grad.data = param.grad.data.clone() | |
model.reset() | |
pid = os.getpid() | |
cnt = 0 | |
#optim = torch.optim.Adadelta(model.parameters()) | |
optim = SimpleSGD(model.parameters()) | |
try: | |
while True: | |
cnt_i, (idx, itm), alpha = comm.get(timeout=30) | |
cnt += 1 | |
try: | |
model.reset() | |
optim.zero_grad() | |
############################################################### | |
# CHANGE THIS TO YOUR MODEL | |
loss, encoded, _ = model(itm) | |
if idx < learned_parameters.shape[0]: | |
learned_parameters[idx, :] = encoded.data.numpy() | |
print(pid, cnt, cnt_i, loss) | |
loss.backward(retain_variables=True) | |
optim.step(lr=alpha) | |
################################################################ | |
except Exception as e: | |
print('###################') | |
print('FAILED: ', cnt, cnt_i) | |
import traceback | |
traceback.print_exc() | |
print('###################') | |
ram_usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss | |
if ram_usage > 1024*1024*1024*5.5 or cnt > 150: | |
# then we must have leaked somewhere | |
# we will kill ourselves and let the main process start a new thread | |
print('###################') | |
print('KILLING SELF, using: ', ram_usage) | |
print('SELF: ',pid) | |
print('###################') | |
os._exit(1) | |
break | |
except queue.Empty as e: | |
# then we are done with processing | |
pass | |
def run_parallel(model, data_x, n_threads=10, save='/home/mfran/data/ce/m6-{}.pth'): | |
model.share_memory() | |
import gc | |
gc.collect() | |
gc.disable() | |
comm = mp.Queue(n_threads * 2 + 10) | |
processes = [] | |
for _ in range(n_threads): | |
p = mp.Process(target=run_processes, args=(model, comm)) | |
p.start() | |
processes.append(p) | |
cnt = 0 | |
for p in data_x: | |
cnt += 1 | |
alpha = args.alpha / (cnt ** .6) | |
comm.put((cnt, p, alpha)) | |
print('added: ', cnt) | |
if cnt % (n_threads - 2) == 0: | |
# check that all the processes are still active otherwise replace | |
for i in range(n_threads): | |
p = processes[i] | |
if not p.is_alive(): | |
# start a new process after an existing one killed itself | |
p.join() | |
p = mp.Process(target=run_processes, args=(model, comm)) | |
p.start() | |
processes[i] = p | |
if cnt % 5000 == 0: | |
torch.save(model, save.format(cnt)) | |
for p in processes: | |
p.join() | |
############################################################ | |
# use by doing something like: | |
run_parallel(model, [ data_point_1, data_point_2, .... ] or an iterator) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment