Created
May 4, 2023 17:40
-
-
Save lumpidu/f9d068146564f9aea94e42ed2c04f68d to your computer and use it in GitHub Desktop.
Create embeddings from given LM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import transformers | |
import argparse | |
from pytictoc import TicToc | |
def load_model_and_tokenizer(model_name): | |
the_tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) | |
tokenizer_model = transformers.AutoModel.from_pretrained(model_name) | |
return tokenizer_model, the_tokenizer | |
def read_text_file(file_path): | |
with open(file_path, "r") as f: | |
text = f.read() | |
return text | |
def split_text_into_parts(text, max_length): | |
text_parts = [text[i:i + max_length] for i in range(0, len(text), max_length)] | |
return text_parts | |
def tokenize_text_parts(text_parts, tokenizer): | |
toks = [] | |
t = TicToc() | |
t.tic() | |
for part in text_parts: | |
toks.extend(tokenizer.tokenize(part)) | |
t.toc('tokenized') | |
return toks | |
def generate_embeddings(model, tokenizer, tokens, output_file): | |
num_tokens = len(tokens) | |
# Account for special tokens | |
max_position_embeddings = model.config.max_position_embeddings - 2 | |
num_parts = (num_tokens - 1) // max_position_embeddings + 1 | |
with open(output_file, "w") as f: | |
t = TicToc() | |
t.tic() | |
for i in range(num_parts): | |
start = i * max_position_embeddings | |
end = min(num_tokens, start + max_position_embeddings) | |
part = tokens[start:end] | |
part_text = tokenizer.convert_tokens_to_string(part) | |
# Adding special tokens and truncating if necessary | |
input_ids = tokenizer.encode(part_text, return_tensors="pt", max_length=max_position_embeddings, truncation=True) | |
attention_mask = (input_ids > 0).long() | |
input_ids = input_ids.to(device) | |
attention_mask = attention_mask.to(device) | |
print(f"Part {i}: input_ids.shape = {input_ids.shape}, attention_mask.shape = {attention_mask.shape}") | |
if input_ids.size(1) > model.config.max_position_embeddings: | |
print(f"Warning: Skipping part {i} due to exceeding the maximum sequence length") | |
continue | |
with torch.no_grad(): | |
outputs = model(input_ids, attention_mask=attention_mask) | |
embeddings = outputs[0][0, 1:-1, :].mean(dim=0) | |
f.write(f"Embedding for part {i}: {embeddings.tolist()}\n") | |
t.toc('generated embeddings') | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--input_file", type=str, help="Path to the input text file.") | |
parser.add_argument("--model", type=str, help="Name of the Transformer model to use.") | |
parser.add_argument("--output_file", type=str, help="Path to the output file for embeddings.") | |
args = parser.parse_args() | |
model, tokenizer = load_model_and_tokenizer(args.model) | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
print(f"Using GPU {torch.cuda.get_device_name(0)}") | |
else: | |
device = torch.device("cpu") | |
print("Using CPU") | |
model.to(device) | |
text = read_text_file(args.input_file) | |
text_parts = split_text_into_parts(text, model.config.max_position_embeddings) | |
tokens = tokenize_text_parts(text_parts, tokenizer) | |
generate_embeddings(model, tokenizer, tokens, args.output_file) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment