Created
March 2, 2025 20:02
-
-
Save praveenc/e49e753754624593e2b71ec0aedddde6 to your computer and use it in GitHub Desktop.
LanceDB vector store operations using OpenAI compatible custom embeddings function
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from pathlib import Path | |
| from typing import Any, Dict, List, Literal, Union | |
| import lancedb | |
| import numpy as np | |
| import pandas as pd | |
| import pyarrow as pa | |
| from lancedb.embeddings import ( | |
| EmbeddingFunction, | |
| EmbeddingFunctionRegistry, | |
| get_registry, | |
| ) | |
| from lancedb.pydantic import LanceModel, Vector | |
| from lancedb.rerankers import LinearCombinationReranker | |
| from loguru import logger | |
| from openai import OpenAI | |
| from pydantic import Field | |
| from rich import print | |
| @EmbeddingFunctionRegistry.get_instance().register("local-embeddings") | |
| class LocalEmbeddingFunction(EmbeddingFunction): | |
| base_url: str = "http://localhost:1234/v1" | |
| api_key: str = "lm-studio" | |
| embed_model_name: str = "text-embedding-nomic-embed-text-v1.5@f16" | |
| def ndims(self) -> int: | |
| return 768 | |
| def compute_query_embeddings( | |
| self, query: Union[str, List[str]], *args, **kwargs | |
| ) -> np.ndarray: | |
| if isinstance(query, str): | |
| query = [query] | |
| client = OpenAI(base_url=self.base_url, api_key=self.api_key) | |
| embeddings = [] | |
| for text in query: | |
| text = text.replace("\n", " ") | |
| embedding = ( | |
| client.embeddings.create(input=[text], model=self.embed_model_name) | |
| .data[0] | |
| .embedding | |
| ) | |
| embeddings.append(embedding) | |
| return np.array(embeddings) | |
| def compute_source_embeddings( | |
| self, | |
| texts: Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray], | |
| *args, | |
| **kwargs, | |
| ) -> np.ndarray: | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| elif isinstance(texts, (pa.Array, pa.ChunkedArray)): | |
| texts = texts.to_pylist() | |
| elif isinstance(texts, np.ndarray): | |
| texts = texts.tolist() | |
| return self.compute_query_embeddings(texts) | |
| db_uri = Path("data/lancedb") | |
| table_name = "awsblogs_test" | |
| reranker = LinearCombinationReranker(0.3) | |
| db = lancedb.connect(db_uri) | |
| if not table_name in db.table_names(): | |
| logger.info(f"creating table: {table_name}") | |
| table = db.create_table( | |
| name=table_name, | |
| schema=AWSBlogs.to_arrow_schema(), | |
| mode="overwrite", | |
| ) | |
| logger.debug("creating index on text column") | |
| table.create_fts_index("text", use_tantivy=False) | |
| table = db.open_table(name=table_name) | |
| batch_posts = [ | |
| { | |
| "title": "AWS Lambda Updates", | |
| "content": "AWS Lambda now supports function URLs and more memory options", | |
| "url": "https://aws.amazon.com/blogs/compute/lambda-updates", | |
| "published": "2025-03-02", | |
| "authors": ["Jane Smith"], | |
| "summary": "New features in AWS Lambda", | |
| }, | |
| { | |
| "title": "Amazon EKS Security", | |
| "content": "Learn about Amazon EKS security features and best practices", | |
| "url": "https://aws.amazon.com/blogs/containers/eks-security", | |
| "published": "2025-03-02", | |
| "authors": ["Bob Wilson", "Alice Brown"], | |
| "summary": "Security features in Amazon EKS", | |
| }, | |
| { | |
| "title": "AWS Security Best Practices", | |
| "content": "Here are some AWS security best practices: 1. Use IAM roles 2. Enable CloudTrail 3. Encrypt data at rest", | |
| "url": "https://aws.amazon.com/blogs/security/best-practices", | |
| "published": "2025-03-01", | |
| "authors": ["John Doe"], | |
| "summary": "A guide to AWS security best practices", | |
| }, | |
| ] | |
| single_post = { | |
| "title": "AWS Security Best Practices", | |
| "content": "Here are some AWS security best practices: 1. Use IAM roles 2. Enable CloudTrail 3. Encrypt data at rest", | |
| "url": "https://aws.amazon.com/blogs/security/best-practices", | |
| "published": "2025-03-01", | |
| "authors": ["John Doe"], | |
| "summary": "A guide to AWS security best practices", | |
| } | |
| def add_records(records: List[Dict[str, Any]], table) -> int: | |
| if isinstance(records, Dict): | |
| records = [records] | |
| logger.info(records) | |
| df = pd.DataFrame( | |
| [ | |
| { | |
| "text": single_post.get("content"), | |
| "url": single_post.get("url"), | |
| "title": single_post.get("title"), | |
| "published": single_post.get("published"), | |
| "summary": single_post.get("summary"), | |
| "authors": single_post.get("authors", []), | |
| } | |
| for single_post in records | |
| ] | |
| ) | |
| embeddings = embed_func.compute_source_embeddings(df["text"].tolist()) | |
| df["vector"] = embeddings.tolist() | |
| _ = table.add(data=df) | |
| rec_count = len(df.index) | |
| logger.info(f"Added {rec_count} records to DB") | |
| # logger.info(f"{table.count_rows()}") | |
| return rec_count | |
| def search_lancedb( | |
| query: str, | |
| query_type: Literal["hybrid", "fts", "vector"] = "vector", | |
| top_k: int = 10, | |
| reranker=reranker, | |
| ): | |
| if query_type == "hybrid": | |
| results = ( | |
| table.search("Lambda", query_type=query_type) | |
| .rerank(reranker=reranker) | |
| .limit(2) | |
| .to_list() | |
| ) | |
| score_key = "_relevance_score" | |
| elif query_type == "vector": | |
| results = table.search("Lambda", query_type=query_type).limit(top_k).to_list() | |
| score_key = "_distance" | |
| logger.info(type(results)) | |
| sorted_results = sorted(results, key=lambda x: x[score_key], reverse=True) | |
| for rs in sorted_results: | |
| score = rs.get(score_key) | |
| url = rs.get("url") | |
| title = rs.get("title") | |
| match = rs.get("text") | |
| print(title) | |
| print(url) | |
| print(match) | |
| print(score) | |
| print("---" * 10) | |
| def delete_records(predicate: str, table) -> int: | |
| # predicate = "url = 'https://aws.amazon.com/blogs/containers/eks-security'" | |
| rows_to_delete = table.search(predicate).to_list() | |
| logger.info(f" Found {len(rows_to_delete)}..") | |
| _ = table.delete(predicate) | |
| logger.info(f"Deleted {len(rows_to_delete)} rows!!") | |
| return rows_to_delete | |
| if __name__ == "main": | |
| recs = add_records(single_post, table) | |
| print(recs) | |
| recs = add_records(batch_posts, table) | |
| search_lancedb(query="security") | |
| predicate = "url = 'https://aws.amazon.com/blogs/security/best-practices'" | |
| del_count = delete_records(predicate, table) | |
| logger.info(f"{table.count_rows()}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment