Created
February 7, 2025 12:06
-
-
Save paulmaunders/fe1846498af47dc026e16d347c6c3070 to your computer and use it in GitHub Desktop.
Rust sentence transformers demo
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[package] | |
name = "sentence_transformers_demo" | |
version = "0.1.0" | |
edition = "2018" | |
[dependencies] | |
tch = "0.17.0" | |
rust-bert = "0.23.0" | |
anyhow = "1.0" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use rust_bert::pipelines::sentence_embeddings::{SentenceEmbeddingsBuilder, SentenceEmbeddingsModelType}; | |
use tch::{nn::VarStore, Device, Tensor}; | |
use std::convert::TryInto; // Add this line | |
fn main() -> anyhow::Result<()> { | |
// Load the pre-trained sentence transformer model | |
let device = Device::cuda_if_available(); | |
let _vs = VarStore::new(device); | |
let model = SentenceEmbeddingsBuilder::remote(SentenceEmbeddingsModelType::AllMiniLmL6V2) | |
.create_model() | |
.unwrap(); | |
// Define sentences to encode | |
let sentences = [ | |
"This is a sentence transformer example.", | |
"Sentence transformers are useful for encoding sentences.", | |
]; | |
// Encode sentences | |
let embeddings: Vec<Vec<f32>> = model.encode(&sentences)?; | |
// Convert embeddings to tch::Tensor | |
let embeddings_tensor = Tensor::from_slice(&embeddings.concat()) | |
.view((embeddings.len() as i64, embeddings[0].len() as i64)); | |
// Compute similarity (cosine similarity) | |
let similarity = embeddings_tensor.matmul(&embeddings_tensor.transpose(0, 1)); | |
// Convert similarity tensor to nested Vec for iteration | |
let similarity_vec: Vec<Vec<f32>> = similarity | |
.to_kind(tch::Kind::Float) | |
.try_into()?; | |
// Print original sentences | |
println!("Sentences:"); | |
for (i, sentence) in sentences.iter().enumerate() { | |
println!("Sentence {}: {}", i + 1, sentence); | |
} | |
// Print similarity matrix with sentence labels | |
println!("\nSimilarity matrix:"); | |
for (i, row) in similarity_vec.iter().enumerate() { | |
for (j, &value) in row.iter().enumerate() { | |
println!( | |
"Similarity between '{}' and '{}': {:.4}", | |
sentences[i], | |
sentences[j], | |
value | |
); | |
} | |
} | |
Ok(()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment