Created
June 20, 2023 12:35
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
params: | |
# Default params file: | |
# This actually gets overwritten by hydra at runtime | |
- params.yaml | |
stages: | |
# Note that we set the DVC_STAGE_NAME so it knows it's running in dvc. | |
train: | |
cmd: | |
- DVC_STAGE_NAME=train python scripts/fine_tune/fine_tune_and_evaluate.py --hoist_params_path="trainer_task.train" | |
params: | |
- trainer_task.train | |
deps: | |
- scripts/fine_tune/fine_tune_and_evaluate.py | |
evaluate: | |
cmd: | |
- DVC_STAGE_NAME=evaluate python scripts/fine_tune/fine_tune_and_evaluate.py --hoist_params_path="trainer_task.eval" | |
params: | |
- trainer_task.eval | |
deps: | |
- scripts/fine_tune/fine_tune_and_evaluate.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_dvc_stage_name(): | |
# Get this from the DVC_STAGE_NAME environment variable: | |
dvc_stage_name = os.getenv("DVC_STAGE_NAME") | |
# if it's empty, make it None: | |
if dvc_stage_name is not None and dvc_stage_name.strip() == "": | |
dvc_stage_name = None | |
return dvc_stage_name | |
def hoist_params(params_dictionary, hoist_params_path): | |
hoist_params_path_parts = hoist_params_path.split(".") | |
for hoist_params_path_part in hoist_params_path_parts: | |
params_dictionary = params_dictionary[hoist_params_path_part] | |
if params_dictionary is None: | |
# Print warning: | |
logger.warning( | |
f"hoist_params_path = {hoist_params_path} is invalid. It's None at {hoist_params_path_part}" | |
) | |
break | |
return params_dictionary | |
def get_dvc_params_if_in_dvc(): | |
dvc_stage_name = get_dvc_stage_name() | |
dvc_params = None | |
if dvc_stage_name is not None: | |
dvc_params = dvc.api.params_show(stages=dvc_stage_name) | |
logger.info(f"DVC Stage: '{dvc_stage_name}'") | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--hoist_params_path", nargs=1, required=False) | |
args = parser.parse_args() | |
if args.hoist_params_path is not None: | |
hoist_params_path = args.hoist_params_path[0] | |
logger.info(f"Hosting parameters to top level from: {hoist_params_path}") | |
dvc_params = hoist_params(dvc_params, hoist_params_path) | |
logger.info( | |
f"DVC Params being used by this script: {json.dumps(dvc_params, indent=4)}" | |
) | |
else: | |
logger.info("Not being run by DVC. (No 'DVC_STAGE_NAME' env variable set) ") | |
return dvc_params | |
# This will be set to the | |
dvc_params_dict = get_dvc_params_if_in_dvc() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This is generated by dvc by reading the hydra config at runtime | |
trainer_task: | |
train: | |
do_train: true | |
do_eval: false | |
eval: | |
do_train: false | |
do_eval: true |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment