Skip to content

Instantly share code, notes, and snippets.

@ealmloff
Last active March 20, 2025 01:00
Show Gist options
  • Save ealmloff/7b3b8bf195316482207639554315ddb4 to your computer and use it in GitHub Desktop.
Save ealmloff/7b3b8bf195316482207639554315ddb4 to your computer and use it in GitHub Desktop.
// candle-core = "0.8.4"
use candle_core::Module;
fn main() {
let mut q_data = [[0f32; 256]; 256];
q_data[0][0] = 1.;
q_data[0][1] = 2.;
q_data[0][2] = 3.;
q_data[1][0] = 3.;
q_data[1][1] = 2.;
q_data[1][2] = 1.;
q_data[2][0] = 1.;
q_data[2][1] = 5.;
q_data[2][2] = 3.;
let tensor = candle_core::Tensor::new(&q_data, &candle_core::Device::Cpu).unwrap();
let quantized =
candle_core::quantized::QTensor::quantize(&tensor, candle_core::quantized::GgmlDType::Q4K)
.unwrap();
let candle_q_matrix = candle_core::quantized::QMatMul::from_qtensor(quantized).unwrap();
let mut tensor_data = vec![vec![0f32; 256]; 256];
tensor_data[0][0] = 4.;
tensor_data[0][1] = 5.;
tensor_data[0][2] = 6.;
tensor_data[1][0] = 6.;
tensor_data[1][1] = 5.;
tensor_data[1][2] = 21.;
tensor_data[2][0] = 4.;
tensor_data[2][1] = 6.;
tensor_data[2][2] = 5.;
let candle_input = candle_core::Tensor::from_iter(
tensor_data.iter().flat_map(|x| x.iter().copied()),
&candle_core::Device::Cpu,
)
.unwrap()
.reshape(&[256, 256])
.unwrap();
println!(
"candle_input: {:?}",
candle_input
.narrow(0, 0, 3)
.unwrap()
.narrow(1, 0, 3)
.unwrap()
.to_vec2::<f32>()
.unwrap()
);
let candle_output = candle_q_matrix.forward(&candle_input).unwrap();
let candle_output = candle_output
.narrow(0, 0, 3)
.unwrap()
.narrow(1, 0, 3)
.unwrap()
.to_vec2::<f32>()
.unwrap();
println!("candle_output: {:?}", candle_output);
// Solution from https://www.wolframalpha.com/input?i=matrix+multiplication+calculator&assumption=%7B%22F%22%2C+%22MatricesOperations%22%2C+%22theMatrix2%22%7D+-%3E%22%7B%7B1%2C2%2C3%7D%2C%7B3%2C2%2C1%7D%2C%7B1%2C5%2C3%7D%7D%22&assumption=%7B%22F%22%2C+%22MatricesOperations%22%2C+%22theMatrix1%22%7D+-%3E%22%7B%7B4%2C5%2C6%7D%2C%7B6%2C5%2C21%7D%2C%7B4%2C6%2C5%7D%7D%22
// {{25, 48, 35}, {42, 127, 86}, {27, 45, 33}}
let expected = vec![
vec![25., 48., 35.],
vec![42., 127., 86.],
vec![27., 45., 33.],
];
assert!(candle_output.iter().zip(expected.iter()).all(|(a, b)| {
a.iter()
.zip(b.iter())
.all(|(x, y)| (x - y).abs() < 10.)
}));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment