Skip to content

Instantly share code, notes, and snippets.

@paulmaunders
Created February 7, 2025 12:06
Show Gist options
  • Save paulmaunders/fe1846498af47dc026e16d347c6c3070 to your computer and use it in GitHub Desktop.
Save paulmaunders/fe1846498af47dc026e16d347c6c3070 to your computer and use it in GitHub Desktop.
Rust sentence transformers demo
[package]
name = "sentence_transformers_demo"
version = "0.1.0"
edition = "2018"
[dependencies]
tch = "0.17.0"
rust-bert = "0.23.0"
anyhow = "1.0"
use rust_bert::pipelines::sentence_embeddings::{SentenceEmbeddingsBuilder, SentenceEmbeddingsModelType};
use tch::{nn::VarStore, Device, Tensor};
use std::convert::TryInto; // Add this line
fn main() -> anyhow::Result<()> {
// Load the pre-trained sentence transformer model
let device = Device::cuda_if_available();
let _vs = VarStore::new(device);
let model = SentenceEmbeddingsBuilder::remote(SentenceEmbeddingsModelType::AllMiniLmL6V2)
.create_model()
.unwrap();
// Define sentences to encode
let sentences = [
"This is a sentence transformer example.",
"Sentence transformers are useful for encoding sentences.",
];
// Encode sentences
let embeddings: Vec<Vec<f32>> = model.encode(&sentences)?;
// Convert embeddings to tch::Tensor
let embeddings_tensor = Tensor::from_slice(&embeddings.concat())
.view((embeddings.len() as i64, embeddings[0].len() as i64));
// Compute similarity (cosine similarity)
let similarity = embeddings_tensor.matmul(&embeddings_tensor.transpose(0, 1));
// Convert similarity tensor to nested Vec for iteration
let similarity_vec: Vec<Vec<f32>> = similarity
.to_kind(tch::Kind::Float)
.try_into()?;
// Print original sentences
println!("Sentences:");
for (i, sentence) in sentences.iter().enumerate() {
println!("Sentence {}: {}", i + 1, sentence);
}
// Print similarity matrix with sentence labels
println!("\nSimilarity matrix:");
for (i, row) in similarity_vec.iter().enumerate() {
for (j, &value) in row.iter().enumerate() {
println!(
"Similarity between '{}' and '{}': {:.4}",
sentences[i],
sentences[j],
value
);
}
}
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment