Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save jmdfm/9836ea2174d135d2306cda0002b86882 to your computer and use it in GitHub Desktop.
Save jmdfm/9836ea2174d135d2306cda0002b86882 to your computer and use it in GitHub Desktop.
from transformers import AutoTokenizer, AutoModel
import torch
# MODEL CKPT is downloaded from: "jinaai/jina-embeddings-v2-base-en" # has context len of 8192
MODEL_CKPT = "/Users/rohan/3_Resources/ai_models/jina-embeddings-v2-base-en"
def recursive_splitter(text: str, separators: list[str], chunk_size: int) -> list[str]:
if len(separators) == 0:
words = text.strip().split(' ')
return [' '.join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)]
ret = []
first_sep = separators[0]
for chunk in text.split(first_sep): ret.extend(recursive_splitter(chunk, separators[1:], chunk_size))
return ret
def embed_using_late_chunking(chunks):
tokenizer = AutoTokenizer.from_pretrained(MODEL_CKPT) # this simple BERT tokenizer
inp_tokens = [x[1:-1] for x in tokenizer(chunks)['input_ids']] # removing CLS and SEP token from start and end of each chunk
offsets = [1]
all_tokens = [tokenizer.cls_token_id]
for toks in inp_tokens:
offsets.append(offsets[-1] + len(toks))
all_tokens.extend(toks)
all_tokens.append(tokenizer.sep_token_id)
model = AutoModel.from_pretrained(MODEL_CKPT, trust_remote_code=True)
model.eval()
with torch.no_grad(): outputs = model(input_ids=torch.tensor(all_tokens).unsqueeze(-1))
return [outputs.last_hidden_state[0, i:j, :].mean(dim=-2).detach().numpy().tolist() for i, j in zip(offsets, offsets[1:])]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment