Skip to content

Instantly share code, notes, and snippets.

@CoffeeVampir3
Created May 30, 2024 01:41
Show Gist options
  • Save CoffeeVampir3/b66ee4e9695e7daae4f328edb599c1d3 to your computer and use it in GitHub Desktop.
Save CoffeeVampir3/b66ee4e9695e7daae4f328edb599c1d3 to your computer and use it in GitHub Desktop.
Rust Candle Inference Examples
use std::io::Write;
use tokenizers::Tokenizer;
use candle_core::quantized::{gguf_file};
use candle_core::Device;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::models::quantized_llama as model;
use model::ModelWeights;
use burnt_wick::streamable_model::StreamableModel;
fn load_model_and_tokenizer(
model_path: &str,
tokenizer_path: &str,
device: &Device,
) -> Result<(ModelWeights, Tokenizer), Box<dyn std::error::Error>> {
let mut file = std::fs::File::open(model_path)?;
let model_content = gguf_file::Content::read(&mut file)?;
let tokenizer = Tokenizer::from_file(tokenizer_path).unwrap();
let model = ModelWeights::from_gguf(model_content, &mut file, device)?;
Ok((model, tokenizer))
}
fn main() -> anyhow::Result<()> {
let model_path = "./Meta-Llama-3-8B-Instruct-Q4_K_M.gguf";
let tokenizer_path = "./tokenizer.json";
let logits_processor = LogitsProcessor::from_sampling(0, Sampling::ArgMax);
let device = Device::new_cuda(0)?;
let (model, tokenizer) = match load_model_and_tokenizer(model_path, tokenizer_path, &device) {
Ok((model, tokenizer)) => (model, tokenizer),
Err(e) => {
println!("Error: {}", e);
return Ok(());
}
};
let mut model_generator = StreamableModel::new(model, tokenizer, logits_processor, device);
let mut stream = model_generator.get_stream_handle();
let prompt_str = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a storywriter. Write whatever the user asks for.<|eot_id|><|start_header_id|>user<|end_header_id|>
Write a story about a cute girl who finds an enchanted meadow.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n";
let mut keep_going = true;
stream.begin_stream(prompt_str);
while keep_going {
keep_going = stream.stream_step();
if let Some(t) = stream.next_token()? {
print!("{t}");
std::io::stdout().flush()?;
keep_going = true;
}
}
print!("\n");
std::io::stdout().flush()?;
Ok(())
}
use candle_transformers::models::quantized_llama as model;
use candle_transformers::generation::{LogitsProcessor};
use candle_core::Tensor;
use candle_core::Device;
pub struct StreamableModel {
model: model::ModelWeights,
tokenizer: tokenizers::Tokenizer,
sampler: LogitsProcessor,
device: Device,
}
pub struct ModelStreamHandle<'a> {
model: &'a mut model::ModelWeights,
tokenizer: &'a tokenizers::Tokenizer,
sampler: &'a mut LogitsProcessor,
device: &'a Device,
tokens: Vec<u32>,
prev_index: usize,
current_index: usize,
prompt_index: usize,
iteration: usize,
end_tokens: Vec<u32>
}
impl StreamableModel {
pub fn new(model: model::ModelWeights, tokenizer: tokenizers::Tokenizer, sampler: LogitsProcessor, device: Device) -> Self {
Self {
model,
tokenizer,
sampler,
device
}
}
pub fn get_stream_handle(&mut self) -> ModelStreamHandle {
let eos_token = "<|end_of_text|>";
let eot_token = "<|eot_id|>";
let eos_token = *self.tokenizer.get_vocab(true).get(eos_token).unwrap();
let eot_token = *self.tokenizer.get_vocab(true).get(eot_token).unwrap();
let ending_tokens = vec![eos_token, eot_token];
ModelStreamHandle {
model: &mut self.model,
tokenizer: &mut self.tokenizer,
sampler: &mut self.sampler,
device: &mut self.device,
tokens: Vec::new(),
prev_index: 0,
current_index: 0,
prompt_index: 0,
iteration: 0,
end_tokens: ending_tokens,
}
}
}
impl ModelStreamHandle<'_> {
fn ingest_prompt(&mut self, prompt: &str) -> (usize, Tensor) {
let tokens = self.tokenizer.encode(prompt, true).unwrap();
let prompt_tokens = tokens.get_ids();
let input = match Tensor::new(prompt_tokens, &self.device) {
Ok(item) => item.unsqueeze(0).unwrap(),
Err(_) => todo!(),
};
let logits = match self.model.forward(&input, 0) {
Ok(output) => output,
Err(_) => todo!(),
};
let logits = logits.squeeze(0).unwrap();
(prompt_tokens.len(), logits)
}
pub fn begin_stream(&mut self, prompt: &str) {
let (prompt_len, logits) = self.ingest_prompt(prompt);
let next_token = self.sampler.sample(&logits).unwrap();
self.prompt_index = prompt_len;
self.tokens.push(next_token);
}
pub fn stream_step(&mut self) -> bool {
let terminal_token = self.tokens.last().copied().unwrap();
let input = Tensor::new(&[terminal_token], &self.device).unwrap().unsqueeze(0).unwrap();
let logits = self.model.forward(&input, self.prompt_index + self.iteration).unwrap();
let logits = logits.squeeze(0).unwrap();
let next_token = self.sampler.sample(&logits).unwrap();
if self.end_tokens.contains(&next_token) {
return false
};
self.tokens.push(next_token);
self.iteration += 1;
return true
}
pub fn next_token(&mut self) -> anyhow::Result<Option<String>> {
if self.current_index == self.tokens.len() && self.prev_index == self.current_index {
return Ok(None)
}
let prev_text = if self.tokens.is_empty() {
String::new()
} else {
let tokens = &self.tokens[self.prev_index..self.current_index];
self.tokenizer.decode(tokens, true).unwrap()
};
self.prev_index = self.current_index;
self.current_index = self.tokens.len();
return Ok(Some(prev_text))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment