Skip to content

Instantly share code, notes, and snippets.

@ochafik
Created November 15, 2024 02:48
Show Gist options
  • Select an option

  • Save ochafik/df77f85cea9309d0686fb3a7a7f01bcf to your computer and use it in GitHub Desktop.

Select an option

Save ochafik/df77f85cea9309d0686fb3a7a7f01bcf to your computer and use it in GitHub Desktop.
sqlite-vec + sqlite-lembed + sqlite-rembed in one neat package (+ semantic text file indexing example)
# Copyright 2024 Google LLC.
# SPDX-License-Identifier: Apache-2.0
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "aiosqlite",
# "gguf",
# "huggingface-hub",
# "more-itertools",
# "openai",
# "pydantic",
# "sentencepiece",
# "sqlite-lembed",
# "sqlite-rembed",
# "sqlite-vec",
# ]
# ///
'''
Library to help setup sqlite-vec + sqlite-lembed / sqlite-rembed
for indexing and searching of text columns.
Given an existing TEXT column, SQLiteVecHelper creates and populates
an embedding index (auto updated w/ triggers), and returns a search function
that hides much of the complexity.
This file can also be run standalone and acts as a mini lines indexer + interactive search.
'''
import aiosqlite
import asyncio
from contextlib import asynccontextmanager
import logging
from more_itertools import unzip
import os
from pydantic import BaseModel
import sqlite_lembed
import sqlite_rembed
import sqlite_vec
import sys
from typing import Dict, Optional, Union, Literal
class ContextOptions(BaseModel):
n_ctx: int = 2048
rope_scaling_type: Literal['none', 'linear', 'yarn']
rope_freq_scale: float = 0.0
class HfModelInfo(BaseModel):
repo_id: str
file_name: str
context_options: Optional[ContextOptions] = None
class LocalSQLiteEmbeddingsConfig(BaseModel):
config_name: Optional[str] = 'default'
embedding_length: int
gguf_model_file: str
n_gpu_layers: int = 99
context_options: ContextOptions
class RemoteSQLiteEmbeddingsConfig(BaseModel):
config_name: Optional[str] = 'default'
embedding_length: int
endpoint: str
api_key: Optional[str] = None
SQLiteEmbeddingsConfig = Union[LocalSQLiteEmbeddingsConfig, RemoteSQLiteEmbeddingsConfig]
known_hf_models: Dict[str, HfModelInfo] = {
'nomic-embed-text-v1.5': HfModelInfo(
repo_id='nomic-ai/nomic-embed-text-v1.5-GGUF',
file_name='nomic-embed-text-v1.5.Q8_0.gguf',
context_options=dict(
n_ctx=8192,
rope_scaling_type='yarn',
rope_freq_scale=.75,
),
),
}
def get_embedding_length(gguf_model_file: str) -> int:
import gguf
reader = gguf.GGUFReader(gguf_model_file)
field = reader.get_field(gguf.Keys.LLM.EMBEDDING_LENGTH)
field = field or reader.get_field('nomic-bert.embedding_length')
return field.parts[field.data[0]][0]
def get_hf_config(model: HfModelInfo) -> LocalSQLiteEmbeddingsConfig:
from huggingface_hub import hf_hub_download
gguf_model_file = hf_hub_download(repo_id=model.repo_id, filename=model.file_name)
embedding_length = get_embedding_length(gguf_model_file)
return LocalSQLiteEmbeddingsConfig(
embedding_length=embedding_length,
gguf_model_file=gguf_model_file,
context_options=model.context_options,
)
def get_known_config(name: Optional[str] = None) -> SQLiteEmbeddingsConfig:
if not name:
name = list(known_hf_models.keys())[0]
if not (model := known_hf_models.get(name)):
raise ValueError(f'Unknown model name: {name}')
return get_hf_config(model)
@asynccontextmanager
async def extension_loading_enabled(db: aiosqlite.Connection):
await db.enable_load_extension(True)
try:
yield
finally:
await db.enable_load_extension(False)
class SQLiteVecHelper:
def __init__(self, config: SQLiteEmbeddingsConfig):
'''
Create a helper for working with sqlite-vec and sqlite-lembed/sqlite-rembed.
'''
self.config = config
if isinstance(self.config, LocalSQLiteEmbeddingsConfig):
self.embed_fn = lambda x: f'lembed("{config.config_name}", {x})'
else:
self.embed_fn = lambda x: f'rembed("{config.config_name}", {x})'
async def setup_extensions(self, db: aiosqlite.Connection):
async with extension_loading_enabled(db):
await db.load_extension(sqlite_vec.loadable_path())
if isinstance(self.config, LocalSQLiteEmbeddingsConfig):
await db.load_extension(sqlite_lembed.loadable_path())
else:
await db.load_extension(sqlite_rembed.loadable_path())
async def setup_options(self, db: aiosqlite.Connection):
if isinstance(self.config, LocalSQLiteEmbeddingsConfig):
await db.execute(f'''
INSERT INTO lembed_models(name, model, model_options, context_options) VALUES (
?,
lembed_model_from_file(?),
lembed_model_options(
'n_gpu_layers', ?
),
lembed_context_options(
'n_ctx', ?,
'rope_scaling_type', ?,
'rope_freq_scale', ?
)
);
''', (
self.config.config_name,
self.config.gguf_model_file,
self.config.n_gpu_layers,
self.config.context_options.n_ctx,
self.config.context_options.rope_scaling_type,
self.config.context_options.rope_freq_scale
))
else:
await db.execute(f'''
INSERT INTO rembed_clients(name, options) VALUES (
?,
rembed_client_options('format', 'llamafile', 'url', ?, 'key', ?)
);
''', (self.config.config_name, self.config.endpoint, self.config.api_key))
async def setup_connection(self, db: aiosqlite.Connection):
await self.setup_extensions(db)
await self.setup_options(db)
async def setup_embeddings(self, db: aiosqlite.Connection, table: str, column: str):
'''
Create an sqlite-vec virtual table w/ an embedding column
kept in sync with a source table's text column,
and return a search function that can be used to query it.
'''
embeddings_table = f'{table}_{column}_embeddings'
try:
await db.execute(f'SELECT 1 FROM {embeddings_table} LIMIT 1')
embeddings_table_exists = True
except aiosqlite.OperationalError:
embeddings_table_exists = False
await db.execute(f'''
CREATE VIRTUAL TABLE IF NOT EXISTS {embeddings_table}
USING vec0(embedding float[{self.config.embedding_length}])
''')
await db.execute(f'''
CREATE TRIGGER IF NOT EXISTS {embeddings_table}_insert AFTER INSERT ON {table}
BEGIN
INSERT INTO {embeddings_table} (rowid, embedding)
VALUES (NEW.rowid, {self.embed_fn('NEW.' + column)});
END;
''')
await db.execute(f'''
CREATE TRIGGER IF NOT EXISTS {embeddings_table}_update AFTER UPDATE OF {column} ON {table}
BEGIN
UPDATE {embeddings_table}
SET embedding = {self.embed_fn('NEW.' + column)}
WHERE rowid = NEW.rowid;
END;
''')
await db.execute(f'''
CREATE TRIGGER IF NOT EXISTS {embeddings_table}_delete AFTER DELETE ON {table}
BEGIN
DELETE FROM {embeddings_table}
WHERE rowid = OLD.rowid;
END;
''')
if not embeddings_table_exists:
await db.execute(f'''
INSERT INTO {embeddings_table} (rowid, embedding)
SELECT rowid, {self.embed_fn(column)}
FROM {table} WHERE {column} IS NOT NULL
''')
def text_search(text: str, top_n: int, columns: list[str] = ['rowid', column], joins: list[str] = []) -> aiosqlite.Cursor:
'''
Search the vector index for the embedding of the provided text and return
the distance of the top_n nearest matches + their corresponding original table's columns.
'''
col_seq = ', '.join(['distance', *columns])
return db.execute(f'''
SELECT {col_seq}
FROM (
SELECT rowid, distance
FROM {embeddings_table}
WHERE {embeddings_table}.embedding MATCH {self.embed_fn('?')}
ORDER BY distance
LIMIT ?
)
JOIN {table} USING (rowid)
{' '.join(joins)}
''', (text, top_n))
return text_search
async def main():
# Embeddings configuration:
# Can either provide an embeddings model file (to be loaded locally by sqlite-lembed)
# or an embeddings endpoint w/ optional api key (to be queried remotely by sqlite-rembed).
if 'ENDPOINT' in os.environ:
from openai import AsyncOpenAI
endpoint = os.environ['ENDPOINT']
api_key = os.environ.get('API_KEY')
embedding_length = len((await AsyncOpenAI(base_url=endpoint + '/v1', api_key=api_key).embeddings.create('Test')).data[0].embedding)
config = RemoteSQLiteEmbeddingsConfig(
embedding_length=embedding_length,
endpoint=endpoint,
api_key=api_key
)
elif (gguf_model_file := os.environ.get('GGUF_MODEL_FILE')):
config = LocalSQLiteEmbeddingsConfig(
embedding_length=get_embedding_length(gguf_model_file),
gguf_model_file=gguf_model_file,
context_options=ContextOptions(
n_ctx=int(os.environ.get('N_CTX', 2048)),
rope_scaling_type=os.environ.get('ROPE_SCALING_TYPE', 'none'),
rope_freq_scale=float(os.environ.get('ROPE_FREQ_SCALE', '0.0')),
)
)
else:
config = get_known_config(os.environ.get('CONFIG'))
logging.info(f'Using embeddings config: {config}')
helper = SQLiteVecHelper(config)
db_file = os.environ.get('DB_FILE', 'index.db')
input_files = sys.argv[1:]
async with aiosqlite.connect(db_file) as db:
await helper.setup_connection(db)
await db.execute('''
CREATE TABLE IF NOT EXISTS files (
rowid INTEGER PRIMARY KEY AUTOINCREMENT,
path TEXT NOT NULL,
lastmod INTEGER NOT NULL
);
''')
await db.execute('''
CREATE TABLE IF NOT EXISTS lines (
rowid INTEGER PRIMARY KEY AUTOINCREMENT,
fileid INTEGER NOT NULL REFERENCES files(rowid) ON DELETE CASCADE,
lineno INTEGER NOT NULL,
content TEXT NOT NULL
)
''')
search_content = await helper.setup_embeddings(db, 'lines', 'content')
await db.commit()
for file in input_files:
if not os.path.exists(file):
print(f'File {file} does not exist, skipping.', file=sys.stderr)
continue
file = os.path.abspath(file)
lastmod = os.path.getmtime(file)
lastmod_cursor = await db.execute('SELECT lastmod FROM files WHERE path = ?', (file,))
lastmod_saved = await lastmod_cursor.fetchone()
lastmod_saved = lastmod_saved[0] if lastmod_saved else None
if lastmod_saved == lastmod:
print(f'Index for {file} is still valid', file=sys.stderr)
continue
if lastmod_saved:
print(f'Deleting stale index for {file}', file=sys.stderr)
await db.execute('DELETE FROM files WHERE file = ?', (file,))
with open(file) as f:
print('Indexing', file, file=sys.stderr)
await db.execute('INSERT INTO files (path, lastmod) VALUES (?, ?)', (file, lastmod))
fileid = (await (await db.execute('SELECT last_insert_rowid()')).fetchone())[0]
lines = f.readlines()
inputs = [
(fileid, i + 1, line.rstrip())
for i, line in enumerate(lines)
if any(c.isalpha() for c in line)
]
await db.executemany(
'INSERT INTO lines (fileid, lineno, content) VALUES (?, ?, ?)',
inputs
)
await db.commit()
k = os.environ.get('K', 5)
print('Indexing complete. Type one search per line:', file=sys.stderr)
for line in sys.stdin:
line = line.strip()
if not line:
continue
# print(f'Searching for: {line}', file=sys.stderr)
async with search_content(line, k, columns=['path', 'lineno', 'content'], joins=['INNER JOIN files ON fileid = files.rowid']) as cursor:
results = await cursor.fetchall()
cols = [c[0] for c in cursor.description]
for res in [dict(zip(cols, row)) for row in results]:
print(f'{res["path"]}:{res["lineno"]} ({res["distance"]:.2f}): {res["content"]}')
if __name__ == '__main__':
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment