Created
May 26, 2024 04:12
-
-
Save CoffeeVampir3/35ad30c07f0a8dbc60e5d53534e14631 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::io::Write; | |
use tokenizers::Tokenizer; | |
use candle_core::quantized::{gguf_file}; | |
use candle_core::Tensor; | |
use candle_core::Device; | |
use candle_transformers::generation::{LogitsProcessor, Sampling}; | |
use candle_transformers::models::quantized_llama as model; | |
use candle_examples::token_output_stream::TokenOutputStream; | |
use model::ModelWeights; | |
use std::collections::HashMap; | |
use tokenizers::models::bpe::{BpeBuilder}; | |
use tokenizers::pre_tokenizers::split::{Split, SplitPattern}; | |
use tokenizers::pre_tokenizers::{byte_level::ByteLevel, sequence::Sequence, PreTokenizerWrapper}; | |
use tokenizers::processors::byte_level::ByteLevel as ByteLevelProcessor; | |
use tokenizers::normalizer::SplitDelimiterBehavior; | |
fn load_model_and_tokenizer( | |
model_path: &str, | |
device: &Device, | |
) -> Result<(ModelWeights, Tokenizer), Box<dyn std::error::Error>> { | |
let mut file = std::fs::File::open(model_path)?; | |
let model_content = gguf_file::Content::read(&mut file)?; | |
let vocab_size = model_content.metadata["llama.vocab_size"].to_u32().unwrap(); | |
let merges_values = model_content.metadata["tokenizer.ggml.merges"].to_vec()?; | |
let merges: Vec<(String, String)> = merges_values | |
.iter() | |
.map(|v| { | |
let merge_str = v.to_string().unwrap(); | |
let parts: Vec<&str> = merge_str.split(' ').collect(); | |
(parts[0].to_string(), parts[1].to_string()) | |
}) | |
.collect(); | |
let tokens = model_content.metadata["tokenizer.ggml.tokens"].to_vec()?; | |
let tokens: Vec<String> = tokens.iter().map(|v| v.to_string().unwrap().to_owned()).collect(); | |
println!("{:?}", model_content.metadata["tokenizer.ggml.model"]); | |
let mut token_to_id = HashMap::with_capacity(vocab_size as usize); | |
for (id, tok) in tokens.into_iter().enumerate() { | |
token_to_id.insert(tok, id as u32); | |
} | |
let bpe_builder = BpeBuilder::new() | |
.vocab_and_merges( | |
token_to_id, | |
merges, | |
) | |
.continuing_subword_prefix("".to_owned()) | |
.end_of_word_suffix("".to_owned()) | |
.byte_fallback(false) | |
.fuse_unk(false) | |
.ignore_merges(true); | |
let mut tokenizer = Tokenizer::new(bpe_builder.build().unwrap()); | |
let comma_split = Sequence::new( | |
vec![ | |
PreTokenizerWrapper::Split( | |
Split::new( | |
SplitPattern::Regex(r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+".into()), | |
SplitDelimiterBehavior::Isolated, | |
false, | |
).unwrap() | |
), | |
PreTokenizerWrapper::ByteLevel(ByteLevel::new(false, true, false)), | |
] | |
); | |
tokenizer.with_pre_tokenizer(comma_split); | |
let post_processor = ByteLevelProcessor::new(true, false, true); | |
tokenizer.with_post_processor(post_processor); | |
let decoder = ByteLevelProcessor::new(true, true, true); | |
tokenizer.with_decoder(decoder); | |
let model = ModelWeights::from_gguf(model_content, &mut file, device)?; | |
Ok((model, tokenizer)) | |
} | |
fn ingest_prompt( | |
model: &mut ModelWeights, | |
token_stream: &TokenOutputStream, | |
device: &Device, | |
prompt: &str, | |
) -> Result<(usize, Tensor), Box<dyn std::error::Error>> { | |
let tokens = token_stream | |
.tokenizer() | |
.encode(prompt, true) | |
.unwrap(); | |
let prompt_tokens = tokens.get_ids(); | |
let input = Tensor::new(prompt_tokens, &device)?.unsqueeze(0)?; | |
let logits = model.forward(&input, 0)?; | |
let logits = logits.squeeze(0)?; | |
Ok((prompt_tokens.len(), logits)) | |
} | |
fn streamed_inference( | |
model: &mut ModelWeights, | |
token_stream: &mut TokenOutputStream, | |
device: &Device, | |
prompt_len: usize, | |
initial_logits: &Tensor, | |
) -> Result<usize, Box<dyn std::error::Error>> { | |
let mut tokens_processed: usize = 0; | |
let mut logits_processor = LogitsProcessor::from_sampling(0, Sampling::ArgMax); | |
let eos_token = "<|end_of_text|>"; | |
let eot_token = "<|eot_id|>"; | |
let eos_token = *token_stream.tokenizer().get_vocab(true).get(eos_token).unwrap(); | |
let eot_token = *token_stream.tokenizer().get_vocab(true).get(eot_token).unwrap(); | |
println!("{eos_token}"); | |
println!("{eot_token}"); | |
let mut next_token = logits_processor.sample(&initial_logits)?; | |
//FLUSH FIRST TOKEN | |
if let Some(t) = token_stream.next_token(next_token)? { | |
print!("{t}"); | |
std::io::stdout().flush()?; | |
} | |
//INFERENCE | |
for index in 0..1000 { | |
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?; | |
let logits = model.forward(&input, prompt_len + index)?; | |
let logits = logits.squeeze(0)?; | |
next_token = logits_processor.sample(&logits)?; | |
if let Some(t) = token_stream.next_token(next_token)? { | |
print!("{t}"); | |
std::io::stdout().flush()?; | |
} | |
//print!(" {next_token} "); | |
if next_token == eos_token || next_token == eot_token { | |
break; | |
}; | |
tokens_processed += 1; | |
} | |
if let Some(rest) = token_stream.decode_rest()? { | |
print!("{rest}"); | |
std::io::stdout().flush()?; | |
} | |
Ok(tokens_processed) | |
} | |
fn main() -> anyhow::Result<()> { | |
let model_path = "./Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"; | |
let device = Device::new_cuda(0)?; | |
let (mut model, tokenizer) = match load_model_and_tokenizer(model_path, &device) { | |
Ok((model, tokenizer)) => (model, tokenizer), | |
Err(e) => { | |
println!("Error: {}", e); | |
return Ok(()); | |
} | |
}; | |
let prompt_str = "<|begin_of_text|><|start_header_id|>system<|end_header_id|> | |
You are a storywriter. Write whatever the user asks for.<|eot_id|><|start_header_id|>user<|end_header_id|> | |
Write a story about a cute girl who finds an enchanted meadow.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"; | |
let mut tos = TokenOutputStream::new(tokenizer); | |
let ingest_start = std::time::Instant::now(); | |
let (prompt_token_len, initial_logits) = ingest_prompt(&mut model, &tos, &device, prompt_str).unwrap(); | |
let ingest_end = std::time::Instant::now(); | |
let total_time = ingest_end - ingest_start; | |
println!( | |
"\n\n{:4} prompt tokens processed: {:.1} t/s {:.2} s total", | |
prompt_token_len, | |
(prompt_token_len as f64) / total_time.as_secs_f64(), | |
total_time.as_secs_f64() | |
); | |
let inference_start = std::time::Instant::now(); | |
let tokens_processed = streamed_inference(&mut model, &mut tos, &device, prompt_token_len, &initial_logits).unwrap(); | |
let inference_end = std::time::Instant::now(); | |
let total_time = inference_end - inference_start; | |
println!( | |
"\n\n{:4} prompt tokens processed: {:.2} token/s", | |
tokens_processed, | |
tokens_processed as f64 / total_time.as_secs_f64(), | |
); | |
print!("\n\n"); | |
Ok(()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment