Skip to content

Instantly share code, notes, and snippets.

@goddoe
Created May 23, 2022 17:54
Show Gist options
  • Select an option

  • Save goddoe/d4796e7c83f4046012ae8ed5b186f09a to your computer and use it in GitHub Desktop.

Select an option

Save goddoe/d4796e7c83f4046012ae8ed5b186f09a to your computer and use it in GitHub Desktop.
lm_scorer.py
import pandas as pd
import numpy as np
import torch
from transformers import GPTJForCausalLM, AutoTokenizer
###########################################################
# Load Model
class CausalLMScorer:
def __init__(self, pretrained_model_name_or_path="EleutherAI/gpt-j-6B"):
self.model = GPTJForCausalLM.from_pretrained(pretrained_model_name_or_path) # , torch_dtype=torch.float16
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, use_fast=True)
def calc_conditional_ppl(self, text_pair_list):
model = self.model
tokenizer = self.tokenizer
text_a_last_ch_idx_list = []
x_list = []
for text_a, text_b in text_pair_list:
x = f"{text_a} {text_b}"
x_list.append(x)
text_a_last_ch_idx = len(text_a) + 1
text_a_last_ch_idx_list.append(text_a_last_ch_idx)
tokenizer.pad_token = tokenizer.eos_token
t = tokenizer(x_list, return_tensors="pt", return_offsets_mapping=True, padding=True, truncation=True)
input_ids_list = t.input_ids.tolist()
offset_mapping_list = t.offset_mapping.tolist()
first_token_idx_list = []
last_token_idx_exclusive_list = []
for elem_i, text_a_last_ch in enumerate(text_a_last_ch_idx_list):
# find last token idx
tid_i = 0
for tid in input_ids_list[elem_i]:
if tid == tokenizer.pad_token_id:
break
tid_i += 1
last_token_idx_exclusive_list.append(tid_i)
# find start token idx
prev_start_idx = 0
for token_idx, (start_idx, end_idx)in enumerate(offset_mapping_list[elem_i]):
if start_idx <= text_a_last_ch < end_idx:
first_token_idx_list.append(token_idx)
break
elif text_a_last_ch <= prev_start_idx:
first_token_idx_list.append(token_idx-1)
break
prev_start_idx = start_idx
out = model(t.input_ids)
pred = torch.softmax(out[0], dim=2)
# lookup prob
lookup_ids = t.input_ids[:, 1:]
probs_lookup = pred[:, :-1].gather(dim=2, index=lookup_ids.unsqueeze(2)).squeeze(2).detach().cpu().numpy()
ppl_list = []
# calc ppl
for prob, first_token_idx, last_token_idx_exclusive in zip(probs_lookup, first_token_idx_list, last_token_idx_exclusive_list):
pr = prob[first_token_idx: last_token_idx_exclusive]
ppl = -1. * np.sum(np.log(pr)) / len(pr)
ppl_list.append(ppl)
return ppl_list
if __name__ == '__main__':
text_1_1 = "YMCA in South Australia"
text_1_2 = "South Australia (SA)  has a unique position in Australia's history as, unlike the other states which were founded as colonies, South Australia began as a self governing province Many were attracted to this and Adelaide and SA developed as an independent and free thinking state."
text_2_1 = "The overall decline seems to be due to the low survival rate of young birds , which may be caused by changes in agricultural practices ."
text_2_2 = "The intensive farming methods used in northern Europe mean there is less pasture and meadow habitat available , and the supply of grassland invertebrates needed for the nestlings to thrive is <unk> reduced ."
text_pair_list = [(text_1_1, text_1_2),
(text_2_1, text_2_2),
(text_1_1, text_2_2),
(text_2_1, text_1_2),
(text_1_1, text_2_1),
(text_1_2, text_2_2)
]
lm_scorer = CausalLMScorer()
ppl_list = lm_scorer.calc_conditional_ppl(text_pair_list)
print(ppl_list)
# for input_ids, first_token_idx, last_token_idx_exclusive in zip(input_ids_list, first_token_idx_list, last_token_idx_exclusive_list):
# target_tokens = input_ids[first_token_idx: last_token_idx_exclusive]
# print(tokenizer.batch_decode(target_tokens))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment