Created
January 25, 2020 19:30
-
-
Save innocenat/87e5c24d6be25094301188ea5cdb8bca to your computer and use it in GitHub Desktop.
Just some helper tools I created to help me save all research data, ever.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import os | |
import re | |
import time | |
from os import path | |
from typing import TextIO, List, Dict, Callable, AnyStr, Any, Tuple, Pattern, Optional | |
import matplotlib.pyplot as plt | |
import numpy as np | |
class NumpyEncoder(json.JSONEncoder): | |
def default(self, obj): | |
if isinstance(obj, np.ndarray): | |
return obj.tolist() | |
if isinstance(obj, np.float32) or isinstance(obj, np.float64): | |
return float(obj) | |
return json.JSONEncoder.default(self, obj) | |
class LRUCache: | |
def __init__(self): | |
pass | |
def cache(self, data_key: str, data: Any, metadata: Any) -> None: | |
pass | |
def read(self, data_key: str) -> Any: | |
return None | |
class Serializer: | |
@staticmethod | |
def write(fp: TextIO, data: Any, metadata: Any) -> None: | |
if metadata is not None: | |
raise Exception('Cannot store metadata yet') | |
json.dump(data, fp, cls=NumpyEncoder) | |
@staticmethod | |
def read(fp: TextIO) -> Tuple[Any, Any]: | |
return json.load(fp), None | |
class Provider: | |
TYPE_DATA = 0 # Default | |
TYPE_GRAPH = 1 | |
function: Optional[Callable[['DataFlowInstanced'], None]] | |
provides: List[str] | |
dataset: Pattern[AnyStr] | |
depends: List[str] | |
requires: List[str] | |
consumes: List[str] | |
type: int | |
def __init__(self, function: Optional[Callable[['DataFlowInstanced'], None]], provides: List[str], | |
dataset: str = None, depends: List[str] = None, requires: List[str] = None, | |
consumes: List[str] = None, provider_type: int = 0): | |
self.function = function | |
self.provides = provides | |
self.dataset = re.compile(dataset) if dataset is not None else None | |
self.depends = depends if depends is not None else [] | |
self.requires = requires if requires is not None else [] | |
self.consumes = consumes if consumes is not None else [] | |
self.type = provider_type | |
def can_provide(self, dataset: str, data_name: str) -> bool: | |
if self.dataset is None: | |
return data_name in self.provides | |
return self.dataset.match(dataset) and data_name in self.provides | |
def do_provide(self, flow: 'DataFlowInstanced') -> None: | |
self.function(flow) | |
def required_options_name(self, flow: 'DataFlow') -> List[str]: | |
options = self.requires.copy() | |
for data_name in self.depends: | |
provider = flow.provider(data_name) | |
if provider is None: | |
raise Exception('Cannot find provider for {}'.format(data_name)) | |
options.extend(flow.provider(data_name).required_options_name(flow)) | |
for data_name in self.consumes: | |
if data_name in options: | |
options.remove(data_name) | |
return options | |
def is_depend_on(self, data_name: str) -> bool: | |
return data_name in self.depends | |
def has_option(self, option_name: str) -> bool: | |
return option_name in self.requires | |
class DataFlow: | |
_data_directory: str | |
_dataset: str | |
providers: List[Provider] | |
_remap: Dict[str, str] | |
_reverse_remap: Dict[str, str] | |
options: Dict[str, any] | |
_plt_current_fig: Any | |
_plt_plot_options: Dict[str, any] | |
def __init__(self, data_directory: str): | |
self._data_directory = data_directory | |
self._dataset = '__default__' | |
self.providers = [] | |
self._remap = {} | |
self._reverse_remap = {} | |
self._options = {} | |
self._plt_current_fig = None | |
self._plt_plot_options = {} | |
def add_provider(self, provider: Provider) -> None: | |
self.providers.append(provider) | |
def provider(self, data_name: str) -> Optional[Provider]: | |
data_name = self._actual_data_name(data_name) | |
for p in self.providers: | |
if p.can_provide(self._dataset, data_name): | |
return p | |
return None | |
def dataset(self, dataset: str) -> None: | |
self._dataset = dataset | |
data_path = "{}/{}".format(self._data_directory, dataset) | |
if not path.isdir(data_path): | |
os.makedirs(data_path) | |
def options(self, options: Dict[str, Any]) -> None: | |
self._options = options | |
def remap(self, src: str, dst: str) -> None: | |
self._remap[dst] = src | |
self._reverse_remap[src] = dst | |
def dataset_match(self, regexp: Pattern[AnyStr]) -> bool: | |
return re.match(regexp, self._dataset) is not None | |
def request(self, data_name: str, data_options: Dict[str, Any] = None) -> Any: | |
return DataFlowInstanced(self, None, self._options).request(data_name, data_options) | |
def plot(self, data_name: str, data_options: Dict[str, Any] = None): | |
plt.clf() | |
self._plt_current_fig = None | |
self.request(data_name, data_options) | |
fig = self._plt_current_fig | |
if fig is None: | |
raise Exception('Plotter {} did not plot'.format(data_name)) | |
fig.savefig(self.filepath(data_name, self._plt_plot_options, 'eps'), dpi=fig.dpi) | |
plt.show() | |
def get_plt_axe(self, nrows=1, ncols=1): | |
fig, ax = plt.subplots(nrows, ncols) | |
self._plt_current_fig = fig | |
return fig, ax | |
def _actual_data_name(self, data_name: str) -> str: | |
if data_name in self._remap: | |
return self._remap[data_name] | |
return data_name | |
def filepath(self, data_name: str, data_options: Dict[str, Any], ext: str = 'dat') -> str: | |
options_string = [] | |
for k, v in data_options.items(): | |
options_string.append('{}-{}'.format(k, v)) | |
options_string = list(sorted(options_string)) | |
option_string = '' | |
if len(options_string) > 0: | |
option_string = '__' + '__'.join(options_string) | |
return "{}/{}/{}{}.{}".format(self._data_directory, self._dataset, data_name, option_string, ext) | |
class DataFlowInstanced: | |
_flow: DataFlow | |
_provider: Optional[Provider] | |
_options: Dict[str, Any] | |
# Plotting environment | |
is_graphics: bool | |
def __init__(self, flow: DataFlow, provider: Optional[Provider], options: Dict[str, Any]): | |
self._flow = flow | |
self._provider = provider | |
self._options = options | |
self.is_graphics = provider is not None and provider.type == Provider.TYPE_GRAPH | |
if self.is_graphics: | |
flow._plt_plot_options = dict(options) | |
def store(self, data_name: str, data: Any) -> None: | |
if self._provider is None: | |
raise Exception('Cannot store data without associated provider') | |
if data_name not in self._provider.provides: | |
raise Exception('Provider for {} cannot provide data {}'.format(self._provider.provides, data_name)) | |
if self.is_graphics: | |
raise Exception('Cannot store graphics data {}'.format(data_name)) | |
filepath = self._flow.filepath(data_name, self._options) | |
with open(filepath, "w") as fp: | |
Serializer.write(fp, data, None) | |
def _load(self, data_name: str, data_options: Dict[str, any]) -> Any: | |
if self._provider is not None and not self._provider.is_depend_on(data_name): | |
raise Exception('Data {} cannot have dependency on {}'.format(self._provider.provides, data_name)) | |
filepath = self._flow.filepath(data_name, data_options) | |
if not path.isfile(filepath): | |
return None | |
with open(filepath, "r") as fp: | |
data, metadata = Serializer.read(fp) | |
return data | |
def get_plt(self, nrows=1, ncols=1): | |
if not self.is_graphics: | |
raise Exception('Cannot get plotting environment for non-graphics data') | |
return self._flow.get_plt_axe(nrows, ncols) | |
def request(self, data_name: str, data_options: Dict[str, Any] = None) -> Any: | |
if self._provider is not None and not self._provider.is_depend_on(data_name): | |
raise Exception('Data {} cannot have dependency on {}'.format(self._provider.provides, data_name)) | |
current_options = self._options.copy() | |
if data_options is not None: | |
current_options.update(data_options) | |
provider = self._flow.provider(data_name) | |
if provider is None: | |
raise Exception('Cannot find provider for data "{}"'.format(data_name)) | |
target_options_name = provider.required_options_name(self._flow) | |
target_options = {} | |
for target_option_name in target_options_name: | |
if target_option_name not in current_options: | |
raise Exception("Option {} not provided for data {}".format(target_option_name, data_name)) | |
target_options[target_option_name] = current_options[target_option_name] | |
if provider.type != Provider.TYPE_GRAPH: | |
data = self._load(data_name, target_options) | |
if data is not None: | |
# TODO Validate last modified chain | |
return data | |
sub_instance = DataFlowInstanced(self._flow, provider, target_options) | |
print('Executing {}...'.format(data_name)) | |
t0 = time.time() | |
provider.do_provide(sub_instance) | |
print('Elapsed ({}): {} seconds.'.format(data_name, time.time() - t0)) | |
if provider.type != Provider.TYPE_GRAPH: | |
return self._load(data_name, target_options) | |
def requires(self, options_list: List[str]) -> List[Any]: | |
return [self.option(name) for name in options_list] | |
def option(self, option_name: str) -> Any: | |
if self._provider is not None and not self._provider.has_option(option_name): | |
raise Exception('Data {} cannot have option {}'.format(self._provider.provides, option_name)) | |
if option_name not in self._options: | |
raise Exception('Option {} not provided.'.format(option_name)) | |
return self._options[option_name] | |
if __name__ == '__main__': | |
# Usage guide | |
def provider1(container: DataFlowInstanced): | |
container.store('data-1', [1, 2, 3, 4, 5]) | |
def provider2(container: DataFlowInstanced): | |
data1 = container.request('data-1') | |
multiple = container.option('multiple') | |
data2 = [x * multiple for x in data1] | |
container.store('data-2', data2) | |
def plotter(container: DataFlowInstanced): | |
fig, ax = container.get_plt() | |
data = container.request('data-2') | |
ax.plot(list(range(len(data))), data) | |
flow = DataFlow('/tmp/dataset') | |
flow.dataset('dataset-1') | |
flow.add_provider(Provider( | |
provider1, ['data-1'] | |
)) | |
flow.add_provider(Provider( | |
provider2, ['data-2'], | |
depends=['data-1'], | |
requires=['multiple'] | |
)) | |
flow.add_provider(Provider( | |
plotter, ['plt-data'], | |
depends=['data-2'], | |
provider_type=Provider.TYPE_GRAPH | |
)) | |
flow.options({ | |
'multiple': 3 | |
}) | |
flow.plot('plt-data') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment