import yaml
import argparse

class _ConfigDict(dict):
    """ A subclass of dict that supports required args
    """
    # Change this to customize it
    fields = []

    @classmethod
    def get_cmd_parser(cls, parser=None):
        """ Return a argparse.ArgumentParser or add argument to existing one """
        if parser is None:
            parser = argparse.ArgumentParser()
        for field in cls.fields:
            parser.add_argument('-' + field[0], type=field[1], default=field[2],
                                help=field[3])
        return parser

    def __init__(self, **kwargs):
        """ set up fields """
        self.validate_config_dict(kwargs)
        super(_ConfigDict, self).__init__(**kwargs)

    @classmethod
    def from_yaml(cls, yaml_file):
        """ Load from yaml file """
        with open(yaml_file, 'r') as stream:
            try:
                return cls(**yaml.load(stream))
            except yaml.YAMLError as exc:
                print(exc)

    def to_yaml(self, yaml_file):
        """ Save to yaml file """
        data = {k: v for k,v in self.items()}
        with open(yaml_file, 'w') as out:
            yaml.dump(data, out, default_flow_style=False)

    @classmethod
    def validate_config_dict(cls, config_dict):
        """ Given a dict, verify whether it has all the fields """
        required_field = [field[0] for field in cls.fields]
        kwargs_field = list(config_dict.keys())
        missing_fields = set(required_field) - set(kwargs_field)
        assert len(missing_fields) == 0, 'Missing fields for game config: ' + str(missing_fields)

    def display(self):
        """ Display the config """
        for k, v in self.items():
            print('{}:{}'.format(k, v))

def config_dict(name, fields):
    """ Return a class that has spefied fields """
    return type(name, (_ConfigDict,), {'fields': fields})

if __name__ == '__main__':
    # Create customizable config dict class
    my_fields = [('lr', float, 0.01, 'learning rate'),
                 ('mom', float, 0.9, 'sgd momentum')]
    MyConfigDict = config_dict('MyConfig', my_fields)
    parser = MyConfigDict.get_cmd_parser()
    args = parser.parse_args()

    # Create config dict
    config = MyConfigDict(**args.__dict__)
    print('old config')
    config.display()

    # save and load
    config.to_yaml('config.yaml')

    # Load use class
    config_new = MyConfigDict.from_yaml('config.yaml')
    print('new config')
    config_new.display()

    # Load use object
    config_new_2 = config.from_yaml('config.yaml')
    print('new config 2')
    config_new_2.display()

    # Create by hand
    config_by_args = MyConfigDict(lr=5, mom=0.9)
    print('create config by directly input args')
    config_by_args.display()