Last active
July 20, 2022 15:36
-
-
Save danielkelshaw/ecd184da888532415f20a90e12832007 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import h5py | |
import enum | |
import torch | |
import numpy as np | |
import einops | |
import opt_einsum | |
import operator | |
from pathlib import Path | |
import functools as ft | |
import itertools as it | |
from typing import List, Tuple | |
import matplotlib.pyplot as plt | |
from picr.utils.config import ExperimentConfig | |
from picr.utils.enums import eSolverFunction | |
from picr.utils.loss_tracker import LossTracker | |
class Run: | |
def __init__(self, run_path: Path) -> None: | |
self.run_path = run_path | |
# define relevant paths | |
self.results_path = self.run_path / 'results.csv' | |
self.model_path = self.run_path / 'autoencoder.pt' | |
self.config_path = self.run_path / 'config.yml' | |
# load config | |
self.config: ExperimentConfig = self._load_config(self.config_path) | |
def __repr__(self) -> str: | |
return f'Run(system={self.system.name}, freq={self.freq}, mag={self.mag})' | |
@property | |
def system(self) -> eSolverFunction: | |
return self.config.SOLVER_FN | |
@property | |
def freq(self) -> float: | |
return self.config.PHI_FREQ | |
@property | |
def mag(self) -> float: | |
return self.config.PHI_LIMIT | |
@staticmethod | |
def _load_config(config_path: Path) -> ExperimentConfig: | |
_config = ExperimentConfig() | |
_config.load_config(config_path) | |
return _config | |
def load_results(self) -> np.ndarray: | |
return np.loadtxt(self.results_path, delimiter=',', skiprows=1) | |
def generate_eResults() -> enum.EnumMeta: | |
lt = LossTracker() | |
str_list = ['epoch', *lt.get_fields(training=True), *lt.get_fields(training=False)] | |
str_list = list(map(lambda x: x.upper(), str_list)) | |
return enum.IntEnum('eResults', str_list, start=0) | |
class eExperimentType(enum.Enum): | |
FREQ = 'FREQ' | |
MAG = 'MAG' | |
eResults = generate_eResults() | |
def get_glob(experiment_type: eExperimentType) -> str: | |
if experiment_type == eExperimentType.FREQ: | |
return './FREQ*/*' | |
if experiment_type == eExperimentType.MAG: | |
return './MAG*/*' | |
raise KeyError('Invalid eExperimentType') | |
class BaseExperiment: | |
experiment_type: eExperimentType | |
def __init__(self, experiment_path: Path) -> None: | |
self.experiment_path = experiment_path | |
# load all runs in experiment | |
self._run_paths = list(self.experiment_path.glob(get_glob(self.experiment_type))) | |
self.runs = [Run(run_path=i) for i in self._run_paths] | |
self.freqs = sorted(set([run.freq for run in self.runs])) | |
self.mags = sorted(set([run.mag for run in self.runs])) | |
def __iter__(self) -> float: | |
raise ValueError('Must Override.') | |
@property | |
def n_runs(self) -> int: | |
return len(self.runs) | |
@property | |
def n_freqs(self) -> int: | |
return len(self.freqs) | |
@property | |
def n_mags(self) -> int: | |
return len(self.mags) | |
@staticmethod | |
def filter_fn(run: Run, x: float) -> bool: | |
raise ValueError('Must overrise BaseExperiment::filter_fn()') | |
class FreqExperiment(BaseExperiment): | |
experiment_type: eExperimentType = eExperimentType.FREQ | |
def __init__(self, experiment_path: Path) -> None: | |
super().__init__(experiment_path) | |
def __iter__(self) -> float: | |
for freq in self.freqs: | |
yield freq | |
@staticmethod | |
def filter_fn(run: Run, x: float) -> bool: | |
return run.freq == x | |
class MagExperiment(BaseExperiment): | |
experiment_type: eExperimentType = eExperimentType.MAG | |
def __init__(self, experiment_path: Path) -> None: | |
super().__init__(experiment_path) | |
def __iter__(self) -> float: | |
for mag in self.mags: | |
yield mag | |
@staticmethod | |
def filter_fn(run: Run, x: float) -> bool: | |
return run.mag == x | |
@ft.lru_cache() | |
def extract_mu_sigma(experiment: BaseExperiment, column: eResults) -> Tuple[List[float], List[float]]: | |
mus = [] | |
sigmas = [] | |
for idx, fm in enumerate(experiment): | |
partial_filter_fn = ft.partial(experiment.filter_fn, x=fm) | |
runs = list(filter(partial_filter_fn, experiment.runs)) | |
results = np.stack([run.load_results() for run in runs], axis=0) | |
mean_results = np.mean(results, axis=0) | |
mean_results[:, eResults.EPOCH] = np.arange(mean_results.shape[0]) | |
sigma_results = np.std(results, axis=0) | |
sigma_results[:, eResults.EPOCH] = np.arange(sigma_results.shape[0]) | |
mu_results = mean_results[:, column] | |
sigma_results = sigma_results[:, column] | |
idx_min = np.argmin(mu_results) | |
fm_mu = mu_results[idx_min] | |
fm_sigma = sigma_results[idx_min] | |
mus.append(fm_mu) | |
sigmas.append(fm_sigma) | |
return mus, [email protected]_cache() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment