Last active
January 19, 2024 20:18
-
-
Save thewh1teagle/d855ec93f516e1235981c73c47ce86fd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use anyhow::{Result, Ok}; | |
| use heed::{EnvOpenOptions, Env}; | |
| use arroy::{Database, Reader, Writer, ItemId}; | |
| use arroy::distances::DotProduct; | |
| use rand::SeedableRng; | |
| use rand::rngs::StdRng; | |
| use std::borrow::BorrowMut; | |
| use std::sync::Arc; | |
| pub struct Engine { | |
| db: Database<DotProduct>, | |
| // writer: Writer<DotProduct>, | |
| env: Arc<Env>, | |
| dimension: usize | |
| } | |
| impl Engine { | |
| pub fn try_create(path: &str, dimension: usize) -> Result<Self> { | |
| let env = Arc::new(EnvOpenOptions::new() | |
| .map_size(1024 * 1024 * 1024 * 2) // 2GiB | |
| .open(path)?); | |
| // if the path not found, create, otherwise use it | |
| let env_clone = env.clone(); | |
| let mut wtxn = env_clone.write_txn().unwrap(); | |
| let db: Database<DotProduct> = env.create_database(&mut wtxn.borrow_mut(), None)?; | |
| Ok(Engine { | |
| db, | |
| env, | |
| dimension | |
| }) | |
| } | |
| pub fn insert(&self, id: u32, vector: &[f32]) -> Result<()> { | |
| let mut wtxn = self.env.write_txn().unwrap(); | |
| let writer = Writer::<DotProduct>::new(self.db, 0, self.dimension)?; | |
| writer.append_item(&mut wtxn, ItemId::from(id), vector)?; | |
| let mut rng = StdRng::seed_from_u64(0); | |
| writer.build(&mut wtxn, &mut rng, None)?; | |
| wtxn.commit().unwrap(); | |
| Ok(()) | |
| } | |
| pub fn update(&self, id: u32, vector: &[f32]) -> Result<()> { | |
| let mut wtxn = self.env.write_txn().unwrap(); | |
| let writer = Writer::<DotProduct>::new(self.db, 0, self.dimension)?; | |
| self.remove(id).unwrap(); | |
| self.insert(id, vector).unwrap(); | |
| Ok(()) | |
| } | |
| pub fn remove(&self, id: u32) -> Result<()> { | |
| let mut wtxn = self.env.write_txn().unwrap(); | |
| let writer = Writer::<DotProduct>::new(self.db, 0, self.dimension)?; | |
| writer.del_item(&mut wtxn, id).unwrap(); | |
| let mut rng = StdRng::seed_from_u64(0); | |
| writer.build(&mut wtxn, &mut rng, None)?; | |
| wtxn.commit().unwrap(); | |
| Ok(()) | |
| } | |
| pub fn find(&self, vector: &[f32]) -> Result<Vec<(u32, f32)>> { | |
| let wtxn = self.env.write_txn().unwrap(); | |
| let rtxn = self.env.read_txn().unwrap(); | |
| let reader = Reader::open(&wtxn, 0, self.db)?; | |
| let vectors = reader.nns_by_vector(&rtxn, vector, 1, None, None)?; | |
| Ok(vectors) | |
| } | |
| } | |
| #[cfg(test)] | |
| mod tests { | |
| use super::*; | |
| const TEST_DB_PATH: &str = "test_db"; | |
| #[test] | |
| fn test_insert_and_search() { | |
| // Create a new Engine | |
| let engine_result = Engine::try_create(TEST_DB_PATH, 3); | |
| let engine = engine_result.unwrap(); | |
| // Test insert | |
| let id = 1; | |
| let vector = vec![1.0, 2.0, 3.0]; | |
| engine.insert(id, &vector).unwrap(); | |
| // Test search | |
| let search_result = engine.find(&vector).unwrap(); | |
| // Ensure the search result is not empty | |
| assert!(!search_result.is_empty()); | |
| // Check if the inserted ID is present in the search result | |
| let found_id = search_result[0].0; | |
| assert_eq!(found_id, id); | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment