Skip to content

Instantly share code, notes, and snippets.

@rodydavis
Last active August 2, 2024 13:03
Show Gist options
  • Save rodydavis/5fbbecd45956230673e7043bae80e495 to your computer and use it in GitHub Desktop.
Save rodydavis/5fbbecd45956230673e7043bae80e495 to your computer and use it in GitHub Desktop.
import 'dart:typed_data';
import 'package:drift/drift.dart';
import 'package:google_generative_ai/google_generative_ai.dart';
import 'connection/connection.dart' as impl;
part 'database.g.dart';
@DriftDatabase(include: {'sql.drift'})
class Database extends _$Database {
Database() : super(impl.connect('app.v2'));
Database.forTesting(DatabaseConnection super.connection);
static Database instance = Database();
final textEmbedder = GenerativeModel(
model: 'text-embedding-004',
apiKey: const String.fromEnvironment('GOOGLE_AI_API_KEY'),
);
@override
int get schemaVersion => 1;
Future<int> addChunk(List<double> vector) async {
await customStatement(
'INSERT INTO chunks (embedding) VALUES (:embedding)',
[serializeFloat32(vector)],
);
return await getLastId().getSingle();
}
Future<void> deleteChunk(int id) async {
await customStatement(
'DELETE FROM chunks WHERE id = :id',
[id],
);
}
}
// Serializes a float32 list into a vector BLOB that sqlite-vec accepts.
Uint8List serializeFloat32(List<double> vector) {
final ByteData byteData = ByteData(vector.length * 4); // 4 bytes per float32
for (int i = 0; i < vector.length; i++) {
byteData.setFloat32(i * 4, vector[i], Endian.little);
}
return byteData.buffer.asUint8List();
}
// Split long text into chunks for embedding
Iterable<(String, int, int)> chunkText(String text) sync* {
final regex = RegExp(r'((?:[^\n][\n]?)+)');
final matches = regex.allMatches(text);
for (final match in matches) {
// Need to limit to 500 tokens for really long paragraphs
final str = text.substring(match.start, match.end);
yield (str, match.start, match.end);
}
}
import 'dart:ffi';
import 'dart:io';
import 'dart:typed_data';
import 'package:google_generative_ai/google_generative_ai.dart';
import 'package:sqlite3/open.dart';
import 'package:sqlite3/sqlite3.dart';
void main(List<String> args) async {
final inDir = Directory('./files');
final db = _dbOpen();
db.execute(
'create table files ( '
' id integer primary key AUTOINCREMENT, '
' path TEXT NOT NULL, '
' content TEXT '
');',
);
db.execute(
'create table file_embeddings ( '
' id TEXT primary key, '
' file_id INTEGER NOT NULL, '
' chunk_id INTEGER NOT NULL, '
' start INTEGER, '
' end INTEGER '
');',
);
db.execute(
'create virtual table chunks using vec0( '
' id integer primary key AUTOINCREMENT, '
' embedding float[768] '
');',
);
const googleAIApiKey = String.fromEnvironment('GOOGLE_AI_API_KEY');
final model = GenerativeModel(
model: 'text-embedding-004',
apiKey: googleAIApiKey,
);
for (final file in inDir.listSync(recursive: true).whereType<File>()) {
final str = file.readAsStringSync();
db.select(
'insert into files(path, content) values (?, ?)',
[
file.path,
str,
],
);
final fileId = db.lastInsertRowId;
final chunks = _chunkText(str);
final futures = <Future Function()>[];
for (final match in chunks) {
final chunk = match.$1;
final content = Content.text(chunk);
futures.add(() async {
final result = await model.embedContent(
content,
title: file.path.split('/').last.split('.').first,
taskType: TaskType.retrievalDocument,
);
db.execute(
'insert into chunks(embedding) values (?)',
[
serializeFloat32(result.embedding.values),
],
);
final chunkId = db.lastInsertRowId;
db.execute(
'insert into file_embeddings(file_id, chunk_id, start, end) values (?, ?, ?, ?)',
[
fileId,
chunkId,
match.$2,
match.$3,
],
);
});
}
await Future.wait(futures.map((call) => call()));
}
final queryResult = await model.embedContent(
Content.text('extension'),
// RETRIEVAL_QUERY Specifies the given text is a query in a search/retrieval setting.
// RETRIEVAL_DOCUMENT Specifies the given text is a document in a search/retrieval setting.
// SEMANTIC_SIMILARITY Specifies the given text will be used for Semantic Textual Similarity (STS).
// CLASSIFICATION Specifies that the embeddings will be used for classification.
// CLUSTERING Specifies that the embeddings will be used for clustering.
taskType: TaskType.retrievalQuery,
);
// print(result.embedding.values);
final rows = db.select(
'select files.path, file_embeddings.start, file_embeddings.end, chunks.distance, files.content from chunks '
'left join file_embeddings on file_embeddings.chunk_id = chunks.id '
'left join files on files.id = file_embeddings.file_id '
"where embedding match ? and k = 20",
[
serializeFloat32(queryResult.embedding.values),
],
);
for (final row in rows) {
final filename = row['path'];
final distance = row['distance'];
final start = row['start'];
final end = row['end'];
// final content = row['content'];
print((filename, distance, start, end));
}
}
Database _dbOpen() {
// Open version of sqlite that can load extensions on MacOS
final lib = DynamicLibrary.open('/opt/homebrew/opt/sqlite/bin/sqlite3');
open.overrideFor(OperatingSystem.macOS, () => lib);
// Load the sqlite-vec extension
final vec0 = DynamicLibrary.open('./vec0.dylib');
sqlite3.ensureExtensionLoaded(
SqliteExtension.inLibrary(vec0, 'sqlite3_vec_init'),
);
// Create an in-memory database
final db = sqlite3.openInMemory();
// Print out the current sqlite version
print('Using sqlite3 ${sqlite3.version}');
// Return the database instance
return db;
}
// Serializes a float32 list into a vector BLOB that sqlite-vec accepts.
Uint8List serializeFloat32(List<double> vector) {
final ByteData byteData = ByteData(vector.length * 4); // 4 bytes per float32
for (int i = 0; i < vector.length; i++) {
byteData.setFloat32(i * 4, vector[i], Endian.little);
}
return byteData.buffer.asUint8List();
}
// Split long text into chunks for embedding
Iterable<(String, int, int)> _chunkText(String text) sync* {
final regex = RegExp(r'((?:[^\n][\n]?)+)');
final matches = regex.allMatches(text);
for (final match in matches) {
// Need to limit to 500 tokens for really long paragraphs
final str = text.substring(match.start, match.end);
yield (str, match.start, match.end);
}
}
CREATE TABLE files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
path TEXT NOT NULL,
content TEXT
);
CREATE TABLE file_embeddings (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_id INTEGER NOT NULL,
chunk_id INTEGER NOT NULL,
[start] INTEGER,
[end] INTEGER
);
-- CREATE VIRTUAL TABLE chunks using vec0(
-- id INTEGER PRIMARY KEY AUTOINCREMENT,
-- embedding float[768]
-- );
getFiles:
SELECT * FROM files;
getFileById:
SELECT * FROM files
WHERE id = :id;
getFileByPath:
SELECT * FROM files
WHERE path = :path;
insertFile:
INSERT INTO files (path, content) VALUES (:path, :content)
RETURNING *;
deleteFileById:
DELETE FROM files
WHERE id = :id;
getFileEmbeddingsByFileId:
SELECT * FROM file_embeddings
WHERE file_id = :fileId;
deleteFileEmbeddingByFileId:
DELETE FROM file_embeddings
WHERE file_id = :id;
-- insertChunk(:embedding AS BLOB):
-- INSERT INTO chunks (embedding) VALUES (:embedding);
getLastId:
SELECT last_insert_rowid();
insertFileEmbedding:
INSERT INTO file_embeddings (file_id, chunk_id, [start], [end])
VALUES (:fileId, :chunkId, :start, :end);
searchEmbeddings(:embedding AS BLOB):
SELECT
files.path,
file_embeddings.start,
file_embeddings.end,
CAST(chunks.distance AS REAL) as distance,
files.content
FROM chunks
LEFT JOIN file_embeddings ON file_embeddings.chunk_id = chunks.id
LEFT JOIN files ON files.id = file_embeddings.file_id
WHERE embedding MATCH :embedding AND k = 20;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment