Skip to content

Instantly share code, notes, and snippets.

@nsdevaraj
Last active September 28, 2024 10:08
Show Gist options
  • Save nsdevaraj/d596d0ca02f3e3e625be433a53caa1b7 to your computer and use it in GitHub Desktop.
Save nsdevaraj/d596d0ca02f3e3e625be433a53caa1b7 to your computer and use it in GitHub Desktop.
local model embedding
import sqlite3
import torch
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
import gc
#python3 -m venv path/to/venv
#source path/to/venv/bin/activate
# Connect to the SQLite database
conn = sqlite3.connect('data.sqlite')
cursor = conn.cursor()
# Ensure the embeddings column exists
cursor.execute("CREATE TABLE IF NOT EXISTS tirukkural (kno INTEGER PRIMARY KEY, efirstline TEXT, esecondline TEXT, explanation TEXT, embeddings BLOB)")
# Load the model and tokenizer
#model_name = "sentence-transformers/all-MiniLM-L6-v2" # A smaller, more efficient model
model_name = "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, ignore_mismatched_sizes=True)
# Move model to GPU if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Function to generate embeddings
def generate_embedding(texts, batch_size=32):
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=256)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
all_embeddings.extend(embeddings)
# Clear CUDA cache if using GPU
if device.type == "cuda":
torch.cuda.empty_cache()
return all_embeddings
# Fetch the data you want to embed
cursor.execute("SELECT kno, efirstline, esecondline, explanation FROM tirukkural WHERE embeddings IS NULL")
rows = cursor.fetchall()
# Process data in chunks to manage memory
chunk_size = 100
for chunk_start in range(0, len(rows), chunk_size):
chunk_end = min(chunk_start + chunk_size, len(rows))
chunk = rows[chunk_start:chunk_end]
# Prepare data for batch processing
ids = [row[0] for row in chunk]
texts = [f"{row[1]} {row[2]} {row[3]}" for row in chunk]
# Generate embeddings in batches
batch_size = 32
embeddings = generate_embedding(texts, batch_size)
# Update database
for id, embedding in zip(ids, embeddings):
cursor.execute("UPDATE tirukkural SET embeddings = ? WHERE kno = ?", (embedding.tobytes(), id))
conn.commit()
# Clear some memory
del embeddings, texts, ids
gc.collect()
# Close connection
conn.close()
print("Embedding generation complete.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment