-
-
Save ochafik/df77f85cea9309d0686fb3a7a7f01bcf to your computer and use it in GitHub Desktop.
sqlite-vec + sqlite-lembed + sqlite-rembed in one neat package (+ semantic text file indexing example)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Copyright 2024 Google LLC. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # /// script | |
| # requires-python = ">=3.11" | |
| # dependencies = [ | |
| # "aiosqlite", | |
| # "gguf", | |
| # "huggingface-hub", | |
| # "more-itertools", | |
| # "openai", | |
| # "pydantic", | |
| # "sentencepiece", | |
| # "sqlite-lembed", | |
| # "sqlite-rembed", | |
| # "sqlite-vec", | |
| # ] | |
| # /// | |
| ''' | |
| Library to help setup sqlite-vec + sqlite-lembed / sqlite-rembed | |
| for indexing and searching of text columns. | |
| Given an existing TEXT column, SQLiteVecHelper creates and populates | |
| an embedding index (auto updated w/ triggers), and returns a search function | |
| that hides much of the complexity. | |
| This file can also be run standalone and acts as a mini lines indexer + interactive search. | |
| ''' | |
| import aiosqlite | |
| import asyncio | |
| from contextlib import asynccontextmanager | |
| import logging | |
| from more_itertools import unzip | |
| import os | |
| from pydantic import BaseModel | |
| import sqlite_lembed | |
| import sqlite_rembed | |
| import sqlite_vec | |
| import sys | |
| from typing import Dict, Optional, Union, Literal | |
| class ContextOptions(BaseModel): | |
| n_ctx: int = 2048 | |
| rope_scaling_type: Literal['none', 'linear', 'yarn'] | |
| rope_freq_scale: float = 0.0 | |
| class HfModelInfo(BaseModel): | |
| repo_id: str | |
| file_name: str | |
| context_options: Optional[ContextOptions] = None | |
| class LocalSQLiteEmbeddingsConfig(BaseModel): | |
| config_name: Optional[str] = 'default' | |
| embedding_length: int | |
| gguf_model_file: str | |
| n_gpu_layers: int = 99 | |
| context_options: ContextOptions | |
| class RemoteSQLiteEmbeddingsConfig(BaseModel): | |
| config_name: Optional[str] = 'default' | |
| embedding_length: int | |
| endpoint: str | |
| api_key: Optional[str] = None | |
| SQLiteEmbeddingsConfig = Union[LocalSQLiteEmbeddingsConfig, RemoteSQLiteEmbeddingsConfig] | |
| known_hf_models: Dict[str, HfModelInfo] = { | |
| 'nomic-embed-text-v1.5': HfModelInfo( | |
| repo_id='nomic-ai/nomic-embed-text-v1.5-GGUF', | |
| file_name='nomic-embed-text-v1.5.Q8_0.gguf', | |
| context_options=dict( | |
| n_ctx=8192, | |
| rope_scaling_type='yarn', | |
| rope_freq_scale=.75, | |
| ), | |
| ), | |
| } | |
| def get_embedding_length(gguf_model_file: str) -> int: | |
| import gguf | |
| reader = gguf.GGUFReader(gguf_model_file) | |
| field = reader.get_field(gguf.Keys.LLM.EMBEDDING_LENGTH) | |
| field = field or reader.get_field('nomic-bert.embedding_length') | |
| return field.parts[field.data[0]][0] | |
| def get_hf_config(model: HfModelInfo) -> LocalSQLiteEmbeddingsConfig: | |
| from huggingface_hub import hf_hub_download | |
| gguf_model_file = hf_hub_download(repo_id=model.repo_id, filename=model.file_name) | |
| embedding_length = get_embedding_length(gguf_model_file) | |
| return LocalSQLiteEmbeddingsConfig( | |
| embedding_length=embedding_length, | |
| gguf_model_file=gguf_model_file, | |
| context_options=model.context_options, | |
| ) | |
| def get_known_config(name: Optional[str] = None) -> SQLiteEmbeddingsConfig: | |
| if not name: | |
| name = list(known_hf_models.keys())[0] | |
| if not (model := known_hf_models.get(name)): | |
| raise ValueError(f'Unknown model name: {name}') | |
| return get_hf_config(model) | |
| @asynccontextmanager | |
| async def extension_loading_enabled(db: aiosqlite.Connection): | |
| await db.enable_load_extension(True) | |
| try: | |
| yield | |
| finally: | |
| await db.enable_load_extension(False) | |
| class SQLiteVecHelper: | |
| def __init__(self, config: SQLiteEmbeddingsConfig): | |
| ''' | |
| Create a helper for working with sqlite-vec and sqlite-lembed/sqlite-rembed. | |
| ''' | |
| self.config = config | |
| if isinstance(self.config, LocalSQLiteEmbeddingsConfig): | |
| self.embed_fn = lambda x: f'lembed("{config.config_name}", {x})' | |
| else: | |
| self.embed_fn = lambda x: f'rembed("{config.config_name}", {x})' | |
| async def setup_extensions(self, db: aiosqlite.Connection): | |
| async with extension_loading_enabled(db): | |
| await db.load_extension(sqlite_vec.loadable_path()) | |
| if isinstance(self.config, LocalSQLiteEmbeddingsConfig): | |
| await db.load_extension(sqlite_lembed.loadable_path()) | |
| else: | |
| await db.load_extension(sqlite_rembed.loadable_path()) | |
| async def setup_options(self, db: aiosqlite.Connection): | |
| if isinstance(self.config, LocalSQLiteEmbeddingsConfig): | |
| await db.execute(f''' | |
| INSERT INTO lembed_models(name, model, model_options, context_options) VALUES ( | |
| ?, | |
| lembed_model_from_file(?), | |
| lembed_model_options( | |
| 'n_gpu_layers', ? | |
| ), | |
| lembed_context_options( | |
| 'n_ctx', ?, | |
| 'rope_scaling_type', ?, | |
| 'rope_freq_scale', ? | |
| ) | |
| ); | |
| ''', ( | |
| self.config.config_name, | |
| self.config.gguf_model_file, | |
| self.config.n_gpu_layers, | |
| self.config.context_options.n_ctx, | |
| self.config.context_options.rope_scaling_type, | |
| self.config.context_options.rope_freq_scale | |
| )) | |
| else: | |
| await db.execute(f''' | |
| INSERT INTO rembed_clients(name, options) VALUES ( | |
| ?, | |
| rembed_client_options('format', 'llamafile', 'url', ?, 'key', ?) | |
| ); | |
| ''', (self.config.config_name, self.config.endpoint, self.config.api_key)) | |
| async def setup_connection(self, db: aiosqlite.Connection): | |
| await self.setup_extensions(db) | |
| await self.setup_options(db) | |
| async def setup_embeddings(self, db: aiosqlite.Connection, table: str, column: str): | |
| ''' | |
| Create an sqlite-vec virtual table w/ an embedding column | |
| kept in sync with a source table's text column, | |
| and return a search function that can be used to query it. | |
| ''' | |
| embeddings_table = f'{table}_{column}_embeddings' | |
| try: | |
| await db.execute(f'SELECT 1 FROM {embeddings_table} LIMIT 1') | |
| embeddings_table_exists = True | |
| except aiosqlite.OperationalError: | |
| embeddings_table_exists = False | |
| await db.execute(f''' | |
| CREATE VIRTUAL TABLE IF NOT EXISTS {embeddings_table} | |
| USING vec0(embedding float[{self.config.embedding_length}]) | |
| ''') | |
| await db.execute(f''' | |
| CREATE TRIGGER IF NOT EXISTS {embeddings_table}_insert AFTER INSERT ON {table} | |
| BEGIN | |
| INSERT INTO {embeddings_table} (rowid, embedding) | |
| VALUES (NEW.rowid, {self.embed_fn('NEW.' + column)}); | |
| END; | |
| ''') | |
| await db.execute(f''' | |
| CREATE TRIGGER IF NOT EXISTS {embeddings_table}_update AFTER UPDATE OF {column} ON {table} | |
| BEGIN | |
| UPDATE {embeddings_table} | |
| SET embedding = {self.embed_fn('NEW.' + column)} | |
| WHERE rowid = NEW.rowid; | |
| END; | |
| ''') | |
| await db.execute(f''' | |
| CREATE TRIGGER IF NOT EXISTS {embeddings_table}_delete AFTER DELETE ON {table} | |
| BEGIN | |
| DELETE FROM {embeddings_table} | |
| WHERE rowid = OLD.rowid; | |
| END; | |
| ''') | |
| if not embeddings_table_exists: | |
| await db.execute(f''' | |
| INSERT INTO {embeddings_table} (rowid, embedding) | |
| SELECT rowid, {self.embed_fn(column)} | |
| FROM {table} WHERE {column} IS NOT NULL | |
| ''') | |
| def text_search(text: str, top_n: int, columns: list[str] = ['rowid', column], joins: list[str] = []) -> aiosqlite.Cursor: | |
| ''' | |
| Search the vector index for the embedding of the provided text and return | |
| the distance of the top_n nearest matches + their corresponding original table's columns. | |
| ''' | |
| col_seq = ', '.join(['distance', *columns]) | |
| return db.execute(f''' | |
| SELECT {col_seq} | |
| FROM ( | |
| SELECT rowid, distance | |
| FROM {embeddings_table} | |
| WHERE {embeddings_table}.embedding MATCH {self.embed_fn('?')} | |
| ORDER BY distance | |
| LIMIT ? | |
| ) | |
| JOIN {table} USING (rowid) | |
| {' '.join(joins)} | |
| ''', (text, top_n)) | |
| return text_search | |
| async def main(): | |
| # Embeddings configuration: | |
| # Can either provide an embeddings model file (to be loaded locally by sqlite-lembed) | |
| # or an embeddings endpoint w/ optional api key (to be queried remotely by sqlite-rembed). | |
| if 'ENDPOINT' in os.environ: | |
| from openai import AsyncOpenAI | |
| endpoint = os.environ['ENDPOINT'] | |
| api_key = os.environ.get('API_KEY') | |
| embedding_length = len((await AsyncOpenAI(base_url=endpoint + '/v1', api_key=api_key).embeddings.create('Test')).data[0].embedding) | |
| config = RemoteSQLiteEmbeddingsConfig( | |
| embedding_length=embedding_length, | |
| endpoint=endpoint, | |
| api_key=api_key | |
| ) | |
| elif (gguf_model_file := os.environ.get('GGUF_MODEL_FILE')): | |
| config = LocalSQLiteEmbeddingsConfig( | |
| embedding_length=get_embedding_length(gguf_model_file), | |
| gguf_model_file=gguf_model_file, | |
| context_options=ContextOptions( | |
| n_ctx=int(os.environ.get('N_CTX', 2048)), | |
| rope_scaling_type=os.environ.get('ROPE_SCALING_TYPE', 'none'), | |
| rope_freq_scale=float(os.environ.get('ROPE_FREQ_SCALE', '0.0')), | |
| ) | |
| ) | |
| else: | |
| config = get_known_config(os.environ.get('CONFIG')) | |
| logging.info(f'Using embeddings config: {config}') | |
| helper = SQLiteVecHelper(config) | |
| db_file = os.environ.get('DB_FILE', 'index.db') | |
| input_files = sys.argv[1:] | |
| async with aiosqlite.connect(db_file) as db: | |
| await helper.setup_connection(db) | |
| await db.execute(''' | |
| CREATE TABLE IF NOT EXISTS files ( | |
| rowid INTEGER PRIMARY KEY AUTOINCREMENT, | |
| path TEXT NOT NULL, | |
| lastmod INTEGER NOT NULL | |
| ); | |
| ''') | |
| await db.execute(''' | |
| CREATE TABLE IF NOT EXISTS lines ( | |
| rowid INTEGER PRIMARY KEY AUTOINCREMENT, | |
| fileid INTEGER NOT NULL REFERENCES files(rowid) ON DELETE CASCADE, | |
| lineno INTEGER NOT NULL, | |
| content TEXT NOT NULL | |
| ) | |
| ''') | |
| search_content = await helper.setup_embeddings(db, 'lines', 'content') | |
| await db.commit() | |
| for file in input_files: | |
| if not os.path.exists(file): | |
| print(f'File {file} does not exist, skipping.', file=sys.stderr) | |
| continue | |
| file = os.path.abspath(file) | |
| lastmod = os.path.getmtime(file) | |
| lastmod_cursor = await db.execute('SELECT lastmod FROM files WHERE path = ?', (file,)) | |
| lastmod_saved = await lastmod_cursor.fetchone() | |
| lastmod_saved = lastmod_saved[0] if lastmod_saved else None | |
| if lastmod_saved == lastmod: | |
| print(f'Index for {file} is still valid', file=sys.stderr) | |
| continue | |
| if lastmod_saved: | |
| print(f'Deleting stale index for {file}', file=sys.stderr) | |
| await db.execute('DELETE FROM files WHERE file = ?', (file,)) | |
| with open(file) as f: | |
| print('Indexing', file, file=sys.stderr) | |
| await db.execute('INSERT INTO files (path, lastmod) VALUES (?, ?)', (file, lastmod)) | |
| fileid = (await (await db.execute('SELECT last_insert_rowid()')).fetchone())[0] | |
| lines = f.readlines() | |
| inputs = [ | |
| (fileid, i + 1, line.rstrip()) | |
| for i, line in enumerate(lines) | |
| if any(c.isalpha() for c in line) | |
| ] | |
| await db.executemany( | |
| 'INSERT INTO lines (fileid, lineno, content) VALUES (?, ?, ?)', | |
| inputs | |
| ) | |
| await db.commit() | |
| k = os.environ.get('K', 5) | |
| print('Indexing complete. Type one search per line:', file=sys.stderr) | |
| for line in sys.stdin: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| # print(f'Searching for: {line}', file=sys.stderr) | |
| async with search_content(line, k, columns=['path', 'lineno', 'content'], joins=['INNER JOIN files ON fileid = files.rowid']) as cursor: | |
| results = await cursor.fetchall() | |
| cols = [c[0] for c in cursor.description] | |
| for res in [dict(zip(cols, row)) for row in results]: | |
| print(f'{res["path"]}:{res["lineno"]} ({res["distance"]:.2f}): {res["content"]}') | |
| if __name__ == '__main__': | |
| asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment