Skip to content

Instantly share code, notes, and snippets.

@lovely-error
Created October 1, 2024 15:10
Show Gist options
  • Save lovely-error/47d804e56050e8388b7fc793220f6a24 to your computer and use it in GitHub Desktop.
Save lovely-error/47d804e56050e8388b7fc793220f6a24 to your computer and use it in GitHub Desktop.
naive matrix multiplication
#![feature(debug_closure_helpers)]
use core::{alloc::Layout, mem::{align_of, size_of}};
#[derive(Debug, Clone, Copy)]
struct MatrixDims {
rows: u8,
cols: u8
}
struct Matrix {
dims: MatrixDims,
storage: *mut i32,
}
impl Drop for Matrix {
fn drop(&mut self) {
unsafe { std::alloc::dealloc(self.storage.cast(), Layout::from_size_align_unchecked(self.storage_byte_count(), align_of::<u32>())) };
}
}
impl core::fmt::Debug for Matrix {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut dbg = fmt.debug_list();
let mut ix = 0;
let col = self.dims.cols as usize;
let row = self.dims.rows as usize;
let mut ptr = self.storage;
loop {
if ix == row { break }
dbg.entry_with(|fmt| {
let mut dbg = fmt.debug_list();
let slc = unsafe { core::slice::from_raw_parts(ptr, col) };
dbg.entries(slc);
dbg.finish()
});
ptr = unsafe { ptr.add(col) };
ix += 1;
}
return dbg.finish();
}
}
impl Matrix {
fn item_count(&self) -> usize {
self.dims.cols as usize * self.dims.rows as usize
}
fn storage_byte_count(&self) -> usize {
self.item_count() * size_of::<i32>()
}
fn item_view(&self) -> &mut [i32] {
unsafe { core::slice::from_raw_parts_mut(self.storage, self.item_count()) }
}
}
fn new_matrix_filled(rows: u8, columns: u8, value: i32) -> Matrix {
let item_count = (columns as usize) * (rows as usize);
let byte_count = item_count * size_of::<i32>();
let memory = unsafe {
std::alloc::alloc(Layout::from_size_align(byte_count, align_of::<i32>()).unwrap()).cast::<i32>()
};
let slc = unsafe { core::slice::from_raw_parts_mut(memory, item_count) };
for item in slc {
*item = value;
}
Matrix { dims: MatrixDims { rows: rows, cols: columns }, storage: memory }
}
fn new_matrix_copied(target: &Matrix) -> Matrix {
let mat = new_matrix_filled(target.dims.cols, target.dims.rows, 0);
let mat_view = mat.item_view();
let tar_view = target.item_view();
for ix in 0 .. target.item_count() {
mat_view[ix] = tar_view[ix];
};
mat
}
fn new_matrix_from_linear_array(rows: u8, columns: u8, items: &[i32]) -> Matrix {
let m = new_matrix_filled(rows, columns, 0);
let k = m.item_view();
let mut ix = 0;
let limit = items.len();
loop {
if ix == limit { break }
k[ix] = items[ix];
ix += 1;
}
return m
}
fn write_print_repr(matrix: &Matrix, buffer: &mut String) {
let mut ix = 0;
let col = matrix.dims.cols as usize;
let row = matrix.dims.rows as usize;
let mut ptr = matrix.storage;
use core::fmt::Write;
loop {
if ix == row { break }
buffer.push('[');
for item in unsafe { core::slice::from_raw_parts(ptr, col) } {
write!(buffer, "{},", item).unwrap();
}
buffer.push_str("]\n");
ptr = unsafe { ptr.add(col) };
ix += 1;
}
}
fn multiply_by_constant_inplace(
operand1: &Matrix,
operand2: i32
) {
for item in operand1.item_view() {
*item *= operand2;
}
}
fn multiply_matricies(operand1: &Matrix, operand2: &Matrix) -> Option<Matrix> {
if operand1.dims.rows != operand2.dims.cols {
return None
}
let result = new_matrix_filled(operand1.dims.rows, operand2.dims.cols, 0);
let mut rows = {
let mut counter = 0;
let mut ptr = operand1.storage;
core::iter::from_fn(move || {
if counter == operand1.dims.rows { return None }
counter += 1;
let slc = unsafe { core::slice::from_raw_parts(ptr, operand1.dims.cols as _) };
ptr = unsafe { ptr.add(operand1.dims.cols as _) };
return Some(slc)
})
};
let mut cols = {
let mut counter1 = 0;
let mut counter2 = 0;
core::iter::from_fn(move || {
if counter1 == operand2.dims.rows as _ {
counter1 = 0;
counter2 += 1;
};
if counter2 == operand2.dims.cols as _ {
counter2 = 0;
}
let val = unsafe { *operand2.storage.add(counter1 * (operand2.dims.cols as usize) + counter2) };
counter1 += 1;
return Some(val)
})
};
let mut result_ptr = result.storage;
while let Some(slc) = rows.next() {
for _ in 0 .. operand2.dims.cols as _ {
let mut acc = 0;
for ix in 0 .. operand1.dims.cols as _ {
acc += slc[ix] * cols.next().unwrap();
}
unsafe {
result_ptr.write(acc);
result_ptr = result_ptr.add(1);
};
}
};
return Some(result);
}
fn main() {
let m1 = new_matrix_from_linear_array(3, 3, &[1,2,3, 4,5,6, 7,8,9]);
// multiply_by_constant_inplace(&m1, 2);
let m2 = multiply_matricies(&m1, &m1).unwrap();
let mut str = String::new();
write_print_repr(&m2, &mut str);
println!("{}", str);
let m1 = new_matrix_from_linear_array(2, 4, &[1,2,3,4, 5,6,7,8]);
let m2 = new_matrix_from_linear_array(4, 2, &[1,2, 3,4, 5,6, 7,8]);
let m3 = multiply_matricies(&m1, &m2).unwrap_or_else(||panic!("cant multiply cus dims diagree"));
let mut str = String::new();
write_print_repr(&m3, &mut str);
println!("{}", str);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment