Last active
July 22, 2024 20:53
-
-
Save CoffeeVampir3/ea399eb6aace7bc8261aa6592def4189 to your computer and use it in GitHub Desktop.
candle model stream
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use candle_transformers::models::quantized_llama as model; | |
use candle_transformers::generation::{LogitsProcessor, Sampling}; | |
use candle_core::quantized::{gguf_file}; | |
use candle_core::Tensor; | |
pub use candle_core::Device; | |
pub use tokenizers::Tokenizer; | |
pub struct StreamableModel { | |
model: model::ModelWeights, | |
tokenizer: tokenizers::Tokenizer, | |
sampler: LogitsProcessor, | |
device: Device, | |
tokens: Vec<u32>, | |
prev_index: usize, | |
current_index: usize, | |
stream_position: usize, | |
end_tokens: Vec<u32> | |
} | |
impl StreamableModel { | |
pub fn from_paths( | |
model_path: &str, | |
tokenizer_config_path: &str, | |
sampler: LogitsProcessor, | |
device: Device, | |
) -> anyhow::Result<Self> { | |
let mut file = std::fs::File::open(model_path)?; | |
let model_content = gguf_file::Content::read(&mut file)?; | |
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_config_path).unwrap(); | |
let model = model::ModelWeights::from_gguf(model_content, &mut file, &device)?; | |
Ok(StreamableModel::new(model, tokenizer, sampler, device)) | |
} | |
pub fn from_paths_argmax_sampling( | |
model_path: &str, | |
tokenizer_config_path: &str, | |
device: Device, | |
) -> anyhow::Result<Self> { | |
let sampler = LogitsProcessor::from_sampling(0, Sampling::ArgMax); | |
StreamableModel::from_paths(model_path, tokenizer_config_path, sampler, device) | |
} | |
pub fn new(model: model::ModelWeights, tokenizer: tokenizers::Tokenizer, sampler: LogitsProcessor, device: Device) -> Self { | |
let eos_token = "<|end_of_text|>"; | |
let eot_token = "<|eot_id|>"; | |
let eos_token = *tokenizer.get_vocab(true).get(eos_token).unwrap(); | |
let eot_token = *tokenizer.get_vocab(true).get(eot_token).unwrap(); | |
let ending_tokens = vec![eos_token, eot_token]; | |
Self { | |
model, | |
tokenizer, | |
sampler, | |
device, | |
tokens: Vec::new(), | |
prev_index: 0, | |
current_index: 0, | |
stream_position: 0, | |
end_tokens: ending_tokens | |
} | |
} | |
fn ingest_prompt(&mut self, prompt: &str) -> (usize, Tensor) { | |
let tokens = self.tokenizer.encode(prompt, true).unwrap(); | |
let prompt_tokens = tokens.get_ids(); | |
let input = match Tensor::new(prompt_tokens, &self.device) { | |
Ok(item) => item.unsqueeze(0).unwrap(), | |
Err(_) => todo!(), | |
}; | |
let logits = match self.model.forward(&input, 0) { | |
Ok(output) => output, | |
Err(_) => todo!(), | |
}; | |
let logits = logits.squeeze(0).unwrap(); | |
(prompt_tokens.len(), logits) | |
} | |
pub fn begin_stream(&mut self, prompt: &str) { | |
let (prompt_len, logits) = self.ingest_prompt(prompt); | |
let next_token = self.sampler.sample(&logits).unwrap(); | |
self.stream_position = prompt_len; | |
self.tokens.push(next_token); | |
} | |
pub fn stream_step(&mut self) -> bool { | |
let terminal_token = *self.tokens.last().unwrap(); | |
let input = Tensor::new(&[[terminal_token]], &self.device).unwrap(); | |
let logits = self.model.forward(&input, self.stream_position).unwrap(); | |
let logits = logits.squeeze(0).unwrap(); | |
let next_token = self.sampler.sample(&logits).unwrap(); | |
if self.end_tokens.contains(&next_token) { | |
return false | |
}; | |
self.tokens.push(next_token); | |
self.stream_position += 1; | |
return true | |
} | |
pub fn next_token(&mut self) -> anyhow::Result<Option<String>> { | |
if self.tokens.is_empty() || | |
(self.current_index == self.tokens.len() && self.prev_index == self.current_index) | |
{ | |
return Ok(None) | |
} | |
let prev_text = { | |
let tokens = &self.tokens[self.prev_index..self.current_index]; | |
self.tokenizer.decode(tokens, true).unwrap() | |
}; | |
self.prev_index = self.current_index; | |
self.current_index = self.tokens.len(); | |
return Ok(Some(prev_text)) | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::io::{self, Write}; | |
use burnt_wick::streamable_model::StreamableModel; | |
use burnt_wick::streamable_model::Device; | |
use burnt_wick::streamable_model::Tokenizer; | |
use std::collections::HashMap; | |
use anyhow::Result; | |
fn encode_system(tokenizer: &Tokenizer, system_prompt: &str) -> Vec<u32> { | |
let bos_token = tokenizer.token_to_id("<|begin_of_text|>").unwrap(); | |
let eot_token = tokenizer.token_to_id("<|eot_id|>").unwrap(); | |
let mut tokens = vec![bos_token]; | |
tokens.extend(encode_header(tokenizer, "system")); | |
let system_ids = tokenizer.encode(system_prompt, false).unwrap().get_ids().to_vec(); | |
tokens.extend(system_ids); | |
tokens.push(eot_token); | |
tokens | |
} | |
fn encode_message(tokenizer: &Tokenizer, username: &str, message: &str) -> Vec<u32> { | |
let eot_token = tokenizer.token_to_id("<|eot_id|>").unwrap(); | |
let mut tokens = encode_header(tokenizer, username); | |
let encoded = tokenizer.encode(message.trim(), false).unwrap(); | |
tokens.extend(encoded.get_ids()); | |
tokens.push(eot_token); | |
tokens | |
} | |
fn encode_message_no_user(tokenizer: &Tokenizer, message: &str) -> Vec<u32> { | |
let mut tokens = Vec::new(); | |
let eot_token = tokenizer.token_to_id("<|eot_id|>").unwrap(); | |
let encoded = tokenizer.encode(message.trim(), false).unwrap(); | |
tokens.extend(encoded.get_ids()); | |
tokens.push(eot_token); | |
tokens | |
} | |
fn encode_header(tokenizer: &Tokenizer, username: &str) -> Vec<u32> { | |
let start_header = tokenizer.token_to_id("<|start_header_id|>").unwrap(); | |
let end_header = tokenizer.token_to_id("<|end_header_id|>").unwrap(); | |
let mut tokens = vec![start_header]; | |
tokens.extend(tokenizer.encode(username, false).unwrap().get_ids().to_vec()); | |
tokens.push(end_header); | |
tokens.extend(tokenizer.encode("\n\n", false).unwrap().get_ids().to_vec()); | |
tokens | |
} | |
fn encode_header_prefilled(tokenizer: &Tokenizer, username: &str, prefill: &str) -> Vec<u32> { | |
let start_header = tokenizer.token_to_id("<|start_header_id|>").unwrap(); | |
let end_header = tokenizer.token_to_id("<|end_header_id|>").unwrap(); | |
let mut tokens = vec![start_header]; | |
tokens.extend(tokenizer.encode(username, false).unwrap().get_ids().to_vec()); | |
tokens.push(end_header); | |
tokens.extend(tokenizer.encode("\n\n", false).unwrap().get_ids().to_vec()); | |
tokens.extend(tokenizer.encode(prefill, false).unwrap().get_ids().to_vec()); | |
tokens | |
} | |
fn main() -> Result<()> { | |
let model_path = "./Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"; | |
let tokenizer_path = "./tokenizer.json"; | |
let device = Device::new_cuda(0)?; | |
let mut stream = StreamableModel::from_paths_argmax_sampling(model_path, tokenizer_path, device)?; | |
let tokenizer = Tokenizer::from_file(tokenizer_path).unwrap(); | |
let mut current_conversation: HashMap<usize, Vec<u32>> = HashMap::new(); | |
let system_prompt = "You are a helpful AI assistant. Answer the user's questions to the best of your ability."; | |
let system_tokens = encode_system(&tokenizer, system_prompt); | |
println!("Welcome to the LLM chat! Type 'exit' to end the conversation."); | |
let mut current_turn: usize = 0; | |
loop { | |
print!("User: "); | |
io::stdout().flush()?; | |
//User input | |
let mut user_input = String::new(); | |
io::stdin().read_line(&mut user_input)?; | |
user_input = user_input.trim().to_string(); | |
current_conversation.insert(current_turn, encode_message(&tokenizer, "user", &user_input)); | |
current_turn += 1; | |
let mut keys: Vec<_> = current_conversation.keys().collect(); | |
let mut conv_buf = Vec::new(); | |
keys.sort(); | |
conv_buf.extend(system_tokens.clone()); | |
for (index, key) in keys.iter().enumerate() { | |
if let Some(value) = current_conversation.get(key) { | |
conv_buf.extend(value); | |
if index % 2 == 0 { | |
let assistant_header = encode_header(&tokenizer, "assistant"); | |
conv_buf.extend(assistant_header); | |
} | |
} | |
} | |
let prompt_str = tokenizer.decode(&conv_buf, false).unwrap(); | |
println!("{}", prompt_str); | |
print!("Assistant: "); | |
io::stdout().flush()?; | |
let mut keep_going = true; | |
let mut stream_output = String::new(); | |
stream.begin_stream(&prompt_str); | |
while keep_going { | |
keep_going = stream.stream_step(); | |
if let Some(t) = stream.next_token()? { | |
stream_output.push_str(&t); | |
print!("{}", t); | |
io::stdout().flush()?; | |
keep_going = true; | |
} | |
} | |
let encoded_response = encode_message_no_user(&tokenizer, &stream_output); | |
current_conversation.insert(current_turn, encoded_response); | |
current_turn += 1; | |
println!("\n"); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment