Created
November 23, 2017 21:59
-
-
Save Dref360/72dc8036a88f270c2c3bf493bee6daf0 to your computer and use it in GitHub Desktop.
Test issue multiprocessing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import multiprocessing | |
| import time | |
| import numpy as np | |
| import keras.backend as K | |
| import keras.layers as KL | |
| from keras import Model | |
| from keras.callbacks import Callback | |
| from keras.utils import Sequence | |
| class superCbk(Callback): | |
| def __init__(self, model, seq): | |
| self.model = model | |
| self.seq = seq | |
| def on_epoch_end(self, epoch, logs=None): | |
| self.results = self.model.predict_generator(self.seq, steps=10, workers=2, use_multiprocessing=True) | |
| time.sleep(10) | |
| inp = KL.Input([30, 30, 3]) | |
| res = KL.Lambda(lambda k: K.max(k, -1))(inp) | |
| model = Model(inp, res) | |
| model.compile('sgd', 'mse') | |
| class DummySeq(Sequence): | |
| def __init__(self, test): | |
| self.test = test | |
| def __getitem__(self, item): | |
| x = np.random.rand(*[1, 30, 30, 3]) | |
| if not self.test: | |
| y = np.random.rand(*[1, 30, 30]) | |
| return x, y | |
| return x | |
| def __len__(self): | |
| return 100 | |
| train_seq = DummySeq(False) | |
| val_seq = DummySeq(False) | |
| test_seq = DummySeq(True) | |
| def run(): | |
| model.fit_generator(train_seq, epochs=10, validation_data=val_seq, use_multiprocessing=True, workers=2, callbacks=[superCbk(model, test_seq)]) | |
| acc = [] | |
| for i in range(5): | |
| pcs = multiprocessing.Process(target=run) | |
| # pcs.daemon = True | |
| pcs.start() | |
| acc.append(pcs) | |
| for pcs in acc: | |
| print(pcs) | |
| pcs.join() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment