Skip to content

Instantly share code, notes, and snippets.

@Dref360
Created November 23, 2017 21:59
Show Gist options
  • Select an option

  • Save Dref360/72dc8036a88f270c2c3bf493bee6daf0 to your computer and use it in GitHub Desktop.

Select an option

Save Dref360/72dc8036a88f270c2c3bf493bee6daf0 to your computer and use it in GitHub Desktop.
Test issue multiprocessing
import multiprocessing
import time
import numpy as np
import keras.backend as K
import keras.layers as KL
from keras import Model
from keras.callbacks import Callback
from keras.utils import Sequence
class superCbk(Callback):
def __init__(self, model, seq):
self.model = model
self.seq = seq
def on_epoch_end(self, epoch, logs=None):
self.results = self.model.predict_generator(self.seq, steps=10, workers=2, use_multiprocessing=True)
time.sleep(10)
inp = KL.Input([30, 30, 3])
res = KL.Lambda(lambda k: K.max(k, -1))(inp)
model = Model(inp, res)
model.compile('sgd', 'mse')
class DummySeq(Sequence):
def __init__(self, test):
self.test = test
def __getitem__(self, item):
x = np.random.rand(*[1, 30, 30, 3])
if not self.test:
y = np.random.rand(*[1, 30, 30])
return x, y
return x
def __len__(self):
return 100
train_seq = DummySeq(False)
val_seq = DummySeq(False)
test_seq = DummySeq(True)
def run():
model.fit_generator(train_seq, epochs=10, validation_data=val_seq, use_multiprocessing=True, workers=2, callbacks=[superCbk(model, test_seq)])
acc = []
for i in range(5):
pcs = multiprocessing.Process(target=run)
# pcs.daemon = True
pcs.start()
acc.append(pcs)
for pcs in acc:
print(pcs)
pcs.join()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment