Skip to content

Instantly share code, notes, and snippets.

@junyann
Created May 15, 2020 21:02
Show Gist options
  • Save junyann/722ee14eb4368d0ce333d50b1160e898 to your computer and use it in GitHub Desktop.
Save junyann/722ee14eb4368d0ce333d50b1160e898 to your computer and use it in GitHub Desktop.
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
from torch.nn import CrossEntropyLoss
from tqdm import trange
max_length = 24
batch_size = 200
class GPT2LMHeadModel_WO_REDUCTION(GPT2LMHeadModel):
def __init__(self, config):
super().__init__(config)
def forward(
self,
input_ids=None,
past=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
):
transformer_outputs = self.transformer(
input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
outputs = (lm_logits,) + transformer_outputs[1:]
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
bsz, shift_seq_len = shift_labels.size()
loss_fct = CrossEntropyLoss(reduction='none')
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
loss = loss.view(bsz, shift_seq_len).sum(1)
outputs = (loss,) + outputs
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
all_sentence = [
'there is a book on the desk',
'there is a plane on the desk',
'there is a book in the desk'
]
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = 'pad'
model = GPT2LMHeadModel_WO_REDUCTION.from_pretrained('gpt2')
model.to(torch.device('cuda'))
model.eval()
with torch.no_grad():
all_ppl = []
input_encoded = tokenizer.batch_encode_plus(all_sentence, max_length=max_length, pad_to_max_length=True, return_tensors='pt')
input_ids = input_encoded['input_ids']
attention_mask = input_encoded['attention_mask']
bos_input_ids = torch.full([input_ids.size(0), 1], tokenizer.bos_token_id, dtype=torch.long)
input_ids_w_bos = torch.cat((bos_input_ids, input_ids), dim=1)
labels_w_bos = input_ids_w_bos.clone()
bos_attention_mask = torch.ones(input_ids.size(0), 1, dtype=torch.long)
attention_mask_w_bos = torch.cat((bos_attention_mask, attention_mask), dim=1).bool()
labels_w_bos.masked_fill_(~attention_mask_w_bos, -100)
sentence_length = attention_mask.sum(-1)
for i in trange(0, len(all_sentence), batch_size):
batch_input_ids_w_bos = input_ids_w_bos[i: i + batch_size].to(torch.device('cuda'))
batch_labels_w_bos = labels_w_bos[i: i + batch_size].to(torch.device('cuda'))
batch_sentence_length = sentence_length[i: i + batch_size].to(torch.device('cuda'))
batch_outputs = model(batch_input_ids_w_bos, labels=batch_labels_w_bos)
batch_loss = batch_outputs[0]
batch_per_token_loss = batch_loss / batch_sentence_length
batch_ppl = torch.exp(batch_per_token_loss).tolist()
all_ppl += batch_ppl
print(all_ppl)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment