Skip to content

Instantly share code, notes, and snippets.

@simrit1
Forked from davidmezzetti/txtai-trainqa.py
Created September 15, 2021 00:33
Show Gist options
  • Save simrit1/1ad44e35b2777572530997c7e1db40c1 to your computer and use it in GitHub Desktop.
Save simrit1/1ad44e35b2777572530997c7e1db40c1 to your computer and use it in GitHub Desktop.
from transformers import pipeline
from txtai.pipeline import HFTrainer
# Training data
data = [
{"question": "What ingredient?", "context": "Pour 1 can whole tomatoes", "answers": "tomatoes"},
{"question": "What ingredient?", "context": "Dice 1 yellow onion", "answers": "onion"},
{"question": "What ingredient?", "context": "Cut 1 red pepper", "answers": "pepper"},
{"question": "What ingredient?", "context": "Peel and dice 1 clove garlic", "answers": "garlic"},
{"question": "What ingredient?", "context": "Put 1/2 lb beef", "answers": "beef"},
]
trainer = HFTrainer()
model, tokenizer = trainer("bert-tiny-squadv2", data, task="question-answering", num_train_epochs=10)
questions = pipeline("question-answering", model=model, tokenizer=tokenizer)
questions("What ingredient?", "Peel and dice 1 shallot")
# {'answer': 'shallot', 'end': 23, 'score': 0.13187439739704132, 'start': 16}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment