Skip to content

Instantly share code, notes, and snippets.

@a-agmon
Created December 21, 2023 09:28
Show Gist options
  • Save a-agmon/e75f4cf1a2108bd05f28379dd43fe460 to your computer and use it in GitHub Desktop.
Save a-agmon/e75f4cf1a2108bd05f28379dd43fe460 to your computer and use it in GitHub Desktop.
embed
//encode
let tokens = self
.tokenizer
.encode_batch(sentences, true)
.map_err(anyhow::Error::msg)?;
//collect
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Ok(Tensor::new(tokens.as_slice(), &self.device)?)
})
.collect::<anyhow::Result<Vec<_>>>()?;
//embed
let token_ids = Tensor::stack(&token_ids, 0)?;
let token_type_ids = token_ids.zeros_like()?;
let embeddings = self.model.forward(&token_ids, &token_type_ids)?;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment