Skip to content

Instantly share code, notes, and snippets.

@ThilinaRajapakse
Created November 12, 2021 23:23
Show Gist options
  • Save ThilinaRajapakse/87cc48c47fb979f9af2a016070ab7a63 to your computer and use it in GitHub Desktop.
Save ThilinaRajapakse/87cc48c47fb979f9af2a016070ab7a63 to your computer and use it in GitHub Desktop.
import logging
from simpletransformers.retrieval import RetrievalModel, RetrievalArgs
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)
eval_data = "data/nq-dev.json"
model_type = "custom"
model_name = "outputs"
# Create a TransformerModel
model = RetrievalModel(
model_type=model_type,
model_name=model_name,
)
results, *_ = model.eval_model(eval_data=eval_data)
print(results)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment