Skip to content

Instantly share code, notes, and snippets.

@CoffeeVampir3
Last active July 22, 2024 20:53
Show Gist options
  • Save CoffeeVampir3/ea399eb6aace7bc8261aa6592def4189 to your computer and use it in GitHub Desktop.
Save CoffeeVampir3/ea399eb6aace7bc8261aa6592def4189 to your computer and use it in GitHub Desktop.
candle model stream
use candle_transformers::models::quantized_llama as model;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_core::quantized::{gguf_file};
use candle_core::Tensor;
pub use candle_core::Device;
pub use tokenizers::Tokenizer;
pub struct StreamableModel {
model: model::ModelWeights,
tokenizer: tokenizers::Tokenizer,
sampler: LogitsProcessor,
device: Device,
tokens: Vec<u32>,
prev_index: usize,
current_index: usize,
stream_position: usize,
end_tokens: Vec<u32>
}
impl StreamableModel {
pub fn from_paths(
model_path: &str,
tokenizer_config_path: &str,
sampler: LogitsProcessor,
device: Device,
) -> anyhow::Result<Self> {
let mut file = std::fs::File::open(model_path)?;
let model_content = gguf_file::Content::read(&mut file)?;
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_config_path).unwrap();
let model = model::ModelWeights::from_gguf(model_content, &mut file, &device)?;
Ok(StreamableModel::new(model, tokenizer, sampler, device))
}
pub fn from_paths_argmax_sampling(
model_path: &str,
tokenizer_config_path: &str,
device: Device,
) -> anyhow::Result<Self> {
let sampler = LogitsProcessor::from_sampling(0, Sampling::ArgMax);
StreamableModel::from_paths(model_path, tokenizer_config_path, sampler, device)
}
pub fn new(model: model::ModelWeights, tokenizer: tokenizers::Tokenizer, sampler: LogitsProcessor, device: Device) -> Self {
let eos_token = "<|end_of_text|>";
let eot_token = "<|eot_id|>";
let eos_token = *tokenizer.get_vocab(true).get(eos_token).unwrap();
let eot_token = *tokenizer.get_vocab(true).get(eot_token).unwrap();
let ending_tokens = vec![eos_token, eot_token];
Self {
model,
tokenizer,
sampler,
device,
tokens: Vec::new(),
prev_index: 0,
current_index: 0,
stream_position: 0,
end_tokens: ending_tokens
}
}
fn ingest_prompt(&mut self, prompt: &str) -> (usize, Tensor) {
let tokens = self.tokenizer.encode(prompt, true).unwrap();
let prompt_tokens = tokens.get_ids();
let input = match Tensor::new(prompt_tokens, &self.device) {
Ok(item) => item.unsqueeze(0).unwrap(),
Err(_) => todo!(),
};
let logits = match self.model.forward(&input, 0) {
Ok(output) => output,
Err(_) => todo!(),
};
let logits = logits.squeeze(0).unwrap();
(prompt_tokens.len(), logits)
}
pub fn begin_stream(&mut self, prompt: &str) {
let (prompt_len, logits) = self.ingest_prompt(prompt);
let next_token = self.sampler.sample(&logits).unwrap();
self.stream_position = prompt_len;
self.tokens.push(next_token);
}
pub fn stream_step(&mut self) -> bool {
let terminal_token = *self.tokens.last().unwrap();
let input = Tensor::new(&[[terminal_token]], &self.device).unwrap();
let logits = self.model.forward(&input, self.stream_position).unwrap();
let logits = logits.squeeze(0).unwrap();
let next_token = self.sampler.sample(&logits).unwrap();
if self.end_tokens.contains(&next_token) {
return false
};
self.tokens.push(next_token);
self.stream_position += 1;
return true
}
pub fn next_token(&mut self) -> anyhow::Result<Option<String>> {
if self.tokens.is_empty() ||
(self.current_index == self.tokens.len() && self.prev_index == self.current_index)
{
return Ok(None)
}
let prev_text = {
let tokens = &self.tokens[self.prev_index..self.current_index];
self.tokenizer.decode(tokens, true).unwrap()
};
self.prev_index = self.current_index;
self.current_index = self.tokens.len();
return Ok(Some(prev_text))
}
}
use std::io::{self, Write};
use burnt_wick::streamable_model::StreamableModel;
use burnt_wick::streamable_model::Device;
use burnt_wick::streamable_model::Tokenizer;
use std::collections::HashMap;
use anyhow::Result;
fn encode_system(tokenizer: &Tokenizer, system_prompt: &str) -> Vec<u32> {
let bos_token = tokenizer.token_to_id("<|begin_of_text|>").unwrap();
let eot_token = tokenizer.token_to_id("<|eot_id|>").unwrap();
let mut tokens = vec![bos_token];
tokens.extend(encode_header(tokenizer, "system"));
let system_ids = tokenizer.encode(system_prompt, false).unwrap().get_ids().to_vec();
tokens.extend(system_ids);
tokens.push(eot_token);
tokens
}
fn encode_message(tokenizer: &Tokenizer, username: &str, message: &str) -> Vec<u32> {
let eot_token = tokenizer.token_to_id("<|eot_id|>").unwrap();
let mut tokens = encode_header(tokenizer, username);
let encoded = tokenizer.encode(message.trim(), false).unwrap();
tokens.extend(encoded.get_ids());
tokens.push(eot_token);
tokens
}
fn encode_message_no_user(tokenizer: &Tokenizer, message: &str) -> Vec<u32> {
let mut tokens = Vec::new();
let eot_token = tokenizer.token_to_id("<|eot_id|>").unwrap();
let encoded = tokenizer.encode(message.trim(), false).unwrap();
tokens.extend(encoded.get_ids());
tokens.push(eot_token);
tokens
}
fn encode_header(tokenizer: &Tokenizer, username: &str) -> Vec<u32> {
let start_header = tokenizer.token_to_id("<|start_header_id|>").unwrap();
let end_header = tokenizer.token_to_id("<|end_header_id|>").unwrap();
let mut tokens = vec![start_header];
tokens.extend(tokenizer.encode(username, false).unwrap().get_ids().to_vec());
tokens.push(end_header);
tokens.extend(tokenizer.encode("\n\n", false).unwrap().get_ids().to_vec());
tokens
}
fn encode_header_prefilled(tokenizer: &Tokenizer, username: &str, prefill: &str) -> Vec<u32> {
let start_header = tokenizer.token_to_id("<|start_header_id|>").unwrap();
let end_header = tokenizer.token_to_id("<|end_header_id|>").unwrap();
let mut tokens = vec![start_header];
tokens.extend(tokenizer.encode(username, false).unwrap().get_ids().to_vec());
tokens.push(end_header);
tokens.extend(tokenizer.encode("\n\n", false).unwrap().get_ids().to_vec());
tokens.extend(tokenizer.encode(prefill, false).unwrap().get_ids().to_vec());
tokens
}
fn main() -> Result<()> {
let model_path = "./Meta-Llama-3-8B-Instruct-Q4_K_M.gguf";
let tokenizer_path = "./tokenizer.json";
let device = Device::new_cuda(0)?;
let mut stream = StreamableModel::from_paths_argmax_sampling(model_path, tokenizer_path, device)?;
let tokenizer = Tokenizer::from_file(tokenizer_path).unwrap();
let mut current_conversation: HashMap<usize, Vec<u32>> = HashMap::new();
let system_prompt = "You are a helpful AI assistant. Answer the user's questions to the best of your ability.";
let system_tokens = encode_system(&tokenizer, system_prompt);
println!("Welcome to the LLM chat! Type 'exit' to end the conversation.");
let mut current_turn: usize = 0;
loop {
print!("User: ");
io::stdout().flush()?;
//User input
let mut user_input = String::new();
io::stdin().read_line(&mut user_input)?;
user_input = user_input.trim().to_string();
current_conversation.insert(current_turn, encode_message(&tokenizer, "user", &user_input));
current_turn += 1;
let mut keys: Vec<_> = current_conversation.keys().collect();
let mut conv_buf = Vec::new();
keys.sort();
conv_buf.extend(system_tokens.clone());
for (index, key) in keys.iter().enumerate() {
if let Some(value) = current_conversation.get(key) {
conv_buf.extend(value);
if index % 2 == 0 {
let assistant_header = encode_header(&tokenizer, "assistant");
conv_buf.extend(assistant_header);
}
}
}
let prompt_str = tokenizer.decode(&conv_buf, false).unwrap();
println!("{}", prompt_str);
print!("Assistant: ");
io::stdout().flush()?;
let mut keep_going = true;
let mut stream_output = String::new();
stream.begin_stream(&prompt_str);
while keep_going {
keep_going = stream.stream_step();
if let Some(t) = stream.next_token()? {
stream_output.push_str(&t);
print!("{}", t);
io::stdout().flush()?;
keep_going = true;
}
}
let encoded_response = encode_message_no_user(&tokenizer, &stream_output);
current_conversation.insert(current_turn, encoded_response);
current_turn += 1;
println!("\n");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment