import yaml import argparse class _ConfigDict(dict): """ A subclass of dict that supports required args """ # Change this to customize it fields = [] @classmethod def get_cmd_parser(cls, parser=None): """ Return a argparse.ArgumentParser or add argument to existing one """ if parser is None: parser = argparse.ArgumentParser() for field in cls.fields: parser.add_argument('-' + field[0], type=field[1], default=field[2], help=field[3]) return parser def __init__(self, **kwargs): """ set up fields """ self.validate_config_dict(kwargs) super(_ConfigDict, self).__init__(**kwargs) @classmethod def from_yaml(cls, yaml_file): """ Load from yaml file """ with open(yaml_file, 'r') as stream: try: return cls(**yaml.load(stream)) except yaml.YAMLError as exc: print(exc) def to_yaml(self, yaml_file): """ Save to yaml file """ data = {k: v for k,v in self.items()} with open(yaml_file, 'w') as out: yaml.dump(data, out, default_flow_style=False) @classmethod def validate_config_dict(cls, config_dict): """ Given a dict, verify whether it has all the fields """ required_field = [field[0] for field in cls.fields] kwargs_field = list(config_dict.keys()) missing_fields = set(required_field) - set(kwargs_field) assert len(missing_fields) == 0, 'Missing fields for game config: ' + str(missing_fields) def display(self): """ Display the config """ for k, v in self.items(): print('{}:{}'.format(k, v)) def config_dict(name, fields): """ Return a class that has spefied fields """ return type(name, (_ConfigDict,), {'fields': fields}) if __name__ == '__main__': # Create customizable config dict class my_fields = [('lr', float, 0.01, 'learning rate'), ('mom', float, 0.9, 'sgd momentum')] MyConfigDict = config_dict('MyConfig', my_fields) parser = MyConfigDict.get_cmd_parser() args = parser.parse_args() # Create config dict config = MyConfigDict(**args.__dict__) print('old config') config.display() # save and load config.to_yaml('config.yaml') # Load use class config_new = MyConfigDict.from_yaml('config.yaml') print('new config') config_new.display() # Load use object config_new_2 = config.from_yaml('config.yaml') print('new config 2') config_new_2.display() # Create by hand config_by_args = MyConfigDict(lr=5, mom=0.9) print('create config by directly input args') config_by_args.display()