Created
May 30, 2024 01:41
-
-
Save CoffeeVampir3/b66ee4e9695e7daae4f328edb599c1d3 to your computer and use it in GitHub Desktop.
Rust Candle Inference Examples
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::io::Write; | |
use tokenizers::Tokenizer; | |
use candle_core::quantized::{gguf_file}; | |
use candle_core::Device; | |
use candle_transformers::generation::{LogitsProcessor, Sampling}; | |
use candle_transformers::models::quantized_llama as model; | |
use model::ModelWeights; | |
use burnt_wick::streamable_model::StreamableModel; | |
fn load_model_and_tokenizer( | |
model_path: &str, | |
tokenizer_path: &str, | |
device: &Device, | |
) -> Result<(ModelWeights, Tokenizer), Box<dyn std::error::Error>> { | |
let mut file = std::fs::File::open(model_path)?; | |
let model_content = gguf_file::Content::read(&mut file)?; | |
let tokenizer = Tokenizer::from_file(tokenizer_path).unwrap(); | |
let model = ModelWeights::from_gguf(model_content, &mut file, device)?; | |
Ok((model, tokenizer)) | |
} | |
fn main() -> anyhow::Result<()> { | |
let model_path = "./Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"; | |
let tokenizer_path = "./tokenizer.json"; | |
let logits_processor = LogitsProcessor::from_sampling(0, Sampling::ArgMax); | |
let device = Device::new_cuda(0)?; | |
let (model, tokenizer) = match load_model_and_tokenizer(model_path, tokenizer_path, &device) { | |
Ok((model, tokenizer)) => (model, tokenizer), | |
Err(e) => { | |
println!("Error: {}", e); | |
return Ok(()); | |
} | |
}; | |
let mut model_generator = StreamableModel::new(model, tokenizer, logits_processor, device); | |
let mut stream = model_generator.get_stream_handle(); | |
let prompt_str = "<|begin_of_text|><|start_header_id|>system<|end_header_id|> | |
You are a storywriter. Write whatever the user asks for.<|eot_id|><|start_header_id|>user<|end_header_id|> | |
Write a story about a cute girl who finds an enchanted meadow.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"; | |
let mut keep_going = true; | |
stream.begin_stream(prompt_str); | |
while keep_going { | |
keep_going = stream.stream_step(); | |
if let Some(t) = stream.next_token()? { | |
print!("{t}"); | |
std::io::stdout().flush()?; | |
keep_going = true; | |
} | |
} | |
print!("\n"); | |
std::io::stdout().flush()?; | |
Ok(()) | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use candle_transformers::models::quantized_llama as model; | |
use candle_transformers::generation::{LogitsProcessor}; | |
use candle_core::Tensor; | |
use candle_core::Device; | |
pub struct StreamableModel { | |
model: model::ModelWeights, | |
tokenizer: tokenizers::Tokenizer, | |
sampler: LogitsProcessor, | |
device: Device, | |
} | |
pub struct ModelStreamHandle<'a> { | |
model: &'a mut model::ModelWeights, | |
tokenizer: &'a tokenizers::Tokenizer, | |
sampler: &'a mut LogitsProcessor, | |
device: &'a Device, | |
tokens: Vec<u32>, | |
prev_index: usize, | |
current_index: usize, | |
prompt_index: usize, | |
iteration: usize, | |
end_tokens: Vec<u32> | |
} | |
impl StreamableModel { | |
pub fn new(model: model::ModelWeights, tokenizer: tokenizers::Tokenizer, sampler: LogitsProcessor, device: Device) -> Self { | |
Self { | |
model, | |
tokenizer, | |
sampler, | |
device | |
} | |
} | |
pub fn get_stream_handle(&mut self) -> ModelStreamHandle { | |
let eos_token = "<|end_of_text|>"; | |
let eot_token = "<|eot_id|>"; | |
let eos_token = *self.tokenizer.get_vocab(true).get(eos_token).unwrap(); | |
let eot_token = *self.tokenizer.get_vocab(true).get(eot_token).unwrap(); | |
let ending_tokens = vec![eos_token, eot_token]; | |
ModelStreamHandle { | |
model: &mut self.model, | |
tokenizer: &mut self.tokenizer, | |
sampler: &mut self.sampler, | |
device: &mut self.device, | |
tokens: Vec::new(), | |
prev_index: 0, | |
current_index: 0, | |
prompt_index: 0, | |
iteration: 0, | |
end_tokens: ending_tokens, | |
} | |
} | |
} | |
impl ModelStreamHandle<'_> { | |
fn ingest_prompt(&mut self, prompt: &str) -> (usize, Tensor) { | |
let tokens = self.tokenizer.encode(prompt, true).unwrap(); | |
let prompt_tokens = tokens.get_ids(); | |
let input = match Tensor::new(prompt_tokens, &self.device) { | |
Ok(item) => item.unsqueeze(0).unwrap(), | |
Err(_) => todo!(), | |
}; | |
let logits = match self.model.forward(&input, 0) { | |
Ok(output) => output, | |
Err(_) => todo!(), | |
}; | |
let logits = logits.squeeze(0).unwrap(); | |
(prompt_tokens.len(), logits) | |
} | |
pub fn begin_stream(&mut self, prompt: &str) { | |
let (prompt_len, logits) = self.ingest_prompt(prompt); | |
let next_token = self.sampler.sample(&logits).unwrap(); | |
self.prompt_index = prompt_len; | |
self.tokens.push(next_token); | |
} | |
pub fn stream_step(&mut self) -> bool { | |
let terminal_token = self.tokens.last().copied().unwrap(); | |
let input = Tensor::new(&[terminal_token], &self.device).unwrap().unsqueeze(0).unwrap(); | |
let logits = self.model.forward(&input, self.prompt_index + self.iteration).unwrap(); | |
let logits = logits.squeeze(0).unwrap(); | |
let next_token = self.sampler.sample(&logits).unwrap(); | |
if self.end_tokens.contains(&next_token) { | |
return false | |
}; | |
self.tokens.push(next_token); | |
self.iteration += 1; | |
return true | |
} | |
pub fn next_token(&mut self) -> anyhow::Result<Option<String>> { | |
if self.current_index == self.tokens.len() && self.prev_index == self.current_index { | |
return Ok(None) | |
} | |
let prev_text = if self.tokens.is_empty() { | |
String::new() | |
} else { | |
let tokens = &self.tokens[self.prev_index..self.current_index]; | |
self.tokenizer.decode(tokens, true).unwrap() | |
}; | |
self.prev_index = self.current_index; | |
self.current_index = self.tokens.len(); | |
return Ok(Some(prev_text)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment