Created
December 14, 2021 19:27
-
-
Save Pangoraw/8c6641023ac92c67b8a0937fd32d1605 to your computer and use it in GitHub Desktop.
A dataclass to ArgumentParser converter
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A small list of utilities to transform dataclasses from argparsers | |
""" | |
from typing import List, Optional | |
import argparse | |
import dataclasses | |
class _MISSING: pass | |
@dataclasses.dataclass() | |
class Field: | |
"Used to add metadata for positional args for." | |
help: Optional[str] = None | |
choices: Optional[List[any]] = None | |
positional: bool = False | |
default: any = _MISSING() | |
required: Optional[bool] = None | |
metavar: Optional[str] = None | |
def from_args(cls): | |
""" | |
Returns an instance of the dataclass with fields populated using the program arguments. | |
Parameters | |
========== | |
cls: A dataclass type | |
Returns | |
======= | |
args: cls - An instance of cls | |
""" | |
parser = to_argparser(cls) | |
namespace = parser.parse_args() | |
name_dict = namespace.__dict__ | |
# should be in cls.__init__ | |
for k, v in name_dict.items(): | |
if isinstance(v, Field): | |
if not isinstance(v.default, _MISSING): | |
name_dict[k] = v.default | |
else: | |
name_dict.pop(k) | |
return cls(**name_dict) | |
def to_argparser(cls) -> argparse.ArgumentParser: | |
""" | |
Converts an annotated dataclass to a corresponding ArgumentParser. | |
Parameters | |
========== | |
cls: A dataclass type | |
Returns | |
======= | |
parser: argparse.ArgumentParser - an argument parser | |
""" | |
annotations = cls.__dict__["__dataclass_fields__"] | |
help_str = "" if "__doc__" not in cls.__dict__ else cls.__doc__ | |
parser = argparse.ArgumentParser(help_str) | |
for k, v in annotations.items(): | |
if isinstance(v.default, Field): | |
field = v.default | |
kwargs = {} | |
has_default = not isinstance(field.default, _MISSING) | |
if has_default: | |
kwargs["default"] = field.default | |
if field.choices is not None: | |
kwargs["choices"] = field.choices | |
if not field.positional: | |
required = field.required if field.required is not None else not has_default | |
kwargs["required"] = required | |
if field.help is not None: | |
kwargs["help"] = field.help | |
elif has_default: | |
kwargs["help"] = f"{k.upper()} [default={field.default}]" | |
if field.metavar is not None: | |
kwargs["metavar"] = field.metavar | |
prefix = "--" if not field.positional else "" | |
parser.add_argument(f"{prefix}{k}", type=v.type, **kwargs) | |
else: | |
required = isinstance(v.default, dataclasses._MISSING_TYPE) | |
kwargs = { | |
"required": required, | |
} | |
if not required: | |
kwargs["default"] = v.default | |
kwargs["help"] = f"[default={v.default}]" | |
parser.add_argument(f"--{k}", type=v.type, **kwargs) | |
return parser |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment