Created
January 20, 2024 13:30
-
-
Save mcharytoniuk/9d80655dc023976b112b2ec671400ca4 to your computer and use it in GitHub Desktop.
PHP + Candle proof of concept
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[package] | |
name = "hello_world" | |
version = "0.1.0" | |
edition = "2021" | |
[lib] | |
crate-type = ["cdylib"] | |
[dependencies] | |
anyhow = "*" | |
ext-php-rs = "*" | |
candle = "*" | |
candle-core = { version = "*", features = ["cuda"] } | |
candle-nn = { version = "*", features = ["cuda"] } | |
candle-transformers = { version = "*", features = ["cuda"] } | |
hf-hub = "*" | |
tokenizers = "*" | |
[profile.release] | |
strip = "debuginfo" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::thread; | |
use anyhow::{Error as E, Result}; | |
use candle_core::{DType, Device, Tensor}; | |
use candle_core::Result as CandleResult; | |
use candle_core::utils; | |
use candle_core::utils::{cuda_is_available, metal_is_available}; | |
use candle_nn::VarBuilder; | |
use candle_transformers::generation::LogitsProcessor; | |
use candle_transformers::models::mistral::{Config, Model as Mistral}; | |
use candle_transformers::models::quantized_mistral::Model as QMistral; | |
use ext_php_rs::prelude::*; | |
use hf_hub::{api::sync::Api, Repo, RepoType}; | |
use tokenizers::Tokenizer; | |
pub fn fooDevice(cpu: bool) -> CandleResult<Device> { | |
if cpu { | |
Ok(Device::Cpu) | |
} else if cuda_is_available() { | |
Ok(Device::new_cuda(0)?) | |
} else if metal_is_available() { | |
Ok(Device::new_metal(0)?) | |
} else { | |
#[cfg(all(target_os = "macos", target_arch = "aarch64"))] | |
{ | |
println!( | |
"Running on CPU, to run on GPU(metal), build this example with `--features metal`" | |
); | |
} | |
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))] | |
{ | |
println!("Running on CPU, to run on GPU, build this example with `--features cuda`"); | |
} | |
candle_core::bail!("No device available") | |
} | |
} | |
/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a | |
/// streaming way rather than having to wait for the full decoding. | |
pub struct TokenOutputStream { | |
tokenizer: Tokenizer, | |
tokens: Vec<u32>, | |
prev_index: usize, | |
current_index: usize, | |
} | |
impl TokenOutputStream { | |
pub fn new(tokenizer: Tokenizer) -> Self { | |
Self { | |
tokenizer, | |
tokens: Vec::new(), | |
prev_index: 0, | |
current_index: 0, | |
} | |
} | |
pub fn into_inner(self) -> Tokenizer { | |
self.tokenizer | |
} | |
fn decode(&self, tokens: &[u32]) -> CandleResult<String> { | |
match self.tokenizer.decode(tokens, true) { | |
Ok(str) => Ok(str), | |
Err(err) => candle_core::bail!("cannot decode: {err}"), | |
} | |
} | |
// https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68 | |
pub fn next_token(&mut self, token: u32) -> CandleResult<Option<String>> { | |
let prev_text = if self.tokens.is_empty() { | |
String::new() | |
} else { | |
let tokens = &self.tokens[self.prev_index..self.current_index]; | |
self.decode(tokens)? | |
}; | |
self.tokens.push(token); | |
let text = self.decode(&self.tokens[self.prev_index..])?; | |
if text.len() > prev_text.len() && text.chars().last().unwrap().is_ascii() { | |
let text = text.split_at(prev_text.len()); | |
self.prev_index = self.current_index; | |
self.current_index = self.tokens.len(); | |
Ok(Some(text.1.to_string())) | |
} else { | |
Ok(None) | |
} | |
} | |
pub fn decode_rest(&self) -> CandleResult<Option<String>> { | |
let prev_text = if self.tokens.is_empty() { | |
String::new() | |
} else { | |
let tokens = &self.tokens[self.prev_index..self.current_index]; | |
self.decode(tokens)? | |
}; | |
let text = self.decode(&self.tokens[self.prev_index..])?; | |
if text.len() > prev_text.len() { | |
let text = text.split_at(prev_text.len()); | |
Ok(Some(text.1.to_string())) | |
} else { | |
Ok(None) | |
} | |
} | |
pub fn decode_all(&self) -> CandleResult<String> { | |
self.decode(&self.tokens) | |
} | |
pub fn get_token(&self, token_s: &str) -> Option<u32> { | |
self.tokenizer.get_vocab(true).get(token_s).copied() | |
} | |
pub fn tokenizer(&self) -> &Tokenizer { | |
&self.tokenizer | |
} | |
pub fn clear(&mut self) { | |
self.tokens.clear(); | |
self.prev_index = 0; | |
self.current_index = 0; | |
} | |
} | |
struct Args { | |
/// Run on CPU rather than on GPU. | |
// #[arg(long)] | |
cpu: bool, | |
// #[arg(long)] | |
use_flash_attn: bool, | |
// #[arg(long)] | |
prompt: String, | |
/// The temperature used to generate samples. | |
// #[arg(long)] | |
temperature: Option<f64>, | |
/// Nucleus sampling probability cutoff. | |
// #[arg(long)] | |
top_p: Option<f64>, | |
/// The seed to use when generating random samples. | |
// #[arg(long, default_value_t = 299792458)] | |
seed: u64, | |
/// The length of the sample to generate (in tokens). | |
// #[arg(long, short = 'n', default_value_t = 100)] | |
sample_len: usize, | |
// #[arg(long, default_value = "lmz/candle-mistral")] | |
model_id: String, | |
// #[arg(long, default_value = "main")] | |
revision: String, | |
// #[arg(long)] | |
tokenizer_file: Option<String>, | |
// #[arg(long)] | |
weight_files: Option<String>, | |
// #[arg(long)] | |
quantized: bool, | |
/// Penalty to be applied for repeating tokens, 1. means no penalty. | |
// #[arg(long, default_value_t = 1.1)] | |
repeat_penalty: f32, | |
/// The context size to consider for the repeat penalty. | |
// #[arg(long, default_value_t = 64)] | |
repeat_last_n: usize, | |
} | |
enum Model { | |
Mistral(Mistral), | |
Quantized(QMistral), | |
} | |
struct TextGeneration { | |
model: Model, | |
device: Device, | |
tokenizer: TokenOutputStream, | |
logits_processor: LogitsProcessor, | |
repeat_penalty: f32, | |
repeat_last_n: usize, | |
} | |
impl TextGeneration { | |
#[allow(clippy::too_many_arguments)] | |
fn new( | |
model: Model, | |
tokenizer: Tokenizer, | |
seed: u64, | |
temp: Option<f64>, | |
top_p: Option<f64>, | |
repeat_penalty: f32, | |
repeat_last_n: usize, | |
device: &Device, | |
) -> Self { | |
let logits_processor = LogitsProcessor::new(seed, temp, top_p); | |
Self { | |
model, | |
tokenizer: TokenOutputStream::new(tokenizer), | |
logits_processor, | |
repeat_penalty, | |
repeat_last_n, | |
device: device.clone(), | |
} | |
} | |
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { | |
use std::io::Write; | |
self.tokenizer.clear(); | |
let mut tokens = self | |
.tokenizer | |
.tokenizer() | |
.encode(prompt, true) | |
.map_err(E::msg)? | |
.get_ids() | |
.to_vec(); | |
for &t in tokens.iter() { | |
if let Some(t) = self.tokenizer.next_token(t)? { | |
print!("{t}") | |
} | |
} | |
std::io::stdout().flush()?; | |
let mut generated_tokens = 0usize; | |
let eos_token = match self.tokenizer.get_token("</s>") { | |
Some(token) => token, | |
None => anyhow::bail!("cannot find the </s> token"), | |
}; | |
let start_gen = std::time::Instant::now(); | |
for index in 0..sample_len { | |
let context_size = if index > 0 { 1 } else { tokens.len() }; | |
let start_pos = tokens.len().saturating_sub(context_size); | |
let ctxt = &tokens[start_pos..]; | |
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; | |
let logits = match &mut self.model { | |
Model::Mistral(m) => m.forward(&input, start_pos)?, | |
Model::Quantized(m) => m.forward(&input, start_pos)?, | |
}; | |
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; | |
let logits = if self.repeat_penalty == 1. { | |
logits | |
} else { | |
let start_at = tokens.len().saturating_sub(self.repeat_last_n); | |
candle_transformers::utils::apply_repeat_penalty( | |
&logits, | |
self.repeat_penalty, | |
&tokens[start_at..], | |
)? | |
}; | |
let next_token = self.logits_processor.sample(&logits)?; | |
tokens.push(next_token); | |
generated_tokens += 1; | |
if next_token == eos_token { | |
break; | |
} | |
if let Some(t) = self.tokenizer.next_token(next_token)? { | |
print!("{t}"); | |
std::io::stdout().flush()?; | |
} | |
} | |
let dt = start_gen.elapsed(); | |
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { | |
print!("{rest}"); | |
} | |
std::io::stdout().flush()?; | |
println!( | |
"\n{generated_tokens} tokens generated ({:.2} token/s)", | |
generated_tokens as f64 / dt.as_secs_f64(), | |
); | |
Ok(()) | |
} | |
} | |
fn mistral(prompt: String) -> Result<()> { | |
let args = Args{ | |
cpu: false, | |
use_flash_attn: false, | |
prompt: prompt, | |
temperature: None, | |
top_p: None, | |
seed: 299792459, | |
sample_len: 40, | |
model_id: "lmz/candle-mistral".to_string(), | |
revision: "main".to_string(), | |
tokenizer_file: None, | |
weight_files: None, | |
quantized: false, | |
repeat_penalty: 1.1, | |
repeat_last_n: 64, | |
}; | |
println!( | |
"avx: {}, neon: {}, simd128: {}, f16c: {}", | |
utils::with_avx(), | |
utils::with_neon(), | |
utils::with_simd128(), | |
utils::with_f16c() | |
); | |
println!( | |
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", | |
args.temperature.unwrap_or(0.), | |
args.repeat_penalty, | |
args.repeat_last_n | |
); | |
let start = std::time::Instant::now(); | |
let api = Api::new()?; | |
let repo = api.repo(Repo::with_revision( | |
args.model_id, | |
RepoType::Model, | |
args.revision, | |
)); | |
// let tokenizer_filename = match args.tokenizer_file { | |
// Some(file) => std::path::PathBuf::from(file), | |
// None => repo.get("tokenizer.json")?, | |
// }; | |
let tokenizer_filename = repo.get("tokenizer.json")?; | |
let filenames = match args.weight_files { | |
Some(files) => files | |
.split(',') | |
.map(std::path::PathBuf::from) | |
.collect::<Vec<_>>(), | |
None => { | |
if args.quantized { | |
vec![repo.get("model-q4k.gguf")?] | |
} else { | |
vec![ | |
repo.get("pytorch_model-00001-of-00002.safetensors")?, | |
repo.get("pytorch_model-00002-of-00002.safetensors")?, | |
] | |
} | |
} | |
}; | |
println!("retrieved the files in {:?}", start.elapsed()); | |
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; | |
let start = std::time::Instant::now(); | |
let config = Config::config_7b_v0_1(args.use_flash_attn); | |
let (model, device) = if args.quantized { | |
let device = fooDevice(args.cpu)?; | |
let filename = &filenames[0]; | |
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?; | |
let model = QMistral::new(&config, vb)?; | |
(Model::Quantized(model), device) | |
} else { | |
let device = fooDevice(args.cpu)?; | |
let dtype = if device.is_cuda() { | |
DType::BF16 | |
} else { | |
DType::F32 | |
}; | |
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; | |
let model = Mistral::new(&config, vb)?; | |
(Model::Mistral(model), device) | |
}; | |
println!("loaded the model in {:?}", start.elapsed()); | |
let mut pipeline = TextGeneration::new( | |
model, | |
tokenizer, | |
args.seed, | |
args.temperature, | |
args.top_p, | |
args.repeat_penalty, | |
args.repeat_last_n, | |
&device, | |
); | |
pipeline.run(&args.prompt, args.sample_len)?; | |
Ok(()) | |
} | |
#[php_function] | |
pub fn hello_world() -> String { | |
thread::spawn(|| { | |
println!(match mistral("Who are you?".to_string()) { | |
Ok(ret) => format!("YES"), | |
Err(e) => format!("Error: {}", e), | |
}) | |
}); | |
format!("YES") | |
} | |
#[php_module] | |
pub fn get_module(module: ModuleBuilder) -> ModuleBuilder { | |
module | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This i almost copy-pasted Candle example code with a few tweaks to make it work under PHP, combined with php-ext-rs to package it as an extension.
This is a proof of concept. If you are looking to serve LLMs through PHP, look at that instead :D -> https://resonance.distantmagic.com/tutorials/how-to-create-llm-websocket-chat-with-llama-cpp/