-
-
Save afspies/d7ca78f834ec3d3242a43500e7d7d75d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Configuration library for experiments. | |
Based on version by Nuri C. https://gist.github.com/nuric/e9df5e38c34c804ec6c298ffca7e2d9b | |
-- USAGE -- | |
In Header: | |
import config.configlib as configlib | |
from config.configlib import config as C | |
parser = configlib.add_parser("File Spec Config") | |
parser.add_argument("--mlp_size", default=128, type=int, help="Size of Output MLP") | |
In Main File | |
configlib.parse(run_name='run_name') - | |
print("Running with Configuration:") | |
configlib.print_config() | |
run() | |
NOTE: The run_name param refers to the --run_name param, not the specific instance | |
e.g. "python my_file.py --execution_name fish" | |
registers fish as the name if run_name="execution_name" in parse (used for saving and logs) | |
Additionally, params may be specific at run-time within a file (for testing) via extra_args param: | |
extra_args=['--load_config', f'path/to/config.json'])#, | |
'--run_name' , 'test_alignnet']) | |
# Here we load a config file and override the associated run_name | |
""" | |
from typing import Dict, Any | |
import logging | |
import pprint | |
import sys, os | |
import json | |
import argparse | |
logger = logging.getLogger(__name__) | |
class LoadFromFile (argparse.Action): | |
def __call__ (self, parser, namespace, values, option_string = None): | |
if values.name[-4:] == ".txt": | |
with values as f: | |
parser.parse_args(f.read().split(), namespace) | |
else: | |
args = [] | |
with values as f: | |
val_dict = json.load(values) | |
for key, val in val_dict.items(): | |
key = "--"+key | |
if val is None or key in sys.argv[1:]: | |
continue | |
args.append(key) | |
if isinstance(val,list): | |
args.extend(map(str, val)) | |
elif isinstance(val, bool): | |
args.append(str(int(val))) | |
else: | |
args.append(str(val)) | |
parser.parse_args(args, namespace) | |
parser = argparse.ArgumentParser(description=__doc__, fromfile_prefix_chars="@") | |
parser.add_argument("--load_config", type=open, action=LoadFromFile) | |
C: Dict[str, Any] = {} | |
def add_parser(title: str, description: str = ""): | |
"""Create a new context for arguments and return a handle.""" | |
return parser.add_argument_group(title, description) | |
def parse(save_folder: str = "", run_name: str=None, extra_args=None) -> Dict[str, Any]: | |
"""Parse given arguments.""" | |
args = parser.parse_args(sys.argv[1:]+extra_args) | |
C.update(vars(args)) | |
logging.info("Parsed %i arguments.", len(C)) | |
if save_folder: | |
run_name = C['run_name'] if run_name else 'config_default' | |
# Check if file already exists, if so add +=1 to filename | |
if os.path.exists(f'{save_folder}/{run_name}.json'): | |
for i in range(1, 100): | |
if not os.path.exists(f'{save_folder}/{run_name}_{i}.json'): | |
C['run_name'] += '_' + str(i) | |
break | |
save_fname = os.path.join(save_folder, C['run_name']) | |
with open(save_fname+'.json', "w") as fout: | |
json.dump(args.__dict__, fout, indent=4) | |
logging.info("Saving arguments to %s.", save_fname) | |
return C | |
def print_config(): | |
"""Print the current config to stdout.""" | |
pprint.pprint(C) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Some data loading module.""" | |
from typing import List | |
import random | |
import configlib | |
from configlib import config as C | |
# Configuration arguments | |
parser = configlib.add_parser("Dataset config") | |
parser.add_argument("--data_length", default=4, type=int, help="Length of random list.") | |
def load_data() -> List[int]: | |
"""Load some random data.""" | |
data = list(range(C["data_length"])) | |
random.shuffle(data) | |
return data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Some training script that uses data module.""" | |
import configlib | |
from configlib import config as C | |
import data | |
load_run = "Name of prev run" | |
# Configuration arguments | |
parser = configlib.add_parser("Train config") | |
parser.add_argument("--debug", action="store_true", help="Enable debug mode.") | |
def train(): | |
"""Main training function.""" | |
if C["debug"]: | |
print("Debugging mode enabled.") | |
print("Example dataset:") | |
print(data.load_data()) | |
if __name__ == "__main__": | |
configlib.parse(save_folder=f"/media/home/alex/slot_exp/models/align_net/", # Save new configuration | |
run_name='run_name', # Specifies which argument will be handled as the name for the run | |
extra_args=['--load_config', f'path/to/{load_run}.json', | |
'--batch_size', '10', # Override parameters from json | |
'--run_name' , 'train_new']) # New configuration saved with new name | |
configlib.parse("last_arguments.txt") | |
print("Running with configuration:") | |
configlib.print_config() | |
train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment