Skip to content

Instantly share code, notes, and snippets.

@CoffeeVampir3
Created May 26, 2024 04:12
Show Gist options
  • Save CoffeeVampir3/35ad30c07f0a8dbc60e5d53534e14631 to your computer and use it in GitHub Desktop.
Save CoffeeVampir3/35ad30c07f0a8dbc60e5d53534e14631 to your computer and use it in GitHub Desktop.
use std::io::Write;
use tokenizers::Tokenizer;
use candle_core::quantized::{gguf_file};
use candle_core::Tensor;
use candle_core::Device;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::models::quantized_llama as model;
use candle_examples::token_output_stream::TokenOutputStream;
use model::ModelWeights;
use std::collections::HashMap;
use tokenizers::models::bpe::{BpeBuilder};
use tokenizers::pre_tokenizers::split::{Split, SplitPattern};
use tokenizers::pre_tokenizers::{byte_level::ByteLevel, sequence::Sequence, PreTokenizerWrapper};
use tokenizers::processors::byte_level::ByteLevel as ByteLevelProcessor;
use tokenizers::normalizer::SplitDelimiterBehavior;
fn load_model_and_tokenizer(
model_path: &str,
device: &Device,
) -> Result<(ModelWeights, Tokenizer), Box<dyn std::error::Error>> {
let mut file = std::fs::File::open(model_path)?;
let model_content = gguf_file::Content::read(&mut file)?;
let vocab_size = model_content.metadata["llama.vocab_size"].to_u32().unwrap();
let merges_values = model_content.metadata["tokenizer.ggml.merges"].to_vec()?;
let merges: Vec<(String, String)> = merges_values
.iter()
.map(|v| {
let merge_str = v.to_string().unwrap();
let parts: Vec<&str> = merge_str.split(' ').collect();
(parts[0].to_string(), parts[1].to_string())
})
.collect();
let tokens = model_content.metadata["tokenizer.ggml.tokens"].to_vec()?;
let tokens: Vec<String> = tokens.iter().map(|v| v.to_string().unwrap().to_owned()).collect();
println!("{:?}", model_content.metadata["tokenizer.ggml.model"]);
let mut token_to_id = HashMap::with_capacity(vocab_size as usize);
for (id, tok) in tokens.into_iter().enumerate() {
token_to_id.insert(tok, id as u32);
}
let bpe_builder = BpeBuilder::new()
.vocab_and_merges(
token_to_id,
merges,
)
.continuing_subword_prefix("".to_owned())
.end_of_word_suffix("".to_owned())
.byte_fallback(false)
.fuse_unk(false)
.ignore_merges(true);
let mut tokenizer = Tokenizer::new(bpe_builder.build().unwrap());
let comma_split = Sequence::new(
vec![
PreTokenizerWrapper::Split(
Split::new(
SplitPattern::Regex(r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+".into()),
SplitDelimiterBehavior::Isolated,
false,
).unwrap()
),
PreTokenizerWrapper::ByteLevel(ByteLevel::new(false, true, false)),
]
);
tokenizer.with_pre_tokenizer(comma_split);
let post_processor = ByteLevelProcessor::new(true, false, true);
tokenizer.with_post_processor(post_processor);
let decoder = ByteLevelProcessor::new(true, true, true);
tokenizer.with_decoder(decoder);
let model = ModelWeights::from_gguf(model_content, &mut file, device)?;
Ok((model, tokenizer))
}
fn ingest_prompt(
model: &mut ModelWeights,
token_stream: &TokenOutputStream,
device: &Device,
prompt: &str,
) -> Result<(usize, Tensor), Box<dyn std::error::Error>> {
let tokens = token_stream
.tokenizer()
.encode(prompt, true)
.unwrap();
let prompt_tokens = tokens.get_ids();
let input = Tensor::new(prompt_tokens, &device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
Ok((prompt_tokens.len(), logits))
}
fn streamed_inference(
model: &mut ModelWeights,
token_stream: &mut TokenOutputStream,
device: &Device,
prompt_len: usize,
initial_logits: &Tensor,
) -> Result<usize, Box<dyn std::error::Error>> {
let mut tokens_processed: usize = 0;
let mut logits_processor = LogitsProcessor::from_sampling(0, Sampling::ArgMax);
let eos_token = "<|end_of_text|>";
let eot_token = "<|eot_id|>";
let eos_token = *token_stream.tokenizer().get_vocab(true).get(eos_token).unwrap();
let eot_token = *token_stream.tokenizer().get_vocab(true).get(eot_token).unwrap();
println!("{eos_token}");
println!("{eot_token}");
let mut next_token = logits_processor.sample(&initial_logits)?;
//FLUSH FIRST TOKEN
if let Some(t) = token_stream.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
//INFERENCE
for index in 0..1000 {
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_len + index)?;
let logits = logits.squeeze(0)?;
next_token = logits_processor.sample(&logits)?;
if let Some(t) = token_stream.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
//print!(" {next_token} ");
if next_token == eos_token || next_token == eot_token {
break;
};
tokens_processed += 1;
}
if let Some(rest) = token_stream.decode_rest()? {
print!("{rest}");
std::io::stdout().flush()?;
}
Ok(tokens_processed)
}
fn main() -> anyhow::Result<()> {
let model_path = "./Meta-Llama-3-8B-Instruct-Q4_K_M.gguf";
let device = Device::new_cuda(0)?;
let (mut model, tokenizer) = match load_model_and_tokenizer(model_path, &device) {
Ok((model, tokenizer)) => (model, tokenizer),
Err(e) => {
println!("Error: {}", e);
return Ok(());
}
};
let prompt_str = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a storywriter. Write whatever the user asks for.<|eot_id|><|start_header_id|>user<|end_header_id|>
Write a story about a cute girl who finds an enchanted meadow.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n";
let mut tos = TokenOutputStream::new(tokenizer);
let ingest_start = std::time::Instant::now();
let (prompt_token_len, initial_logits) = ingest_prompt(&mut model, &tos, &device, prompt_str).unwrap();
let ingest_end = std::time::Instant::now();
let total_time = ingest_end - ingest_start;
println!(
"\n\n{:4} prompt tokens processed: {:.1} t/s {:.2} s total",
prompt_token_len,
(prompt_token_len as f64) / total_time.as_secs_f64(),
total_time.as_secs_f64()
);
let inference_start = std::time::Instant::now();
let tokens_processed = streamed_inference(&mut model, &mut tos, &device, prompt_token_len, &initial_logits).unwrap();
let inference_end = std::time::Instant::now();
let total_time = inference_end - inference_start;
println!(
"\n\n{:4} prompt tokens processed: {:.2} token/s",
tokens_processed,
tokens_processed as f64 / total_time.as_secs_f64(),
);
print!("\n\n");
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment