Skip to content

Instantly share code, notes, and snippets.

@Felflare
Created February 10, 2020 03:03
Show Gist options
  • Save Felflare/5ab13a929d27cfd7f50c4d3ef5e9392d to your computer and use it in GitHub Desktop.
Save Felflare/5ab13a929d27cfd7f50c4d3ef5e9392d to your computer and use it in GitHub Desktop.
XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). Simple demo of loss and logits.
from transformers import XLNetTokenizer, XLNetForQuestionAnsweringSimple
import torch
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
model = XLNetForQuestionAnsweringSimple.from_pretrained('xlnet-base-cased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
print(f'Encoded sequence ids -- {input_ids.tolist()[0]}')
# Encoded sequence ids -- [17, 11368, 19, 94, 2288, 27, 10920, 4, 3]
start_positions = torch.tensor([1])
end_positions = torch.tensor([3])
print(f'Extracted sequence ids -- {input_ids.tolist()[0][start_positions.tolist()[0]:end_positions.tolist()[0]]}')
# Extracted sequence ids -- [11368, 19]
print(f'Extracted sequence is -- {tokenizer.decode(input_ids.tolist()[0][start_positions.tolist()[0]:end_positions.tolist()[0]])}')
# Extracted sequence is -- Hello,
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
loss = outputs[0]
print(f'Loss value of the selected span is -- {loss.tolist()}')
# Loss value of the selected span is -- 2.9509589672088623
@LifeIsStrange
Copy link

Could this be used to outperform spanBERT on coreference resolution?
https://paperswithcode.com/sota/coreference-resolution-on-ontonotes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment