Created
February 10, 2017 05:26
-
-
Save bnaul/e6a16323cda2e266eb040175a9112073 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from argparse import Namespace | |
from multiprocessing import Pool, current_process | |
import sys | |
import tempfile | |
from sklearn.model_selection import ParameterGrid | |
from keras_util import parse_model_args, limited_memory_session | |
from autoencoder import main as autoencoder | |
from period import main as period | |
from asas import main as asas | |
from asas_full import main as asas_full | |
def set_session(num_gpus, procs_per_gpu, tmpdirname): | |
gpu_frac = 0.96 / procs_per_gpu | |
gpu_id = int(current_process().name.split('-')[-1]) % num_gpus | |
limited_memory_session(gpu_frac, gpu_id) | |
log_file = tempfile.NamedTemporaryFile('w') | |
print(log_file.name) | |
sys.stdout = log_file | |
if __name__ == '__main__': | |
NUM_GPUS = 2 | |
PROCS_PER_GPU = 1 | |
simulation = autoencoder | |
model_types = ['conv', 'gru'] | |
params = { | |
'sim_type': ['test'], | |
'size': [16, 32], | |
'num_layers': [1, 2], | |
'drop_frac': [0.25], | |
'n_min': [20], 'n_max': [20], 'sigma': [0.5], | |
'nb_epoch': [25], 'lr': [5e-4], | |
} | |
conv_only_args = { | |
'filter_length': [5], 'batch_norm': [True] | |
} | |
param_grid = ParameterGrid([{'model_type': [t], **params, | |
**(conv_only_args if t in ['conv', 'atrous'] else {})} | |
for t in model_types]) | |
full_param_grids = [parse_model_args(p) for p in param_grid] | |
with tempfile.TemporaryDirectory() as tmpdirname: | |
with Pool(NUM_GPUS * PROCS_PER_GPU, initializer=set_session, | |
initargs=(NUM_GPUS, PROCS_PER_GPU, tmpdirname)) as pool: | |
pool.map(simulation, full_param_grids) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment