Last active
August 2, 2024 13:03
-
-
Save rodydavis/5fbbecd45956230673e7043bae80e495 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import 'dart:typed_data'; | |
import 'package:drift/drift.dart'; | |
import 'package:google_generative_ai/google_generative_ai.dart'; | |
import 'connection/connection.dart' as impl; | |
part 'database.g.dart'; | |
@DriftDatabase(include: {'sql.drift'}) | |
class Database extends _$Database { | |
Database() : super(impl.connect('app.v2')); | |
Database.forTesting(DatabaseConnection super.connection); | |
static Database instance = Database(); | |
final textEmbedder = GenerativeModel( | |
model: 'text-embedding-004', | |
apiKey: const String.fromEnvironment('GOOGLE_AI_API_KEY'), | |
); | |
@override | |
int get schemaVersion => 1; | |
Future<int> addChunk(List<double> vector) async { | |
await customStatement( | |
'INSERT INTO chunks (embedding) VALUES (:embedding)', | |
[serializeFloat32(vector)], | |
); | |
return await getLastId().getSingle(); | |
} | |
Future<void> deleteChunk(int id) async { | |
await customStatement( | |
'DELETE FROM chunks WHERE id = :id', | |
[id], | |
); | |
} | |
} | |
// Serializes a float32 list into a vector BLOB that sqlite-vec accepts. | |
Uint8List serializeFloat32(List<double> vector) { | |
final ByteData byteData = ByteData(vector.length * 4); // 4 bytes per float32 | |
for (int i = 0; i < vector.length; i++) { | |
byteData.setFloat32(i * 4, vector[i], Endian.little); | |
} | |
return byteData.buffer.asUint8List(); | |
} | |
// Split long text into chunks for embedding | |
Iterable<(String, int, int)> chunkText(String text) sync* { | |
final regex = RegExp(r'((?:[^\n][\n]?)+)'); | |
final matches = regex.allMatches(text); | |
for (final match in matches) { | |
// Need to limit to 500 tokens for really long paragraphs | |
final str = text.substring(match.start, match.end); | |
yield (str, match.start, match.end); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import 'dart:ffi'; | |
import 'dart:io'; | |
import 'dart:typed_data'; | |
import 'package:google_generative_ai/google_generative_ai.dart'; | |
import 'package:sqlite3/open.dart'; | |
import 'package:sqlite3/sqlite3.dart'; | |
void main(List<String> args) async { | |
final inDir = Directory('./files'); | |
final db = _dbOpen(); | |
db.execute( | |
'create table files ( ' | |
' id integer primary key AUTOINCREMENT, ' | |
' path TEXT NOT NULL, ' | |
' content TEXT ' | |
');', | |
); | |
db.execute( | |
'create table file_embeddings ( ' | |
' id TEXT primary key, ' | |
' file_id INTEGER NOT NULL, ' | |
' chunk_id INTEGER NOT NULL, ' | |
' start INTEGER, ' | |
' end INTEGER ' | |
');', | |
); | |
db.execute( | |
'create virtual table chunks using vec0( ' | |
' id integer primary key AUTOINCREMENT, ' | |
' embedding float[768] ' | |
');', | |
); | |
const googleAIApiKey = String.fromEnvironment('GOOGLE_AI_API_KEY'); | |
final model = GenerativeModel( | |
model: 'text-embedding-004', | |
apiKey: googleAIApiKey, | |
); | |
for (final file in inDir.listSync(recursive: true).whereType<File>()) { | |
final str = file.readAsStringSync(); | |
db.select( | |
'insert into files(path, content) values (?, ?)', | |
[ | |
file.path, | |
str, | |
], | |
); | |
final fileId = db.lastInsertRowId; | |
final chunks = _chunkText(str); | |
final futures = <Future Function()>[]; | |
for (final match in chunks) { | |
final chunk = match.$1; | |
final content = Content.text(chunk); | |
futures.add(() async { | |
final result = await model.embedContent( | |
content, | |
title: file.path.split('/').last.split('.').first, | |
taskType: TaskType.retrievalDocument, | |
); | |
db.execute( | |
'insert into chunks(embedding) values (?)', | |
[ | |
serializeFloat32(result.embedding.values), | |
], | |
); | |
final chunkId = db.lastInsertRowId; | |
db.execute( | |
'insert into file_embeddings(file_id, chunk_id, start, end) values (?, ?, ?, ?)', | |
[ | |
fileId, | |
chunkId, | |
match.$2, | |
match.$3, | |
], | |
); | |
}); | |
} | |
await Future.wait(futures.map((call) => call())); | |
} | |
final queryResult = await model.embedContent( | |
Content.text('extension'), | |
// RETRIEVAL_QUERY Specifies the given text is a query in a search/retrieval setting. | |
// RETRIEVAL_DOCUMENT Specifies the given text is a document in a search/retrieval setting. | |
// SEMANTIC_SIMILARITY Specifies the given text will be used for Semantic Textual Similarity (STS). | |
// CLASSIFICATION Specifies that the embeddings will be used for classification. | |
// CLUSTERING Specifies that the embeddings will be used for clustering. | |
taskType: TaskType.retrievalQuery, | |
); | |
// print(result.embedding.values); | |
final rows = db.select( | |
'select files.path, file_embeddings.start, file_embeddings.end, chunks.distance, files.content from chunks ' | |
'left join file_embeddings on file_embeddings.chunk_id = chunks.id ' | |
'left join files on files.id = file_embeddings.file_id ' | |
"where embedding match ? and k = 20", | |
[ | |
serializeFloat32(queryResult.embedding.values), | |
], | |
); | |
for (final row in rows) { | |
final filename = row['path']; | |
final distance = row['distance']; | |
final start = row['start']; | |
final end = row['end']; | |
// final content = row['content']; | |
print((filename, distance, start, end)); | |
} | |
} | |
Database _dbOpen() { | |
// Open version of sqlite that can load extensions on MacOS | |
final lib = DynamicLibrary.open('/opt/homebrew/opt/sqlite/bin/sqlite3'); | |
open.overrideFor(OperatingSystem.macOS, () => lib); | |
// Load the sqlite-vec extension | |
final vec0 = DynamicLibrary.open('./vec0.dylib'); | |
sqlite3.ensureExtensionLoaded( | |
SqliteExtension.inLibrary(vec0, 'sqlite3_vec_init'), | |
); | |
// Create an in-memory database | |
final db = sqlite3.openInMemory(); | |
// Print out the current sqlite version | |
print('Using sqlite3 ${sqlite3.version}'); | |
// Return the database instance | |
return db; | |
} | |
// Serializes a float32 list into a vector BLOB that sqlite-vec accepts. | |
Uint8List serializeFloat32(List<double> vector) { | |
final ByteData byteData = ByteData(vector.length * 4); // 4 bytes per float32 | |
for (int i = 0; i < vector.length; i++) { | |
byteData.setFloat32(i * 4, vector[i], Endian.little); | |
} | |
return byteData.buffer.asUint8List(); | |
} | |
// Split long text into chunks for embedding | |
Iterable<(String, int, int)> _chunkText(String text) sync* { | |
final regex = RegExp(r'((?:[^\n][\n]?)+)'); | |
final matches = regex.allMatches(text); | |
for (final match in matches) { | |
// Need to limit to 500 tokens for really long paragraphs | |
final str = text.substring(match.start, match.end); | |
yield (str, match.start, match.end); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
CREATE TABLE files ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
path TEXT NOT NULL, | |
content TEXT | |
); | |
CREATE TABLE file_embeddings ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
file_id INTEGER NOT NULL, | |
chunk_id INTEGER NOT NULL, | |
[start] INTEGER, | |
[end] INTEGER | |
); | |
-- CREATE VIRTUAL TABLE chunks using vec0( | |
-- id INTEGER PRIMARY KEY AUTOINCREMENT, | |
-- embedding float[768] | |
-- ); | |
getFiles: | |
SELECT * FROM files; | |
getFileById: | |
SELECT * FROM files | |
WHERE id = :id; | |
getFileByPath: | |
SELECT * FROM files | |
WHERE path = :path; | |
insertFile: | |
INSERT INTO files (path, content) VALUES (:path, :content) | |
RETURNING *; | |
deleteFileById: | |
DELETE FROM files | |
WHERE id = :id; | |
getFileEmbeddingsByFileId: | |
SELECT * FROM file_embeddings | |
WHERE file_id = :fileId; | |
deleteFileEmbeddingByFileId: | |
DELETE FROM file_embeddings | |
WHERE file_id = :id; | |
-- insertChunk(:embedding AS BLOB): | |
-- INSERT INTO chunks (embedding) VALUES (:embedding); | |
getLastId: | |
SELECT last_insert_rowid(); | |
insertFileEmbedding: | |
INSERT INTO file_embeddings (file_id, chunk_id, [start], [end]) | |
VALUES (:fileId, :chunkId, :start, :end); | |
searchEmbeddings(:embedding AS BLOB): | |
SELECT | |
files.path, | |
file_embeddings.start, | |
file_embeddings.end, | |
CAST(chunks.distance AS REAL) as distance, | |
files.content | |
FROM chunks | |
LEFT JOIN file_embeddings ON file_embeddings.chunk_id = chunks.id | |
LEFT JOIN files ON files.id = file_embeddings.file_id | |
WHERE embedding MATCH :embedding AND k = 20; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment