Skip to content

Instantly share code, notes, and snippets.

@a-agmon
Created December 21, 2023 09:15
Show Gist options
  • Save a-agmon/acc45bce8c34e2c5a7375d101fb35a68 to your computer and use it in GitHub Desktop.
Save a-agmon/acc45bce8c34e2c5a7375d101fb35a68 to your computer and use it in GitHub Desktop.
infer function
pub fn infer_sentence_embedding(&self, sentence: &str)
-> anyhow::Result<Tensor> {
let tokens = self
.tokenizer
.encode(sentence, true)
.map_err(anyhow::Error::msg)?;
let token_ids = Tensor::new(tokens.get_ids(), &self.device)?.unsqueeze(0)?;
let token_type_ids = token_ids.zeros_like()?;
let start = std::time::Instant::now();
let embeddings = self.model.forward(&token_ids, &token_type_ids)?;
println!("time taken for forward: {:?}", start.elapsed());
println!("embeddings: {:?}", embeddings);
let embeddings = Self::apply_max_pooling(&embeddings)?;
println!("embeddings after pooling: {:?}", embeddings);
let embeddings = Self::l2_normalize(&embeddings)?;
Ok(embeddings)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment