Last active
July 13, 2024 22:20
-
-
Save BramVanroy/f78530673b1437ed0d6be7c61cdbdd7c to your computer and use it in GitHub Desktop.
Overwrite HfArgumentParser config options with CLI arguments
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# See https://gist.github.com/BramVanroy/f78530673b1437ed0d6be7c61cdbdd7c | |
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, HyperOptArguments)) | |
try: | |
# Assumes that the first .json file is the config file (if any) | |
config_file = next(iter(arg for arg in sys.argv if arg.endswith(".json"))) | |
except StopIteration: | |
config_file = None | |
run_name_specified = False | |
if config_file: | |
config_args = parser.parse_json_file(json_file=os.path.abspath(config_file)) | |
raw_config_json = json.loads(Path(config_file).read_text(encoding="utf-8")) | |
config_arg_idx = sys.argv.index(config_file) | |
other_args = sys.argv[config_arg_idx + 1:] | |
arg_names = {arg[2:] for arg in other_args if arg.startswith("--")} | |
if "run_name" in arg_names or "run_name" in raw_config_json: | |
run_name_specified = True | |
required_args = [(act.option_strings[0], "dummy") | |
for act in parser._actions | |
if act.required and not any(act_s[2:] in arg_names for act_s in act.option_strings)] | |
required_args = [arg for req_dummy_args in required_args for arg in req_dummy_args] # Flatten | |
cli_args = other_args + required_args | |
cli_args = parser.parse_args_into_dataclasses(args=cli_args, look_for_args_file=False) | |
all_args = [] | |
for cfg_dc, cli_dc in zip(config_args, cli_args): | |
# Have to check explicitly for no_ for the automatically added negated boolean arguments | |
# E.g. find_unused... vs no_find_unused... | |
cli_d = {k: v for k, v in dataclasses.asdict(cli_dc).items() if k in arg_names or f"no_{k}" in arg_names} | |
all_args.append(dataclasses.replace(cfg_dc, **cli_d)) | |
model_args, data_args, training_args, hyperopt_args = all_args | |
else: | |
model_args, data_args, training_args, hyperopt_args = parser.parse_args_into_dataclasses() | |
# Normally, post_init of training_args sets run_name to output_dir (defaults to "results/" in our config file) | |
# But if we overwrite output_dir with a CLI option, then we do not correctly update | |
# run_name to the same value. Which in turn will lead to wandb to use the original "results/" as a run name | |
# see: https://github.com/huggingface/transformers/blob/fe861e578f50dc9c06de33cd361d2f625017e624/src/transformers/integrations.py#L741-L742 | |
# Instead we explicitly have to set run_name to the output_dir again -- but of course only if the user | |
# did not specifically specify run_name in the config or in the CLI | |
if not run_name_specified: | |
training_args.run_name = training_args.output_dir |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I document this here if anybody else is looking for something similar. For this to work, you need to give up the convenience of the
yaml
format.I finally came up with this:
.yaml
file, it reads the config from this file only (no command line arguments).args
file; the reason is that theHfArgumentParser
can already overwrite arguments, just not inyaml
format, but in the.args
file consisting of a command line argument per line in the file<caller_file>.args
In summary, I can call e.g.,
python train.py --config_file experiment.args --learning_rate 0.001
. This loadstrain.args
, overwrites its contents withexperiment.args
and overwrites thelearning_rate
.