Skip to content

Instantly share code, notes, and snippets.

@BramVanroy
Last active July 13, 2024 22:20
Show Gist options
  • Save BramVanroy/f78530673b1437ed0d6be7c61cdbdd7c to your computer and use it in GitHub Desktop.
Save BramVanroy/f78530673b1437ed0d6be7c61cdbdd7c to your computer and use it in GitHub Desktop.
Overwrite HfArgumentParser config options with CLI arguments
# See https://gist.github.com/BramVanroy/f78530673b1437ed0d6be7c61cdbdd7c
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, HyperOptArguments))
try:
# Assumes that the first .json file is the config file (if any)
config_file = next(iter(arg for arg in sys.argv if arg.endswith(".json")))
except StopIteration:
config_file = None
run_name_specified = False
if config_file:
config_args = parser.parse_json_file(json_file=os.path.abspath(config_file))
raw_config_json = json.loads(Path(config_file).read_text(encoding="utf-8"))
config_arg_idx = sys.argv.index(config_file)
other_args = sys.argv[config_arg_idx + 1:]
arg_names = {arg[2:] for arg in other_args if arg.startswith("--")}
if "run_name" in arg_names or "run_name" in raw_config_json:
run_name_specified = True
required_args = [(act.option_strings[0], "dummy")
for act in parser._actions
if act.required and not any(act_s[2:] in arg_names for act_s in act.option_strings)]
required_args = [arg for req_dummy_args in required_args for arg in req_dummy_args] # Flatten
cli_args = other_args + required_args
cli_args = parser.parse_args_into_dataclasses(args=cli_args, look_for_args_file=False)
all_args = []
for cfg_dc, cli_dc in zip(config_args, cli_args):
# Have to check explicitly for no_ for the automatically added negated boolean arguments
# E.g. find_unused... vs no_find_unused...
cli_d = {k: v for k, v in dataclasses.asdict(cli_dc).items() if k in arg_names or f"no_{k}" in arg_names}
all_args.append(dataclasses.replace(cfg_dc, **cli_d))
model_args, data_args, training_args, hyperopt_args = all_args
else:
model_args, data_args, training_args, hyperopt_args = parser.parse_args_into_dataclasses()
# Normally, post_init of training_args sets run_name to output_dir (defaults to "results/" in our config file)
# But if we overwrite output_dir with a CLI option, then we do not correctly update
# run_name to the same value. Which in turn will lead to wandb to use the original "results/" as a run name
# see: https://github.com/huggingface/transformers/blob/fe861e578f50dc9c06de33cd361d2f625017e624/src/transformers/integrations.py#L741-L742
# Instead we explicitly have to set run_name to the output_dir again -- but of course only if the user
# did not specifically specify run_name in the config or in the CLI
if not run_name_specified:
training_args.run_name = training_args.output_dir
@hogru
Copy link

hogru commented Mar 6, 2023

I document this here if anybody else is looking for something similar. For this to work, you need to give up the convenience of the yamlformat.

I finally came up with this:

  • If I pass a .yaml file, it reads the config from this file only (no command line arguments)
  • I allow for arguments to be passed via a .args file; the reason is that the HfArgumentParser can already overwrite arguments, just not in yaml format, but in the .args file consisting of a command line argument per line in the file
  • Additionally, one can pass command line arguments (as usual)
  • And finally, if there are no arguments at all, it looks for a default config file <caller_file>.args

In summary, I can call e.g., python train.py --config_file experiment.args --learning_rate 0.001. This loads train.args, overwrites its contents with experiment.args and overwrites the learning_rate.

CONFIG_OVERWRITE_ARGS_FLAG: Final[str] = "--config_file"

def get_training_config() -> tuple[argparse.Namespace, ...]:
    # Extract the arguments which are not overwritten via command line flags /.args files
    config_file_parser = argparse.ArgumentParser()
    config_file_parser.add_argument(
        CONFIG_OVERWRITE_ARGS_FLAG, type=str, action="append"
    )
    _, remaining_args = config_file_parser.parse_known_args()

    # Determine .args and .yaml files in arguments
    args_file_names = {
        arg
        for arg in remaining_args
        if arg.lower().endswith(".args") and not arg.startswith("--")
    }
    yaml_file_names = {
        arg
        for arg in remaining_args
        if arg.lower().endswith(".yaml") and not arg.startswith("--")
    }

    # Check if arguments make sense
    if len(args_file_names) and len(yaml_file_names):
        raise ValueError(
            "Both .args and .yaml config files specified, can only use one configuration file (format)"
        )

    if len(args_file_names) > 1 or len(yaml_file_names) > 1:
        raise ValueError(
            "Can only create configuration from a single configuration file"
        )

    if len(yaml_file_names) == 1 and len(sys.argv) > 2:
        raise ValueError(
            "Configuration with a .yaml file does not allow for additional command line arguments"
        )

    # Parse the arguments, taking care of the different config scenarios
    parser = HfArgumentParser(
        (
            ModelArguments,
            DataArguments,
            TrainingArguments,
            AdditionalArguments,
        ),
        description="Train a model on a dataset",
    )

    # we have a .yaml file as a single argument
    if len(sys.argv) == 2 and sys.argv[1].lower().endswith(".yaml"):
        # print(f"Config via .yaml file {sys.argv[1]}")
        yaml_file_path = Path(sys.argv[1]).resolve()
        if not yaml_file_path.is_file():
            raise FileNotFoundError(f"Config file {yaml_file_path} not found")
        model_args, data_args, training_args, additional_args = parser.parse_yaml_file(
            yaml_file_path.as_posix(),
        )

    # we have a (main) .args file, plus potentially overwrites
    elif len(args_file_names) == 1 and list(args_file_names)[0].lower().endswith(
        ".args"
    ):
        main_args_file_name = list(args_file_names)[0]
        main_args_file_path = Path(main_args_file_name).resolve()
        # print(f"Config via .args file {main_args_file_path}")
        if not main_args_file_path.is_file():
            raise FileNotFoundError(f"Config file {main_args_file_path} not found")
        (
            model_args,
            data_args,
            training_args,
            additional_args,
            unused_args,
        ) = parser.parse_args_into_dataclasses(
            args_filename=main_args_file_path.as_posix(),
            return_remaining_strings=True,
            args_file_flag=CONFIG_OVERWRITE_ARGS_FLAG,
        )
        unused_args.remove(main_args_file_name)  # must use original string
        if unused_args:
            raise ValueError(f"Unknown configuration arguments: {unused_args}")

    # no config given, look for default config
    elif len(args_file_names) == 0:
        main_args_file_path = (
            Path(DEFAULT_CONFIG_DIR) / Path(sys.argv[0]).with_suffix(".args").name
        )
        # print(f"Config via default .args file {main_args_file_path}")
        if not main_args_file_path.is_file():
            raise FileNotFoundError(f"Config file {main_args_file_path} not found")
        (
            model_args,
            data_args,
            training_args,
            additional_args,
            unused_args,
        ) = parser.parse_args_into_dataclasses(
            args_filename=main_args_file_path,
            return_remaining_strings=True,
            args_file_flag=CONFIG_OVERWRITE_ARGS_FLAG,
        )
        if unused_args:
            raise ValueError(f"Unknown configuration arguments: {unused_args}")

    # Something we have not considered (hopefully not)
    else:
        raise RuntimeError(
            "Can not determine valid configuration option (should not happen, check code)"
        )

    return model_args, data_args, training_args, additional_args

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment