Created
December 17, 2021 20:20
-
-
Save HarshTrivedi/c1f54b0b532f847cbddaaf39042dca2b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
######################################################### | |
local setting = std.extVar("setting"); #options: full, fixture | |
local num_cores = std.parseInt(std.extVar("num_cores")); | |
######################################################### | |
# Set this for memory optimization | |
local activation_checkpointing = false; | |
local seed = 100; | |
local transformer_model_name = "nielsr/nt5-small-rc1"; | |
local batch_size = 32; | |
local accumulation_steps = 1; | |
local max_context_tokens = 700; | |
local max_question_tokens = 100; | |
local max_answer_tokens = 200; | |
local shuffle_context = false; | |
local skip_context = true; | |
local use_program_as_question = false; | |
local use_program_last_step_as_question = false; | |
local generate_final_answers = true; | |
local generate_intermediate_answers = true; | |
local generate_step_instructions = true; | |
local intermediate_chain_or_output = "output"; | |
local max_answer_tokens = ( | |
(if generate_final_answers then 1 else 0)*50 + | |
(if generate_intermediate_answers && intermediate_chain_or_output == "chain" then 1 else 0)*500+ | |
(if generate_intermediate_answers && intermediate_chain_or_output == "output" then 1 else 0)*200+ | |
(if generate_step_instructions then 1 else 0)*100 | |
); | |
local target_is_serialized = true; | |
local fixture_path = 'fixtures/synthetic_data/sampled_for_mtl_one_hop.jsonl'; | |
local train_data_path = | |
if setting == 'full' | |
then 'processed_data/synthetic_data/sampled_for_mtl/one_hop_train.jsonl' | |
else if setting == 'fixture' | |
then fixture_path | |
else '-'; | |
local validation_data_path = | |
if setting == 'full' | |
then 'processed_data/synthetic_data/sampled_for_mtl/one_hop_dev.jsonl' | |
else if setting == 'fixture' | |
then fixture_path | |
else '-'; | |
local num_epochs = | |
if setting == 'full' | |
then 20 | |
else if setting == 'fixture' | |
then 30 | |
else 2; | |
local patience = | |
if setting == 'full' | |
then 20 | |
else if setting == 'fixture' | |
then num_epochs | |
else num_epochs; | |
local dataset_reader = { | |
"type": "synthetic_dbqa", | |
"transformer_model_name": transformer_model_name, | |
"max_context_tokens": max_context_tokens, | |
"max_question_tokens": max_question_tokens, | |
"max_answer_tokens": max_answer_tokens, | |
"shuffle_context": shuffle_context, | |
"skip_context": skip_context, | |
"use_program_as_question": use_program_as_question, | |
"use_program_last_step_as_question": use_program_last_step_as_question, | |
"generate_final_answers": generate_final_answers, | |
"generate_intermediate_answers": generate_intermediate_answers, | |
"generate_step_instructions": generate_step_instructions, | |
"intermediate_chain_or_output": intermediate_chain_or_output, | |
"add_additional_tokens": true | |
}; | |
local data_loader = { | |
"batch_size": batch_size, | |
"shuffle": true, | |
"num_workers": 20, | |
"max_instances_in_memory": batch_size*50, | |
[if setting == 'fixture' then "batches_per_epoch"]: 200*accumulation_steps, | |
}; | |
local tensorboard_callback = {"type": "tensorboard"}; | |
local wandb_callback = { | |
"type": "wandb", | |
"project": "synth2realmh", | |
"entity": "harshtrivedi", | |
"name": std.extVar("WANDB_RUN_NAME"), | |
"watch_model": false, | |
"summary_interval": 1, | |
"should_log_parameter_statistics": false, | |
"should_log_learning_rate": false, | |
}; | |
{ | |
"train_data_path": train_data_path, | |
"validation_data_path": validation_data_path, | |
"dataset_reader": dataset_reader, | |
"validation_dataset_reader": dataset_reader, | |
"model": { | |
"type": "qa_t5", | |
"model_name": transformer_model_name, | |
"beam_search": { | |
"beam_size": 5, | |
"max_steps": max_answer_tokens, | |
}, | |
"target_is_serialized": target_is_serialized, | |
[if activation_checkpointing then "checkpoint_wrapper"]: { | |
"type": "fairscale", | |
"offload_to_cpu": true, | |
"maintain_forward_counter": true, | |
}, | |
}, | |
"data_loader": data_loader, | |
"validation_data_loader": self.data_loader + { | |
"max_instances_in_memory": null, | |
"batches_per_epoch": null, | |
"batch_size": batch_size, | |
}, | |
"vocabulary": { | |
"type": "empty", | |
}, | |
"trainer": { | |
"cuda_device": 0, | |
"use_amp": false, | |
"num_epochs": num_epochs, | |
"patience": patience, | |
"num_gradient_accumulation_steps": accumulation_steps, | |
"optimizer": { | |
"type": "huggingface_adafactor", | |
}, | |
"grad_norm": 1.0, | |
"callbacks": if std.extVar("on_beaker") == "true" | |
then [tensorboard_callback, wandb_callback] | |
else [tensorboard_callback], | |
"validation_metric": "+df_match_score", | |
}, | |
"random_seed": seed, | |
"numpy_seed": seed, | |
"pytorch_seed": seed, | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment