Last active
February 22, 2019 06:23
-
-
Save JACKHAHA363/46831312f26b8d1df87c6a4baaeeb3c8 to your computer and use it in GitHub Desktop.
A principled way to have config dictionary that can be saved/restored, support pre-defined type and default values, and support cmd parsing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import yaml | |
import argparse | |
class _ConfigDict(dict): | |
""" A subclass of dict that supports required args | |
""" | |
# Change this to customize it | |
fields = [] | |
@classmethod | |
def get_cmd_parser(cls, parser=None): | |
""" Return a argparse.ArgumentParser or add argument to existing one """ | |
if parser is None: | |
parser = argparse.ArgumentParser() | |
for field in cls.fields: | |
parser.add_argument('-' + field[0], type=field[1], default=field[2], | |
help=field[3]) | |
return parser | |
def __init__(self, **kwargs): | |
""" set up fields """ | |
self.validate_config_dict(kwargs) | |
super(_ConfigDict, self).__init__(**kwargs) | |
@classmethod | |
def from_yaml(cls, yaml_file): | |
""" Load from yaml file """ | |
with open(yaml_file, 'r') as stream: | |
try: | |
return cls(**yaml.load(stream)) | |
except yaml.YAMLError as exc: | |
print(exc) | |
def to_yaml(self, yaml_file): | |
""" Save to yaml file """ | |
data = {k: v for k,v in self.items()} | |
with open(yaml_file, 'w') as out: | |
yaml.dump(data, out, default_flow_style=False) | |
@classmethod | |
def validate_config_dict(cls, config_dict): | |
""" Given a dict, verify whether it has all the fields """ | |
required_field = [field[0] for field in cls.fields] | |
kwargs_field = list(config_dict.keys()) | |
missing_fields = set(required_field) - set(kwargs_field) | |
assert len(missing_fields) == 0, 'Missing fields for game config: ' + str(missing_fields) | |
def display(self): | |
""" Display the config """ | |
for k, v in self.items(): | |
print('{}:{}'.format(k, v)) | |
def config_dict(name, fields): | |
""" Return a class that has spefied fields """ | |
return type(name, (_ConfigDict,), {'fields': fields}) | |
if __name__ == '__main__': | |
# Create customizable config dict class | |
my_fields = [('lr', float, 0.01, 'learning rate'), | |
('mom', float, 0.9, 'sgd momentum')] | |
MyConfigDict = config_dict('MyConfig', my_fields) | |
parser = MyConfigDict.get_cmd_parser() | |
args = parser.parse_args() | |
# Create config dict | |
config = MyConfigDict(**args.__dict__) | |
print('old config') | |
config.display() | |
# save and load | |
config.to_yaml('config.yaml') | |
# Load use class | |
config_new = MyConfigDict.from_yaml('config.yaml') | |
print('new config') | |
config_new.display() | |
# Load use object | |
config_new_2 = config.from_yaml('config.yaml') | |
print('new config 2') | |
config_new_2.display() | |
# Create by hand | |
config_by_args = MyConfigDict(lr=5, mom=0.9) | |
print('create config by directly input args') | |
config_by_args.display() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment