Last active
March 15, 2017 14:25
-
-
Save bobchennan/d414da80f2f5939e1660898aab42a415 to your computer and use it in GitHub Desktop.
keras sharedmem
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from multiprocessing import Pipe, Process, Manager | |
from time import sleep | |
import sharedmem | |
batches = Manager().dict() | |
def train_process(model, num_batches, save_path, q): | |
cnt = 0 | |
while True: | |
x, y = q.recv() | |
if x is None: | |
break | |
model.train_on_batch(batches[x], batches[y]) | |
del batches[x], batches[y] | |
model.save_weight(save_path) | |
class Training(): | |
def __init__(self, build_network, post_process, generator, max_limit): | |
self.model = build_network | |
self.post = post_process | |
self.gen = generator | |
self.maxl = max_limit | |
def work(self, num_itr): | |
model = build_network() | |
c1, c2 = Pipe() | |
count = 0 | |
for _ in xrange(num_itr): | |
p = Process(target = train_process, args=(self.gen.size(), 'model.hdf5', c2) | |
count = -1 | |
for cas in xrange(self.gen.size()): | |
while len(arrays)>=self.maxl: | |
sleep(0.05) | |
x, y = next(self.gen) | |
retx = sharedmem.zeros(x.shape) | |
rety = sharedmem.zeros(y.shape) | |
rety[:] = y[:] | |
for batch_idx in xrange(x.shape[0]): | |
retx[i] = self.post(x[batch_idx]) | |
count += 1 | |
arrays[count] = retx | |
count += 1 | |
arrays[count] = rety | |
c1.send((count-1, count)) | |
c1.send((None, None)) | |
p.join() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment