Skip to content

Instantly share code, notes, and snippets.

@a-agmon
a-agmon / cn9.rs
Created December 21, 2023 09:30
save embed
let results: Vec<Result<Tensor, _>> = sentences
.par_chunks(350)
.map(|chunk| bert_model.create_embeddings(chunk.to_vec()))
.collect();
let embeddings = Tensor::cat(
&results
.iter()
.map(|r| r.as_ref().unwrap())
@a-agmon
a-agmon / cn8.rs
Created December 21, 2023 09:28
embed
//encode
let tokens = self
.tokenizer
.encode_batch(sentences, true)
.map_err(anyhow::Error::msg)?;
//collect
let token_ids = tokens
.iter()
.map(|tokens| {
@a-agmon
a-agmon / cn7.rs
Created December 21, 2023 09:23
find similar function
async fn find_similar(
State(model_ctx): State<Arc<(BertInferenceModel, Vec<String>)>>,
Json(payload): Json<ReqPayload>,
) -> Json<ResPayload> {
let (model, text_map) = &*model_ctx;
let query_vector = model
.infer_sentence_embedding(&payload.text)
.expect("error infering sentence embedding");
let results: Vec<(usize, f32)> = model
.score_vector_similarity(
@a-agmon
a-agmon / cn6.rs
Created December 21, 2023 09:20
start service
let filename = "embeddings.bin";
let embedding_key = "my_embedding";
let bert_model = BertInferenceModel::load(
"sentence-transformers/all-MiniLM-L6-v2",
"refs/pr/21",
filename,
embedding_key,
)?;
let mut text_map_file = File::open("text_map.bin").unwrap();
@a-agmon
a-agmon / cn5.rs
Created December 21, 2023 09:18
score
pub fn score_vector_similarity(
&self,
vector: Tensor,
top_k: usize,
) -> anyhow::Result<Vec<(usize, f32)>> {
let vec_len = self.embeddings.dim(0)?;
let mut scores = vec![(0, 0.0); vec_len];
for (embedding_index, score_tuple) in scores.iter_mut().enumerate() {
let cur_vec = self.embeddings.get(embedding_index)?.unsqueeze(0)?;
// because its normalized we can use cosine similarity
@a-agmon
a-agmon / cn4.rs
Created December 21, 2023 09:17
utility functions
pub fn apply_max_pooling(embeddings: &Tensor)
-> anyhow::Result<Tensor> {
Ok(embeddings.max(1)?)
}
pub fn l2_normalize(embeddings: &Tensor)
-> anyhow::Result<Tensor> {
Ok(embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?)
}
@a-agmon
a-agmon / cn3.rs
Created December 21, 2023 09:15
infer function
pub fn infer_sentence_embedding(&self, sentence: &str)
-> anyhow::Result<Tensor> {
let tokens = self
.tokenizer
.encode(sentence, true)
.map_err(anyhow::Error::msg)?;
let token_ids = Tensor::new(tokens.get_ids(), &self.device)?.unsqueeze(0)?;
let token_type_ids = token_ids.zeros_like()?;
let start = std::time::Instant::now();
let embeddings = self.model.forward(&token_ids, &token_type_ids)?;
@a-agmon
a-agmon / cn2.rs
Created December 21, 2023 09:14
loading embeddings
let embeddings = match embeddings_filename.is_empty() {
true => {
println!("no file name provided");
Tensor::new(&[0.0], &device)?
}
false => {
let tensor_file = safetensors::load(embeddings_filename, &device)?;
tensor_file
.get(embeddings_key)
.expect("error getting key:embedding")
@a-agmon
a-agmon / cn1.rs
Created December 21, 2023 09:11
loading a model
let repo = Repo::with_revision(model_name.parse()?, RepoType::Model, revision.parse()?);
let api = Api::new()?;
let api = api.repo(repo);
let config_filename = api.get("config.json")?;
let tokenizer_filename = api.get("tokenizer.json")?;
let weights_filename = api.get("model.safetensors")?;
// load the model config
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
func (d *DuckDBDriver) Execute(statement string) error {
_, err := d.db.Exec(statement)
if err != nil {
return err
}
return nil
}
func (d *DuckDBDriver) Query(statement string) (*sql.Rows, error) {