Created
May 15, 2020 21:02
-
-
Save junyann/722ee14eb4368d0ce333d50b1160e898 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
import torch | |
from torch.nn import CrossEntropyLoss | |
from tqdm import trange | |
max_length = 24 | |
batch_size = 200 | |
class GPT2LMHeadModel_WO_REDUCTION(GPT2LMHeadModel): | |
def __init__(self, config): | |
super().__init__(config) | |
def forward( | |
self, | |
input_ids=None, | |
past=None, | |
attention_mask=None, | |
token_type_ids=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
labels=None, | |
): | |
transformer_outputs = self.transformer( | |
input_ids, | |
past=past, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
) | |
hidden_states = transformer_outputs[0] | |
lm_logits = self.lm_head(hidden_states) | |
outputs = (lm_logits,) + transformer_outputs[1:] | |
if labels is not None: | |
# Shift so that tokens < n predict n | |
shift_logits = lm_logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
# Flatten the tokens | |
bsz, shift_seq_len = shift_labels.size() | |
loss_fct = CrossEntropyLoss(reduction='none') | |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | |
loss = loss.view(bsz, shift_seq_len).sum(1) | |
outputs = (loss,) + outputs | |
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions) | |
all_sentence = [ | |
'there is a book on the desk', | |
'there is a plane on the desk', | |
'there is a book in the desk' | |
] | |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
tokenizer.pad_token = 'pad' | |
model = GPT2LMHeadModel_WO_REDUCTION.from_pretrained('gpt2') | |
model.to(torch.device('cuda')) | |
model.eval() | |
with torch.no_grad(): | |
all_ppl = [] | |
input_encoded = tokenizer.batch_encode_plus(all_sentence, max_length=max_length, pad_to_max_length=True, return_tensors='pt') | |
input_ids = input_encoded['input_ids'] | |
attention_mask = input_encoded['attention_mask'] | |
bos_input_ids = torch.full([input_ids.size(0), 1], tokenizer.bos_token_id, dtype=torch.long) | |
input_ids_w_bos = torch.cat((bos_input_ids, input_ids), dim=1) | |
labels_w_bos = input_ids_w_bos.clone() | |
bos_attention_mask = torch.ones(input_ids.size(0), 1, dtype=torch.long) | |
attention_mask_w_bos = torch.cat((bos_attention_mask, attention_mask), dim=1).bool() | |
labels_w_bos.masked_fill_(~attention_mask_w_bos, -100) | |
sentence_length = attention_mask.sum(-1) | |
for i in trange(0, len(all_sentence), batch_size): | |
batch_input_ids_w_bos = input_ids_w_bos[i: i + batch_size].to(torch.device('cuda')) | |
batch_labels_w_bos = labels_w_bos[i: i + batch_size].to(torch.device('cuda')) | |
batch_sentence_length = sentence_length[i: i + batch_size].to(torch.device('cuda')) | |
batch_outputs = model(batch_input_ids_w_bos, labels=batch_labels_w_bos) | |
batch_loss = batch_outputs[0] | |
batch_per_token_loss = batch_loss / batch_sentence_length | |
batch_ppl = torch.exp(batch_per_token_loss).tolist() | |
all_ppl += batch_ppl | |
print(all_ppl) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment