Skip to content

Instantly share code, notes, and snippets.

@thewh1teagle
Last active January 19, 2024 20:18
Show Gist options
  • Select an option

  • Save thewh1teagle/d855ec93f516e1235981c73c47ce86fd to your computer and use it in GitHub Desktop.

Select an option

Save thewh1teagle/d855ec93f516e1235981c73c47ce86fd to your computer and use it in GitHub Desktop.
use anyhow::{Result, Ok};
use heed::{EnvOpenOptions, Env};
use arroy::{Database, Reader, Writer, ItemId};
use arroy::distances::DotProduct;
use rand::SeedableRng;
use rand::rngs::StdRng;
use std::borrow::BorrowMut;
use std::sync::Arc;
pub struct Engine {
db: Database<DotProduct>,
// writer: Writer<DotProduct>,
env: Arc<Env>,
dimension: usize
}
impl Engine {
pub fn try_create(path: &str, dimension: usize) -> Result<Self> {
let env = Arc::new(EnvOpenOptions::new()
.map_size(1024 * 1024 * 1024 * 2) // 2GiB
.open(path)?);
// if the path not found, create, otherwise use it
let env_clone = env.clone();
let mut wtxn = env_clone.write_txn().unwrap();
let db: Database<DotProduct> = env.create_database(&mut wtxn.borrow_mut(), None)?;
Ok(Engine {
db,
env,
dimension
})
}
pub fn insert(&self, id: u32, vector: &[f32]) -> Result<()> {
let mut wtxn = self.env.write_txn().unwrap();
let writer = Writer::<DotProduct>::new(self.db, 0, self.dimension)?;
writer.append_item(&mut wtxn, ItemId::from(id), vector)?;
let mut rng = StdRng::seed_from_u64(0);
writer.build(&mut wtxn, &mut rng, None)?;
wtxn.commit().unwrap();
Ok(())
}
pub fn update(&self, id: u32, vector: &[f32]) -> Result<()> {
let mut wtxn = self.env.write_txn().unwrap();
let writer = Writer::<DotProduct>::new(self.db, 0, self.dimension)?;
self.remove(id).unwrap();
self.insert(id, vector).unwrap();
Ok(())
}
pub fn remove(&self, id: u32) -> Result<()> {
let mut wtxn = self.env.write_txn().unwrap();
let writer = Writer::<DotProduct>::new(self.db, 0, self.dimension)?;
writer.del_item(&mut wtxn, id).unwrap();
let mut rng = StdRng::seed_from_u64(0);
writer.build(&mut wtxn, &mut rng, None)?;
wtxn.commit().unwrap();
Ok(())
}
pub fn find(&self, vector: &[f32]) -> Result<Vec<(u32, f32)>> {
let wtxn = self.env.write_txn().unwrap();
let rtxn = self.env.read_txn().unwrap();
let reader = Reader::open(&wtxn, 0, self.db)?;
let vectors = reader.nns_by_vector(&rtxn, vector, 1, None, None)?;
Ok(vectors)
}
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_DB_PATH: &str = "test_db";
#[test]
fn test_insert_and_search() {
// Create a new Engine
let engine_result = Engine::try_create(TEST_DB_PATH, 3);
let engine = engine_result.unwrap();
// Test insert
let id = 1;
let vector = vec![1.0, 2.0, 3.0];
engine.insert(id, &vector).unwrap();
// Test search
let search_result = engine.find(&vector).unwrap();
// Ensure the search result is not empty
assert!(!search_result.is_empty());
// Check if the inserted ID is present in the search result
let found_id = search_result[0].0;
assert_eq!(found_id, id);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment