Last active
July 13, 2024 22:20
-
-
Save BramVanroy/f78530673b1437ed0d6be7c61cdbdd7c to your computer and use it in GitHub Desktop.
Overwrite HfArgumentParser config options with CLI arguments
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# See https://gist.github.com/BramVanroy/f78530673b1437ed0d6be7c61cdbdd7c | |
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, HyperOptArguments)) | |
try: | |
# Assumes that the first .json file is the config file (if any) | |
config_file = next(iter(arg for arg in sys.argv if arg.endswith(".json"))) | |
except StopIteration: | |
config_file = None | |
run_name_specified = False | |
if config_file: | |
config_args = parser.parse_json_file(json_file=os.path.abspath(config_file)) | |
raw_config_json = json.loads(Path(config_file).read_text(encoding="utf-8")) | |
config_arg_idx = sys.argv.index(config_file) | |
other_args = sys.argv[config_arg_idx + 1:] | |
arg_names = {arg[2:] for arg in other_args if arg.startswith("--")} | |
if "run_name" in arg_names or "run_name" in raw_config_json: | |
run_name_specified = True | |
required_args = [(act.option_strings[0], "dummy") | |
for act in parser._actions | |
if act.required and not any(act_s[2:] in arg_names for act_s in act.option_strings)] | |
required_args = [arg for req_dummy_args in required_args for arg in req_dummy_args] # Flatten | |
cli_args = other_args + required_args | |
cli_args = parser.parse_args_into_dataclasses(args=cli_args, look_for_args_file=False) | |
all_args = [] | |
for cfg_dc, cli_dc in zip(config_args, cli_args): | |
# Have to check explicitly for no_ for the automatically added negated boolean arguments | |
# E.g. find_unused... vs no_find_unused... | |
cli_d = {k: v for k, v in dataclasses.asdict(cli_dc).items() if k in arg_names or f"no_{k}" in arg_names} | |
all_args.append(dataclasses.replace(cfg_dc, **cli_d)) | |
model_args, data_args, training_args, hyperopt_args = all_args | |
else: | |
model_args, data_args, training_args, hyperopt_args = parser.parse_args_into_dataclasses() | |
# Normally, post_init of training_args sets run_name to output_dir (defaults to "results/" in our config file) | |
# But if we overwrite output_dir with a CLI option, then we do not correctly update | |
# run_name to the same value. Which in turn will lead to wandb to use the original "results/" as a run name | |
# see: https://github.com/huggingface/transformers/blob/fe861e578f50dc9c06de33cd361d2f625017e624/src/transformers/integrations.py#L741-L742 | |
# Instead we explicitly have to set run_name to the output_dir again -- but of course only if the user | |
# did not specifically specify run_name in the config or in the CLI | |
if not run_name_specified: | |
training_args.run_name = training_args.output_dir |
I document this here if anybody else is looking for something similar. For this to work, you need to give up the convenience of the yaml
format.
I finally came up with this:
- If I pass a
.yaml
file, it reads the config from this file only (no command line arguments) - I allow for arguments to be passed via a
.args
file; the reason is that theHfArgumentParser
can already overwrite arguments, just not inyaml
format, but in the.args
file consisting of a command line argument per line in the file - Additionally, one can pass command line arguments (as usual)
- And finally, if there are no arguments at all, it looks for a default config file
<caller_file>.args
In summary, I can call e.g., python train.py --config_file experiment.args --learning_rate 0.001
. This loads train.args
, overwrites its contents with experiment.args
and overwrites the learning_rate
.
CONFIG_OVERWRITE_ARGS_FLAG: Final[str] = "--config_file"
def get_training_config() -> tuple[argparse.Namespace, ...]:
# Extract the arguments which are not overwritten via command line flags /.args files
config_file_parser = argparse.ArgumentParser()
config_file_parser.add_argument(
CONFIG_OVERWRITE_ARGS_FLAG, type=str, action="append"
)
_, remaining_args = config_file_parser.parse_known_args()
# Determine .args and .yaml files in arguments
args_file_names = {
arg
for arg in remaining_args
if arg.lower().endswith(".args") and not arg.startswith("--")
}
yaml_file_names = {
arg
for arg in remaining_args
if arg.lower().endswith(".yaml") and not arg.startswith("--")
}
# Check if arguments make sense
if len(args_file_names) and len(yaml_file_names):
raise ValueError(
"Both .args and .yaml config files specified, can only use one configuration file (format)"
)
if len(args_file_names) > 1 or len(yaml_file_names) > 1:
raise ValueError(
"Can only create configuration from a single configuration file"
)
if len(yaml_file_names) == 1 and len(sys.argv) > 2:
raise ValueError(
"Configuration with a .yaml file does not allow for additional command line arguments"
)
# Parse the arguments, taking care of the different config scenarios
parser = HfArgumentParser(
(
ModelArguments,
DataArguments,
TrainingArguments,
AdditionalArguments,
),
description="Train a model on a dataset",
)
# we have a .yaml file as a single argument
if len(sys.argv) == 2 and sys.argv[1].lower().endswith(".yaml"):
# print(f"Config via .yaml file {sys.argv[1]}")
yaml_file_path = Path(sys.argv[1]).resolve()
if not yaml_file_path.is_file():
raise FileNotFoundError(f"Config file {yaml_file_path} not found")
model_args, data_args, training_args, additional_args = parser.parse_yaml_file(
yaml_file_path.as_posix(),
)
# we have a (main) .args file, plus potentially overwrites
elif len(args_file_names) == 1 and list(args_file_names)[0].lower().endswith(
".args"
):
main_args_file_name = list(args_file_names)[0]
main_args_file_path = Path(main_args_file_name).resolve()
# print(f"Config via .args file {main_args_file_path}")
if not main_args_file_path.is_file():
raise FileNotFoundError(f"Config file {main_args_file_path} not found")
(
model_args,
data_args,
training_args,
additional_args,
unused_args,
) = parser.parse_args_into_dataclasses(
args_filename=main_args_file_path.as_posix(),
return_remaining_strings=True,
args_file_flag=CONFIG_OVERWRITE_ARGS_FLAG,
)
unused_args.remove(main_args_file_name) # must use original string
if unused_args:
raise ValueError(f"Unknown configuration arguments: {unused_args}")
# no config given, look for default config
elif len(args_file_names) == 0:
main_args_file_path = (
Path(DEFAULT_CONFIG_DIR) / Path(sys.argv[0]).with_suffix(".args").name
)
# print(f"Config via default .args file {main_args_file_path}")
if not main_args_file_path.is_file():
raise FileNotFoundError(f"Config file {main_args_file_path} not found")
(
model_args,
data_args,
training_args,
additional_args,
unused_args,
) = parser.parse_args_into_dataclasses(
args_filename=main_args_file_path,
return_remaining_strings=True,
args_file_flag=CONFIG_OVERWRITE_ARGS_FLAG,
)
if unused_args:
raise ValueError(f"Unknown configuration arguments: {unused_args}")
# Something we have not considered (hopefully not)
else:
raise RuntimeError(
"Can not determine valid configuration option (should not happen, check code)"
)
return model_args, data_args, training_args, additional_args
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thank you, will give this a try. I was looking for a way to combine an “external” (to HF) hyperparameter optimizer (e.g., the
wandb sweeps
) with theHFArgumentParser
. I want to leave the parameters in a file (not to be optimized or as defaults) and only provide the changeable hyperparameters on the command line. This seems to do just that.Before using HF I used
hydra
which does this out of the box (but has its own inherent complexity compared to simple config files).