Skip to content

Instantly share code, notes, and snippets.

@ThilinaRajapakse
Created November 12, 2021 14:12
Show Gist options
  • Save ThilinaRajapakse/a6bca35f2ab1a69ecbb0a4495d6b09e8 to your computer and use it in GitHub Desktop.
Save ThilinaRajapakse/a6bca35f2ab1a69ecbb0a4495d6b09e8 to your computer and use it in GitHub Desktop.
import logging
from simpletransformers.retrieval import RetrievalModel, RetrievalArgs
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)
train_data = "data/nq-train.json"
eval_data = "data/nq-dev.json"
model_type = "custom"
model_name = None
context_name = "bert-base-uncased"
query_name = "bert-base-uncased"
model_args = RetrievalArgs()
model_args.reprocess_input_data = True
model_args.overwrite_output_dir = True
model_args.use_cached_eval_features = False
model_args.retrieve_n_docs = 100
model_args.hard_negatives = False
model_args.max_seq_length = 256
model_args.num_train_epochs = 40
model_args.train_batch_size = 40
model_args.eval_batch_size = 128
model_args.use_hf_datasets = True
model_args.learning_rate = 1e-5
model_args.save_steps = -1
model_args.evaluate_during_training = False
model_args.wandb_project = "Training DPR on NQ"
model_args.save_model_every_epoch = False
# Create a TransformerModel
model = RetrievalModel(
model_type=model_type,
model_name=model_name,
context_encoder_name=context_name,
query_encoder_name=query_name,
args=model_args,
force_redownload=True
)
model.train_model(train_data, eval_data=eval_data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment