Skip to content

Instantly share code, notes, and snippets.

@a-agmon
Created December 21, 2023 09:18
Show Gist options
  • Save a-agmon/4c477480f05534b8c2f17006f47d019f to your computer and use it in GitHub Desktop.
Save a-agmon/4c477480f05534b8c2f17006f47d019f to your computer and use it in GitHub Desktop.
score
pub fn score_vector_similarity(
&self,
vector: Tensor,
top_k: usize,
) -> anyhow::Result<Vec<(usize, f32)>> {
let vec_len = self.embeddings.dim(0)?;
let mut scores = vec![(0, 0.0); vec_len];
for (embedding_index, score_tuple) in scores.iter_mut().enumerate() {
let cur_vec = self.embeddings.get(embedding_index)?.unsqueeze(0)?;
// because its normalized we can use cosine similarity
let cosine_similarity = (&cur_vec * &vector)?.sum_all()?.to_scalar::<f32>()?;
*score_tuple = (embedding_index, cosine_similarity);
}
// now we want to sort scores by cosine_similarity
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
// just return the top k
scores.truncate(top_k);
Ok(scores)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment