Skip to content

Instantly share code, notes, and snippets.

@praveenc
Created March 2, 2025 20:02
Show Gist options
  • Save praveenc/e49e753754624593e2b71ec0aedddde6 to your computer and use it in GitHub Desktop.
Save praveenc/e49e753754624593e2b71ec0aedddde6 to your computer and use it in GitHub Desktop.
LanceDB vector store operations using OpenAI compatible custom embeddings function
from pathlib import Path
from typing import Any, Dict, List, Literal, Union
import lancedb
import numpy as np
import pandas as pd
import pyarrow as pa
from lancedb.embeddings import (
EmbeddingFunction,
EmbeddingFunctionRegistry,
get_registry,
)
from lancedb.pydantic import LanceModel, Vector
from lancedb.rerankers import LinearCombinationReranker
from loguru import logger
from openai import OpenAI
from pydantic import Field
from rich import print
@EmbeddingFunctionRegistry.get_instance().register("local-embeddings")
class LocalEmbeddingFunction(EmbeddingFunction):
base_url: str = "http://localhost:1234/v1"
api_key: str = "lm-studio"
embed_model_name: str = "text-embedding-nomic-embed-text-v1.5@f16"
def ndims(self) -> int:
return 768
def compute_query_embeddings(
self, query: Union[str, List[str]], *args, **kwargs
) -> np.ndarray:
if isinstance(query, str):
query = [query]
client = OpenAI(base_url=self.base_url, api_key=self.api_key)
embeddings = []
for text in query:
text = text.replace("\n", " ")
embedding = (
client.embeddings.create(input=[text], model=self.embed_model_name)
.data[0]
.embedding
)
embeddings.append(embedding)
return np.array(embeddings)
def compute_source_embeddings(
self,
texts: Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray],
*args,
**kwargs,
) -> np.ndarray:
if isinstance(texts, str):
texts = [texts]
elif isinstance(texts, (pa.Array, pa.ChunkedArray)):
texts = texts.to_pylist()
elif isinstance(texts, np.ndarray):
texts = texts.tolist()
return self.compute_query_embeddings(texts)
db_uri = Path("data/lancedb")
table_name = "awsblogs_test"
reranker = LinearCombinationReranker(0.3)
db = lancedb.connect(db_uri)
if not table_name in db.table_names():
logger.info(f"creating table: {table_name}")
table = db.create_table(
name=table_name,
schema=AWSBlogs.to_arrow_schema(),
mode="overwrite",
)
logger.debug("creating index on text column")
table.create_fts_index("text", use_tantivy=False)
table = db.open_table(name=table_name)
batch_posts = [
{
"title": "AWS Lambda Updates",
"content": "AWS Lambda now supports function URLs and more memory options",
"url": "https://aws.amazon.com/blogs/compute/lambda-updates",
"published": "2025-03-02",
"authors": ["Jane Smith"],
"summary": "New features in AWS Lambda",
},
{
"title": "Amazon EKS Security",
"content": "Learn about Amazon EKS security features and best practices",
"url": "https://aws.amazon.com/blogs/containers/eks-security",
"published": "2025-03-02",
"authors": ["Bob Wilson", "Alice Brown"],
"summary": "Security features in Amazon EKS",
},
{
"title": "AWS Security Best Practices",
"content": "Here are some AWS security best practices: 1. Use IAM roles 2. Enable CloudTrail 3. Encrypt data at rest",
"url": "https://aws.amazon.com/blogs/security/best-practices",
"published": "2025-03-01",
"authors": ["John Doe"],
"summary": "A guide to AWS security best practices",
},
]
single_post = {
"title": "AWS Security Best Practices",
"content": "Here are some AWS security best practices: 1. Use IAM roles 2. Enable CloudTrail 3. Encrypt data at rest",
"url": "https://aws.amazon.com/blogs/security/best-practices",
"published": "2025-03-01",
"authors": ["John Doe"],
"summary": "A guide to AWS security best practices",
}
def add_records(records: List[Dict[str, Any]], table) -> int:
if isinstance(records, Dict):
records = [records]
logger.info(records)
df = pd.DataFrame(
[
{
"text": single_post.get("content"),
"url": single_post.get("url"),
"title": single_post.get("title"),
"published": single_post.get("published"),
"summary": single_post.get("summary"),
"authors": single_post.get("authors", []),
}
for single_post in records
]
)
embeddings = embed_func.compute_source_embeddings(df["text"].tolist())
df["vector"] = embeddings.tolist()
_ = table.add(data=df)
rec_count = len(df.index)
logger.info(f"Added {rec_count} records to DB")
# logger.info(f"{table.count_rows()}")
return rec_count
def search_lancedb(
query: str,
query_type: Literal["hybrid", "fts", "vector"] = "vector",
top_k: int = 10,
reranker=reranker,
):
if query_type == "hybrid":
results = (
table.search("Lambda", query_type=query_type)
.rerank(reranker=reranker)
.limit(2)
.to_list()
)
score_key = "_relevance_score"
elif query_type == "vector":
results = table.search("Lambda", query_type=query_type).limit(top_k).to_list()
score_key = "_distance"
logger.info(type(results))
sorted_results = sorted(results, key=lambda x: x[score_key], reverse=True)
for rs in sorted_results:
score = rs.get(score_key)
url = rs.get("url")
title = rs.get("title")
match = rs.get("text")
print(title)
print(url)
print(match)
print(score)
print("---" * 10)
def delete_records(predicate: str, table) -> int:
# predicate = "url = 'https://aws.amazon.com/blogs/containers/eks-security'"
rows_to_delete = table.search(predicate).to_list()
logger.info(f" Found {len(rows_to_delete)}..")
_ = table.delete(predicate)
logger.info(f"Deleted {len(rows_to_delete)} rows!!")
return rows_to_delete
if __name__ == "main":
recs = add_records(single_post, table)
print(recs)
recs = add_records(batch_posts, table)
search_lancedb(query="security")
predicate = "url = 'https://aws.amazon.com/blogs/security/best-practices'"
del_count = delete_records(predicate, table)
logger.info(f"{table.count_rows()}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment