Skip to content

Instantly share code, notes, and snippets.

@kingjr
Created May 23, 2023 13:05
Show Gist options
  • Save kingjr/684c5fd8bb050d4abe4d5ffb4e437fcb to your computer and use it in GitHub Desktop.
Save kingjr/684c5fd8bb050d4abe4d5ffb4e437fcb to your computer and use it in GitHub Desktop.
from transformers import AutoModel, AutoTokenizer
import torch
from tqdm.notebook import trange, tqdm
import pandas as pd
from bm.studies.utils import match_list
import numpy as np
import matplotlib.pyplot as plt
class LM():
def __init__(self, model='gpt2', device='cuda', stride=200, agg='mean', layers=None):
self.tokenizer = AutoTokenizer.from_pretrained(model, add_special_token=False)
self.model = AutoModel.from_pretrained(model)
self.model.to(device)
self.stride = stride
self.agg = agg
self.layers = layers
def __call__(self, words: pd.DataFrame) -> torch.Tensor:
# assign token id for each word
words, tokens = self._get_tokens(words)
# retrieve transformer embedding
token_embs = self._get_token_embs(tokens)
# re-align to word embedding
word_embs = self._get_word_embs(words, token_embs, self.agg)
return word_embs
def _get_token_embs(self, tokens: torch.Tensor) -> torch.Tensor:
# retrieve size of model buffer
max_length = self.model.config.n_ctx
assert 768<=max_length<=1024
# for speed, we'll stride the transformer
n_tokens = tokens.shape[1]
starts = np.arange(max_length, n_tokens, self.stride)
strides = [max_length] + [self.stride]*(len(starts)-1)
missing = n_tokens - starts[-1]
if missing:
starts = np.r_[starts, n_tokens]
strides = strides + [missing]
positions = torch.range(0, n_tokens).to(self.model.device)
pos_max = self.model.config.max_position_embeddings - 1
positions = (positions % pos_max).long()
latents = []
past_key_values = None
for start, stride in zip(tqdm(starts), strides):
with torch.no_grad():
inpt = tokens[:, start-max_length:start]
pos = positions[start-max_length:start]
# contextual word embedding for all layers
out = self.model(inpt,
position_ids=pos,
output_hidden_states=True)
cwe = out.hidden_states
cwe = torch.stack(cwe)[:, 0]
# add word embedding
wte = self.model.base_model.wte.forward(inpt)
# remove positional embedding
latent = torch.cat([wte, cwe], dim=0)
if start-max_length:
latent = latent[:, -stride:, :]
latents.append(latent.cpu())
latents = torch.cat(latents, dim=1).permute(1, 2, 0)
if self.layers:
latents = latents[..., self.layers]
return latents
def _get_tokens(self, words: pd.DataFrame):
"""get token id for each words"""
# tokenize word sequence
string = ' '.join(words.word)
tokens = self.tokenizer.encode(string, return_tensors="pt")
def to_chars(df: pd.DataFrame):
"""convert a dataframe of word into a dataframe of characters"""
chars = []
for i, d in df.iterrows():
for char in d.word:
chars.append(dict(word_id=i, char=char))
for k, v in d.items():
chars[-1][k] = v
return pd.DataFrame(chars)
# dataframe for token
assert len(tokens)==1
df = pd.DataFrame(tokens[0], columns=['token'])
df['word'] = df.token.apply(self.tokenizer.decode)
# dataframe for chars
token_chars = to_chars(df)
word_chars = to_chars(words[['word']])
# match at the character level
i, j = match_list(token_chars.char.values, word_chars.char.values)
word_chars.loc[j, 'match'] = i
# assign corresponding token_id to each word
words['token_id'] = '[]'
for wid, d in word_chars.groupby('word_id'):
# remove all mismatch characters
match = [int(i) for i in d.match.values if i>=0]
# identify corresponding token id
match = list(token_chars.loc[match].word_id.unique())
# store as str to avoid pandas error
words.loc[wid, 'token_id'] = str(match)
# back to list
words.token_id = words.token_id.apply(eval)
#
tokens = tokens.to(self.model.device)
return words, tokens
@staticmethod
def _get_word_embs(words: pd.DataFrame, latents: torch.Tensor, method='mean') -> torch.Tensor:
out = torch.zeros(len(words), *latents[0].shape)
for wid, d in words.iterrows():
# if the word did not get a token, contiue
if not len(d.token_id):
continue
# aggregate multiple token for each word
latent = latents[d.token_id]
if method == 'mean':
latent = latent.mean(0)
elif method == 'last':
latent = latent[-1]
elif method == 'sum':
latent = latent.sum(0)
else:
raise
out[wid] = latent
return out
try:
from bm.studies.narrative2020 import Narrative2020Recording
rec = next(Narrative2020Recording.iter())
events = rec.events()
words = events.query("kind=='word'").copy().reset_index(drop=True)
except:
import random, string
def random_text(n_words):
df = []
for i in range(n_words):
word_length = random.randint(0, 10)
word = ''.join(random.choice(string.ascii_letters) for _ in range(word_length))
df.append(dict(word=word.lower()))
return pd.DataFrame(df)
words = random_text(100)
lm = LM(stride=100)
embs = lm(words)
plt.matshow(embs[:, :, 3].T, vmin=-1, vmax=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment