Skip to content

Instantly share code, notes, and snippets.

@a-agmon
a-agmon / cn2.rs
Created December 21, 2023 09:14
loading embeddings
let embeddings = match embeddings_filename.is_empty() {
true => {
println!("no file name provided");
Tensor::new(&[0.0], &device)?
}
false => {
let tensor_file = safetensors::load(embeddings_filename, &device)?;
tensor_file
.get(embeddings_key)
.expect("error getting key:embedding")
@a-agmon
a-agmon / cn3.rs
Created December 21, 2023 09:15
infer function
pub fn infer_sentence_embedding(&self, sentence: &str)
-> anyhow::Result<Tensor> {
let tokens = self
.tokenizer
.encode(sentence, true)
.map_err(anyhow::Error::msg)?;
let token_ids = Tensor::new(tokens.get_ids(), &self.device)?.unsqueeze(0)?;
let token_type_ids = token_ids.zeros_like()?;
let start = std::time::Instant::now();
let embeddings = self.model.forward(&token_ids, &token_type_ids)?;
@a-agmon
a-agmon / cn4.rs
Created December 21, 2023 09:17
utility functions
pub fn apply_max_pooling(embeddings: &Tensor)
-> anyhow::Result<Tensor> {
Ok(embeddings.max(1)?)
}
pub fn l2_normalize(embeddings: &Tensor)
-> anyhow::Result<Tensor> {
Ok(embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?)
}
@a-agmon
a-agmon / cn5.rs
Created December 21, 2023 09:18
score
pub fn score_vector_similarity(
&self,
vector: Tensor,
top_k: usize,
) -> anyhow::Result<Vec<(usize, f32)>> {
let vec_len = self.embeddings.dim(0)?;
let mut scores = vec![(0, 0.0); vec_len];
for (embedding_index, score_tuple) in scores.iter_mut().enumerate() {
let cur_vec = self.embeddings.get(embedding_index)?.unsqueeze(0)?;
// because its normalized we can use cosine similarity
@a-agmon
a-agmon / cn6.rs
Created December 21, 2023 09:20
start service
let filename = "embeddings.bin";
let embedding_key = "my_embedding";
let bert_model = BertInferenceModel::load(
"sentence-transformers/all-MiniLM-L6-v2",
"refs/pr/21",
filename,
embedding_key,
)?;
let mut text_map_file = File::open("text_map.bin").unwrap();
@a-agmon
a-agmon / cn7.rs
Created December 21, 2023 09:23
find similar function
async fn find_similar(
State(model_ctx): State<Arc<(BertInferenceModel, Vec<String>)>>,
Json(payload): Json<ReqPayload>,
) -> Json<ResPayload> {
let (model, text_map) = &*model_ctx;
let query_vector = model
.infer_sentence_embedding(&payload.text)
.expect("error infering sentence embedding");
let results: Vec<(usize, f32)> = model
.score_vector_similarity(
@a-agmon
a-agmon / cn8.rs
Created December 21, 2023 09:28
embed
//encode
let tokens = self
.tokenizer
.encode_batch(sentences, true)
.map_err(anyhow::Error::msg)?;
//collect
let token_ids = tokens
.iter()
.map(|tokens| {
@a-agmon
a-agmon / cn9.rs
Created December 21, 2023 09:30
save embed
let results: Vec<Result<Tensor, _>> = sentences
.par_chunks(350)
.map(|chunk| bert_model.create_embeddings(chunk.to_vec()))
.collect();
let embeddings = Tensor::cat(
&results
.iter()
.map(|r| r.as_ref().unwrap())
@a-agmon
a-agmon / icebergrusttbl.rs
Created March 15, 2024 04:02
iceberg rust static table v1
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
use arrow::array::{array, Scalar, StringArray};
use parquet::arrow::arrow_reader::{ArrowPredicateFn, ParquetRecordBatchReaderBuilder, RowFilter};
use parquet::arrow::async_reader::ParquetRecordBatchStreamBuilder;
use parquet::arrow::ProjectionMask;
use parquet::schema::types::SchemaDescriptor;
use tokio::fs::File;
use tokio_stream::StreamExt;
#[tokio::main]
async fn main() -> anyhow::Result<()> {