Created
May 23, 2022 17:54
-
-
Save goddoe/d4796e7c83f4046012ae8ed5b186f09a to your computer and use it in GitHub Desktop.
lm_scorer.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| from transformers import GPTJForCausalLM, AutoTokenizer | |
| ########################################################### | |
| # Load Model | |
| class CausalLMScorer: | |
| def __init__(self, pretrained_model_name_or_path="EleutherAI/gpt-j-6B"): | |
| self.model = GPTJForCausalLM.from_pretrained(pretrained_model_name_or_path) # , torch_dtype=torch.float16 | |
| self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, use_fast=True) | |
| def calc_conditional_ppl(self, text_pair_list): | |
| model = self.model | |
| tokenizer = self.tokenizer | |
| text_a_last_ch_idx_list = [] | |
| x_list = [] | |
| for text_a, text_b in text_pair_list: | |
| x = f"{text_a} {text_b}" | |
| x_list.append(x) | |
| text_a_last_ch_idx = len(text_a) + 1 | |
| text_a_last_ch_idx_list.append(text_a_last_ch_idx) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| t = tokenizer(x_list, return_tensors="pt", return_offsets_mapping=True, padding=True, truncation=True) | |
| input_ids_list = t.input_ids.tolist() | |
| offset_mapping_list = t.offset_mapping.tolist() | |
| first_token_idx_list = [] | |
| last_token_idx_exclusive_list = [] | |
| for elem_i, text_a_last_ch in enumerate(text_a_last_ch_idx_list): | |
| # find last token idx | |
| tid_i = 0 | |
| for tid in input_ids_list[elem_i]: | |
| if tid == tokenizer.pad_token_id: | |
| break | |
| tid_i += 1 | |
| last_token_idx_exclusive_list.append(tid_i) | |
| # find start token idx | |
| prev_start_idx = 0 | |
| for token_idx, (start_idx, end_idx)in enumerate(offset_mapping_list[elem_i]): | |
| if start_idx <= text_a_last_ch < end_idx: | |
| first_token_idx_list.append(token_idx) | |
| break | |
| elif text_a_last_ch <= prev_start_idx: | |
| first_token_idx_list.append(token_idx-1) | |
| break | |
| prev_start_idx = start_idx | |
| out = model(t.input_ids) | |
| pred = torch.softmax(out[0], dim=2) | |
| # lookup prob | |
| lookup_ids = t.input_ids[:, 1:] | |
| probs_lookup = pred[:, :-1].gather(dim=2, index=lookup_ids.unsqueeze(2)).squeeze(2).detach().cpu().numpy() | |
| ppl_list = [] | |
| # calc ppl | |
| for prob, first_token_idx, last_token_idx_exclusive in zip(probs_lookup, first_token_idx_list, last_token_idx_exclusive_list): | |
| pr = prob[first_token_idx: last_token_idx_exclusive] | |
| ppl = -1. * np.sum(np.log(pr)) / len(pr) | |
| ppl_list.append(ppl) | |
| return ppl_list | |
| if __name__ == '__main__': | |
| text_1_1 = "YMCA in South Australia" | |
| text_1_2 = "South Australia (SA) has a unique position in Australia's history as, unlike the other states which were founded as colonies, South Australia began as a self governing province Many were attracted to this and Adelaide and SA developed as an independent and free thinking state." | |
| text_2_1 = "The overall decline seems to be due to the low survival rate of young birds , which may be caused by changes in agricultural practices ." | |
| text_2_2 = "The intensive farming methods used in northern Europe mean there is less pasture and meadow habitat available , and the supply of grassland invertebrates needed for the nestlings to thrive is <unk> reduced ." | |
| text_pair_list = [(text_1_1, text_1_2), | |
| (text_2_1, text_2_2), | |
| (text_1_1, text_2_2), | |
| (text_2_1, text_1_2), | |
| (text_1_1, text_2_1), | |
| (text_1_2, text_2_2) | |
| ] | |
| lm_scorer = CausalLMScorer() | |
| ppl_list = lm_scorer.calc_conditional_ppl(text_pair_list) | |
| print(ppl_list) | |
| # for input_ids, first_token_idx, last_token_idx_exclusive in zip(input_ids_list, first_token_idx_list, last_token_idx_exclusive_list): | |
| # target_tokens = input_ids[first_token_idx: last_token_idx_exclusive] | |
| # print(tokenizer.batch_decode(target_tokens)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment