Created
May 24, 2018 16:49
-
-
Save basnijholt/ca3dce5050504407cfc1cd2a31ce78d0 to your computer and use it in GitHub Desktop.
`run_simulation` with ipyparallel
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Common functions for doing stuff.""" | |
from copy import copy | |
from datetime import timedelta | |
import functools | |
from glob import glob | |
from itertools import product | |
import os | |
import time | |
import subprocess | |
import sys | |
import numpy as np | |
import pandas as pd | |
assert sys.version_info >= (3, 6), 'Use Python ≥3.6' | |
def run_simulation(lview, func, vals, parameters, fname_i, N=None, | |
overwrite=False): | |
"""Run a simulation where one loops over `vals`. The simulation | |
yields len(vals) results, but by using `N`, you can split it up | |
in parts of length N. | |
Parameters | |
---------- | |
lview : ipyparallel.client.view.LoadBalancedView object | |
LoadBalancedView for asynchronous map. | |
func : function | |
Function that takes a list of arguments: `vals`. | |
vals : list | |
Arguments for `func`. | |
parameters : dict | |
Dictionary that is saved with the data, used for constant | |
parameters. | |
fname_i : str | |
Name for the resulting HDF5 files. If the simulation is | |
split up in parts by using the `N` argument, it needs to | |
be a formatteble string, for example 'file_{}'. | |
N : int | |
Number of results in each pandas.DataFrame. | |
overwrite : bool | |
Overwrite the file even if it already exists. | |
""" | |
if N is None: | |
N = 1000000 | |
if len(vals) > N: | |
raise Exception('You need to split up vals in smaller parts') | |
N_files = len(vals) // N + (0 if len(vals) % N == 0 else 1) | |
print('`vals` will be split in {} files.'.format(N_files)) | |
time_elapsed = 0 | |
parts_done = 0 | |
parameters_new = {} | |
for k, v in parameters.items(): | |
if callable(v): | |
warnings.warn('parameters["{}"] is a function and is not saved!'.format(k)) | |
else: | |
parameters_new[k] = v | |
parameters = parameters_new | |
if N < len(vals) and fname_i.format('1') == fname_i.format('2'): | |
raise Exception('Use a formattable string for `fname_i`.') | |
for i, chunk in enumerate(partition_all(N, vals)): | |
fname = fname_i.replace('{}', '{:03d}').format(i) | |
print('Busy with file: {}.'.format(fname)) | |
if not os.path.exists(fname) or overwrite: | |
map_async = lview.map_async(func, chunk) | |
map_async.wait_interactive() | |
result = map_async.result() | |
df = pd.DataFrame(result) | |
common_keys = common_elements(df.columns, parameters.keys()) | |
if common_keys: | |
raise Exception('Parameters in both function result and function input', | |
': {}'.format(common_keys)) | |
else: | |
df = df.assign(**parameters) | |
#df = df.assign(git_hash=get_git_revision_hash()) | |
os.makedirs(os.path.dirname(fname), exist_ok=True) | |
df.to_hdf(fname, 'all_data', mode='w', complib='zlib', complevel=9) | |
# Print useful information | |
N_files_left = N_files - (i + 1) | |
parts_done += 1 | |
time_elapsed += map_async.elapsed | |
time_left = timedelta(seconds=(time_elapsed / parts_done) * | |
N_files_left) | |
print_str = ('Saved {}, {} more files to go, {} time left ' | |
'before everything is done.') | |
print(print_str.format(fname, N_files_left, time_left)) | |
else: | |
print('File: {} was already done.'.format(fname)) | |
def common_elements(list1, list2): | |
return [element for element in list1 if element in list2] | |
def parse_params(params): | |
for k, v in params.items(): | |
if isinstance(v, str): | |
try: | |
params[k] = eval(v) | |
except NameError: | |
pass | |
return params | |
def combine_dfs(pattern, fname=None): | |
files = glob(pattern) | |
df = pd.concat([pd.read_hdf(f) for f in sorted(files)]) | |
df = df.reset_index(drop=True) | |
if fname is not None: | |
os.makedirs(os.path.dirname(fname), exist_ok=True) | |
df.to_hdf(fname, 'all_data', mode='w', complib='zlib', complevel=9) | |
return df | |
def memoize(obj): | |
cache = obj.cache = {} | |
@functools.wraps(obj) | |
def memoizer(*args, **kwargs): | |
key = str(args) + str(kwargs) | |
if key not in cache: | |
cache[key] = obj(*args, **kwargs) | |
return cache[key] | |
return memoizer | |
def named_product(**items): | |
names = items.keys() | |
vals = items.values() | |
return [dict(zip(names, res)) for res in product(*vals)] | |
def get_git_revision_hash(): | |
"""Get the git hash to save with data to ensure reproducibility.""" | |
git_output = subprocess.check_output(['git', 'rev-parse', 'HEAD']) | |
return git_output.decode("utf-8").replace('\n', '') | |
def remove_unhashable_columns(df): | |
df = df.copy() | |
for col in df.columns: | |
if not hashable(df[col].iloc[0]): | |
df.drop(col, axis=1, inplace=True) | |
return df | |
def hashable(v): | |
"""Determine whether `v` can be hashed.""" | |
try: | |
hash(v) | |
except TypeError: | |
return False | |
return True | |
def drop_constant_columns(df): | |
"""Taken from http://stackoverflow.com/a/20210048/3447047""" | |
df = remove_unhashable_columns(df) | |
df = df.reset_index(drop=True) | |
return df.loc[:, (df != df.ix[0]).any()] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment