Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save lumpidu/f9d068146564f9aea94e42ed2c04f68d to your computer and use it in GitHub Desktop.
Save lumpidu/f9d068146564f9aea94e42ed2c04f68d to your computer and use it in GitHub Desktop.
Create embeddings from given LM
import torch
import transformers
import argparse
from pytictoc import TicToc
def load_model_and_tokenizer(model_name):
the_tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
tokenizer_model = transformers.AutoModel.from_pretrained(model_name)
return tokenizer_model, the_tokenizer
def read_text_file(file_path):
with open(file_path, "r") as f:
text = f.read()
return text
def split_text_into_parts(text, max_length):
text_parts = [text[i:i + max_length] for i in range(0, len(text), max_length)]
return text_parts
def tokenize_text_parts(text_parts, tokenizer):
toks = []
t = TicToc()
t.tic()
for part in text_parts:
toks.extend(tokenizer.tokenize(part))
t.toc('tokenized')
return toks
def generate_embeddings(model, tokenizer, tokens, output_file):
num_tokens = len(tokens)
# Account for special tokens
max_position_embeddings = model.config.max_position_embeddings - 2
num_parts = (num_tokens - 1) // max_position_embeddings + 1
with open(output_file, "w") as f:
t = TicToc()
t.tic()
for i in range(num_parts):
start = i * max_position_embeddings
end = min(num_tokens, start + max_position_embeddings)
part = tokens[start:end]
part_text = tokenizer.convert_tokens_to_string(part)
# Adding special tokens and truncating if necessary
input_ids = tokenizer.encode(part_text, return_tensors="pt", max_length=max_position_embeddings, truncation=True)
attention_mask = (input_ids > 0).long()
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
print(f"Part {i}: input_ids.shape = {input_ids.shape}, attention_mask.shape = {attention_mask.shape}")
if input_ids.size(1) > model.config.max_position_embeddings:
print(f"Warning: Skipping part {i} due to exceeding the maximum sequence length")
continue
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
embeddings = outputs[0][0, 1:-1, :].mean(dim=0)
f.write(f"Embedding for part {i}: {embeddings.tolist()}\n")
t.toc('generated embeddings')
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str, help="Path to the input text file.")
parser.add_argument("--model", type=str, help="Name of the Transformer model to use.")
parser.add_argument("--output_file", type=str, help="Path to the output file for embeddings.")
args = parser.parse_args()
model, tokenizer = load_model_and_tokenizer(args.model)
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"Using GPU {torch.cuda.get_device_name(0)}")
else:
device = torch.device("cpu")
print("Using CPU")
model.to(device)
text = read_text_file(args.input_file)
text_parts = split_text_into_parts(text, model.config.max_position_embeddings)
tokens = tokenize_text_parts(text_parts, tokenizer)
generate_embeddings(model, tokenizer, tokens, args.output_file)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment