Skip to content

Instantly share code, notes, and snippets.

@travisstaloch
Created January 19, 2025 07:14
Show Gist options
  • Save travisstaloch/60beb1973792bc76da9f9b385aa1b644 to your computer and use it in GitHub Desktop.
Save travisstaloch/60beb1973792bc76da9f9b385aa1b644 to your computer and use it in GitHub Desktop.
An incomplete generic matrix lib which imitates some aspects of numpy
//!
//! a work in progress
//!
//! this lib only does multiplication so far. but can parse numpy output at
//! runtime and comptime and has lots helpers including array() which
//! allows multi indexing i.e `mat.array()[i][j]`. there are some working
//! tests for 1d, 2d, 3d and 4d matrices w/ different shapes.
//!
//! Resources:
//! https://ajcr.net/stride-guide-part-1/
const std = @import("std");
const assert = std.debug.assert;
const testing = std.testing;
const mem = std.mem;
const Allocator = mem.Allocator;
pub const ParseError = error{ParseFailure};
/// `shape` must be a tuple, array, or slice coercable to []const usize
/// `strides` must be a tuple, array, or slice coercable to []const usize with the same length as `shape`
pub fn Matrix(comptime T: type, comptime shape: anytype, comptime strides: anytype) type {
return struct {
/// pointer to an N-dimensional array of `T` shaped according to `shape`
buf: *Buf,
/// the shape of `Buf`. the number of items in each dimension.
comptime shape: []const usize = _shape,
/// the strides of `Buf`.
comptime strides: []const isize = _strides,
/// an N dimensional array of `T` shaped according to `shape`
pub const Array = ArrayFromDim(0);
/// a flattened, one dimensional array of `T`. has the same memory layout as Array.
pub const Buf = [len]T;
/// the total number of elements in the Buf
pub const len = @reduce(.Mul, @as(
@Vector(_shape.len, usize),
_shape[0.._shape.len].*,
));
pub const child = T;
/// returns an Array type for the given dimension.
/// for example:
/// given shape = .{2,4}
/// ArrayFromDim(0) = [2][4]T
/// ArrayFromDim(1) = [4]T
/// ArrayFromDim(2) = T
pub fn ArrayFromDim(comptime dim: usize) type {
const s = _shape[dim..];
return switch (s.len) {
0 => T,
1 => [s[0]]T,
2 => [s[0]][s[1]]T,
3 => [s[0]][s[1]][s[2]]T,
4 => [s[0]][s[1]][s[2]][s[3]]T,
5 => [s[0]][s[1]][s[2]][s[3]][s[4]]T,
else => unreachable, // TODO
};
}
const Self = @This();
const _shape = toSlice(usize, shape);
const _strides = if (strides.len == 0)
newStrides(_shape)
else
toSlice(isize, strides);
pub fn init(buf: *Buf) Self {
return .{ .buf = buf };
}
pub fn initFilled(buf: *Buf, value: T) Self {
var result = init(buf);
result.fill(value);
return result;
}
pub fn initArray(buf: *Buf, a: *const Array) Self {
const result = init(buf);
result.fillArray(a);
return result;
}
pub fn initBuf(buf: *Buf, const_buf: *const Buf) Self {
const result = init(buf);
result.fillArray(@ptrCast(const_buf));
return result;
}
pub fn initConst(arr: *const Array) Self {
return .{ .buf = @ptrCast(@constCast(arr)) };
}
pub fn initConstBuf(buf: *const Buf) Self {
return .{ .buf = @constCast(buf) };
}
pub fn initAlloc(allocator: Allocator) !Self {
return init(try allocator.create(Array));
}
pub fn initFilledAlloc(allocator: Allocator, value: T) !Self {
return initFilled(try allocator.create(Array), value);
}
pub fn initArrayAlloc(allocator: Allocator, arr: *const Array) !Self {
return initArray(try allocator.create(Array), arr);
}
pub fn initBufAlloc(allocator: Allocator, buf: *const Buf) !Self {
return initBuf(try allocator.create(Buf), buf);
}
pub fn deinit(self: Self, allocator: Allocator) void {
allocator.free(self.buf);
}
pub fn reshape(
self: Self,
comptime new_shape: anytype,
) Matrix(T, new_shape, newStrides(toSlice(usize, new_shape))) {
// TODO check for impossible reshape
return .{ .buf = self.buf };
}
pub fn withStrides(
self: Self,
comptime new_strides: anytype,
) Matrix(T, _shape, new_strides) {
// TODO check for impossible re-stride
return .{ .buf = self.buf };
}
fn comptimeReverseSlice(comptime slice: anytype) @TypeOf(slice) {
comptime {
var buf = slice[0..slice.len].*;
mem.reverse(@typeInfo(@TypeOf(buf)).array.child, &buf);
const fbuf = buf;
return &fbuf;
}
}
pub fn transpose(self: Self) Matrix(T, comptimeReverseSlice(self.shape), .{}) {
return .{ .buf = self.buf };
}
fn comptimeSwapAxes(
comptime slice: anytype,
comptime axis1: usize,
comptime axis2: usize,
) @TypeOf(slice) {
comptime {
var buf = slice[0..slice.len].*;
mem.swap(@typeInfo(@TypeOf(buf)).array.child, &buf[axis1], &buf[axis2]);
const fbuf = buf;
return &fbuf;
}
}
pub fn swapAxes(
self: Self,
comptime axis1: usize,
comptime axis2: usize,
) Matrix(T, comptimeSwapAxes(self.shape, axis1, axis2), .{}) {
return .{ .buf = self.buf };
}
fn parseErr(comptime fmt: []const u8, args: anytype) ParseError {
if (@inComptime())
@compileError(std.fmt.comptimePrint(fmt, args))
else {
if (!@import("builtin").is_test) std.log.err(fmt, args);
return error.ParseFailure;
}
}
/// parse numpy like output at comptime
pub inline fn parse(comptime input: []const u8) !Self {
comptime {
var buf: Buf = undefined;
_ = try parseBuf(input, &buf);
const fbuf = buf;
return .{ .buf = @constCast(&fbuf) };
}
}
/// parse numpy like output at runtime into buf
pub fn parseBuf(input: []const u8, buf: *Buf) !Self {
var i: usize = 0;
var depth: usize = 0;
var bufi: usize = 0;
var bracket: u8 = 0;
while (i < input.len) : (i += 1) {
switch (input[i]) {
' ', '\n', '\t', '\r', ',' => {},
'.' => {
if (i + 1 == input.len)
return parseErr("unexpected eof position {}", .{i + 1});
if (input[i + 1] != '{') return parseErr(
"expecting '{{' but found '{c}' at position {}",
.{ input[i + 1], i + 1 },
);
i += 1;
depth += 1;
bracket = '{';
},
'[' => {
depth += 1;
bracket = '[';
},
']', '}' => if (depth > 0) {
depth -= 1;
} else return parseErr("unbalanced bracket at position {}", .{i}),
'0'...'9', '-', '+' => {
if (depth < _shape.len)
return parseErr("missing bracket at position {}", .{i})
else if (depth > _shape.len)
return parseErr("extra bracket at position {}", .{i});
const end_bracket: u8 = if (bracket == '[')
']'
else if (bracket == '{')
'}'
else
unreachable;
const rbi = mem.indexOfScalarPos(u8, input, i, end_bracket) orelse
return parseErr(
"missing closing bracket at position {}",
.{i},
);
var it = mem.tokenizeAny(
u8,
input[i..rbi],
if (bracket == '[') " " else ", ",
);
var j: usize = 0;
while (it.next()) |nr| : (j += 1) {
const n = switch (@typeInfo(T)) {
.int, .comptime_int => try std.fmt.parseInt(T, nr, 10),
.float, .comptime_float => try std.fmt.parseFloat(T, nr),
else => unreachable,
};
if (bufi >= len)
return parseErr(
"expected {} items but found {} at position {}.",
.{ len, bufi + 1, i },
);
buf[bufi] = n;
bufi += 1;
}
if (j != _shape[_shape.len - 1])
return parseErr(
"expected {} items but found {} at position {}.",
.{ _shape[_shape.len - 1], j, i },
);
i = rbi;
depth -= 1;
},
else => return parseErr(
"unexpected character '{c}' at position {}",
.{ input[i], i },
),
}
}
if (depth != 0) return parseErr("eof. missing closing bracket", .{});
if (bufi != len) return parseErr(
"eof. expected {} items but found {}.",
.{ len, bufi },
);
return .{ .buf = buf };
}
/// Return stride offset for the given indices.
fn strideOffset(indices: []const isize, __strides: []const isize) ?usize {
// debug("strideOffset() shape={any} indices={any} strides={any}\n", .{ _shape, indices, __strides });
var offset: isize = 0;
assert(indices.len <= __strides.len);
assert(indices.len <= shape.len);
var it = mem.reverseIterator(indices);
var stit = mem.reverseIterator(__strides);
var shit = mem.reverseIterator(_shape);
// debug("\n", .{});
while (it.next()) |index| {
const stride = stit.next() orelse unreachable;
const shape_signed = @as(isize, @bitCast(shit.next() orelse unreachable));
const this_offset = if (index < 0 and stride < 0)
index * stride - 1
else if (index < 0)
(@mod(index * stride, shape_signed))
else if (stride < 0)
index * stride + shape_signed - 1
// else if (stride < 0)
// @mod(index * stride, shape_signed)
else
index * stride;
offset += this_offset;
debug(
"shape={} index={} stride={} this_offset={} offset={}\n",
.{ shape_signed, index, stride, this_offset, offset },
);
// debug("this_index={} this_stride={}\n", .{ this_index, this_stride });
}
// debug("offset={}\n", .{offset});
return if (offset < 0)
@bitCast(@mod(offset, len))
else if (offset >= len)
null
else
@bitCast(offset);
}
pub fn at(self: Self, indices: anytype) ?T {
const i = strideOffset(toSlice(isize, indices), self.strides) orelse
return null;
debug("at() i={}\n", .{i});
return self.buf[i];
}
pub fn atPtr(self: Self, indices: anytype) ?*T {
const i = strideOffset(toSlice(isize, indices), self.strides) orelse
return null;
return &self.buf[i];
}
pub fn fill(self: Self, value: T) void {
if (@sizeOf(T) > 0)
@memset(@as(*Buf, @ptrCast(self.buf)), value);
}
pub fn fillArray(self: Self, arr: *const Array) void {
@memcpy(self.buf, arr);
}
pub fn subMatrix(ptr: [*]T, comptime dim: usize) Matrix(T, _shape[dim..], _strides[dim..]) {
return .{ .buf = @ptrCast(ptr) };
}
pub const mul = switch (_shape.len) {
1 => mul1d,
2 => mul2d,
else => mulNd,
};
fn array(self: Self) *Array {
return @ptrCast(self.buf);
}
fn mul1d(a: Self, b: anytype, dst: anytype) void {
comptime assert(a.shape.len == 1);
comptime assert(a.shape[0] == b.shape[0]);
dst.fill(0);
for (0.._shape[0]) |i| {
if (b.shape.len == 2) {
for (0..b.shape[1]) |j| {
dst.array()[j] += a.array()[i] * b.array()[i][j];
}
} else {
// @compileLog(
// b.shape,
// @TypeOf(dst.array),
// @TypeOf(a.array),
// @TypeOf(b.array),
// );
// dst.array[0] += a.array[0] * b.array[0][0];
unreachable;
}
// const asubm = subMatrix(@ptrCast(&a.array[i]), 1);
// const bsubm = @TypeOf(b).subMatrix(@ptrCast(&b.array[i]), 1);
// const dsubm = @TypeOf(dst).subMatrix(@ptrCast(&dst.array[i]), 1);
// a.mul2d(bsubm, dsubm);
}
}
fn mul2d(a: Self, b: anytype, dst: anytype) void {
comptime if (a.shape.len != 2) {
@compileLog(a.shape.len);
};
comptime assert(a.shape[1] == b.shape[0]);
dst.fill(0);
for (0.._shape[0]) |i| {
for (0..b.shape[1]) |j| {
for (0.._shape[1]) |k| {
dst.array()[i][j] += a.array()[i][k] * b.array()[k][j];
}
}
}
}
fn mulNd(a: Self, b: anytype, dst: anytype) void {
comptime assert(a.shape.len >= 3);
for (0..dst.shape[0]) |i| {
const asubm = subMatrix(@ptrCast(&a.array()[i]), 1);
const bsubm = @TypeOf(b).subMatrix(@ptrCast(&b.array()[i]), 1);
const dsubm = @TypeOf(dst).subMatrix(@ptrCast(&dst.array()[i]), 1);
switch (_shape.len) {
3 => asubm.mul2d(bsubm, dsubm),
else => asubm.mulNd(bsubm, dsubm),
}
}
}
/// output either a numpy style or zig style text matrix representation.
/// numpy style is the default and results in output like '[0 1]'.
/// when the fmt specifier has a leading 'z', format() outputs zig style literals.
/// for example the specifier '{zd:.1}' will result in output like '.{0.0,1.0}'
pub fn format(
self: Self,
comptime fmt: []const u8,
options: std.fmt.FormatOptions,
writer: anytype,
) !void {
const is_zig_fmt = comptime mem.startsWith(u8, fmt, "z");
_ = try writer.write(if (is_zig_fmt) ".{" else "[");
const _fmt = if (is_zig_fmt) fmt[1..] else fmt;
for (0.._shape[0]) |i| {
if (_shape.len == 1) {
if (i != 0) try writer.writeByte(if (is_zig_fmt) ',' else ' ');
try std.fmt.formatType(self.buf[i], _fmt, options, writer, 1);
} else {
try std.fmt.format(writer, "{" ++ fmt ++ "}", .{subMatrix(@ptrCast(&self.array()[i]), 1)});
if (is_zig_fmt and i + 1 != _shape[0]) try writer.writeByte(',');
}
}
_ = try writer.writeByte(if (is_zig_fmt) '}' else ']');
}
pub fn dump(self: Self, message: []const u8) void {
std.debug.print("--{s}:", .{message});
for (_shape, 0..) |s, i| {
std.debug.print("{s}{}", .{ if (i != 0) "x" else "", s });
}
std.debug.print("--\n", .{});
std.debug.print("{}", .{self});
}
pub const Indices = [_shape.len]isize;
pub fn iterator(self: Self) Iterator {
return .{ .mat = self };
}
pub const Iterator = struct {
mat: Self,
indices: Indices = [1]isize{0} ** _shape.len,
done: bool = false,
/// set 'indices' to next index. returns true when done indicating
/// there are no more indices.
fn nextIndex(indices: *Indices) bool {
var i = _shape.len - 1;
while (true) : (i -= 1) {
const ix = &indices[i];
ix.* += 1;
if (ix.* == _shape[i])
ix.* = 0
else
return false;
if (i == 0) break;
}
return true;
}
pub fn next(iter: *Iterator) ?T {
if (iter.done) return null;
const offset = strideOffset(&iter.indices, iter.mat.strides) orelse
return null;
iter.done = nextIndex(&iter.indices);
// std.log.debug("iter.next() indices={any} offset={}", .{ iter.indices, offset });
return iter.mat.buf[offset];
}
pub fn nextPtr(iter: *Iterator) ?*T {
if (iter.done) return null;
const offset = strideOffset(iter.indices, iter.mat.strides) orelse
return null;
iter.done = nextIndex(iter.mat.dims, iter.indices);
// std.log.debug("indices={any}", .{iter.indices[0..iter.m.dims.len]});
return &iter.mat.buf[offset];
}
pub fn at(iter: Iterator, indices: []const isize) ?T {
const offset = strideOffset(indices, iter.mat.strides) orelse
return null;
return iter.mat.buf[offset];
}
pub fn atPtr(iter: Iterator, indices: []const isize) ?*T {
const offset = strideOffset(indices, iter.mat.strides) orelse
return null;
return &iter.mat.buf[offset];
}
};
};
}
fn InferShape(comptime T: type) []const usize {
comptime {
var shape: []const usize = &.{};
var info = @typeInfo(switch (@typeInfo(T)) {
.pointer => |p| p.child,
inline .array, .vector => |a| blk: {
shape = shape ++ [1]usize{a.len};
break :blk a.child;
},
.@"struct" => |s| blk: {
if (s.is_tuple) {
shape = shape ++ [1]usize{s.fields.len};
break :blk s.fields[0].type;
} else @compileError("unexpected type " ++ @typeName(T));
},
else => |x| @compileError("unexpected type " ++ @tagName(x)),
});
while (true) {
const len, const U = switch (info) {
.array => .{ info.Array.len, info.Array.child },
.@"struct" => .{ info.@"struct".fields.len, info.@"struct".fields[0].type },
else => break,
};
shape = shape ++ [1]usize{len};
info = @typeInfo(U);
}
return shape;
}
}
fn setDefaultStrides(shape: []const usize, strides: []isize) void {
assert(shape.len == strides.len);
@memset(strides, 0);
const any_nonzero = for (shape) |s| {
if (s != 0) break true;
} else false;
if (!any_nonzero) return;
var iter = mem.reverseIterator(strides);
var iter2 = mem.reverseIterator(shape);
if (iter.nextPtr()) |ptr| ptr.* = 1;
var cum_prod: usize = 1;
while (iter.nextPtr()) |rs| {
cum_prod *= iter2.next() orelse unreachable;
rs.* = cum_prod;
}
}
fn newStrides(comptime shape: []const usize) []const isize {
comptime {
var result = [1]isize{0} ** shape.len;
setDefaultStrides(shape, &result);
const fresult = result;
return &fresult;
}
}
inline fn toSlice(comptime T: type, x: anytype) []const T {
return switch (@typeInfo(@TypeOf(x))) {
.@"struct" => |s| if (s.is_tuple)
&x
else
@compileError("unexpected struct type. expected tuple. found " ++
@typeName(@TypeOf(x))),
.array => &x,
.pointer => x,
else => @compileError("unexpected type. expected tuple, array or " ++
" pointer. found " ++ @typeName(@TypeOf(x))),
};
}
/// helper for initializing a const matrix with an inferred shape
pub inline fn constMatrix(comptime T: type, items: anytype) Matrix(T, InferShape(@TypeOf(items)), .{}) {
const M = Matrix(T, InferShape(@TypeOf(items)), .{});
const I = @TypeOf(items);
const shape = @as(M, undefined).shape;
return Matrix(T, shape, .{}).initConst(switch (@typeInfo(I)) {
.pointer => items,
.array, .@"struct", .vector => blk: {
const a: M.Array = items;
break :blk &a;
},
else => @compileError("unexpected type " ++ @typeName(I)),
});
}
fn debug(comptime fmt: []const u8, args: anytype) void {
if (true) return;
std.debug.print(fmt, args);
}
fn testInitHelpers(comptime T: type) !void {
const talloc = testing.allocator;
const M = Matrix(T, .{2}, .{});
const a = try M.initAlloc(talloc);
defer a.deinit(talloc);
const b = try M.initFilledAlloc(talloc, 1);
defer b.deinit(talloc);
try std.testing.expectEqualSlices(T, &.{ 1, 1 }, b.buf);
const c = M.initFilled(b.buf, 2);
try std.testing.expectEqualSlices(T, &.{ 2, 2 }, c.buf);
const d = try M.initArrayAlloc(talloc, &.{ 1, 2 });
defer d.deinit(talloc);
try std.testing.expectEqualSlices(T, &.{ 1, 2 }, d.buf);
const e = M.initArray(d.buf, &.{ 3, 4 });
try std.testing.expectEqualSlices(T, &.{ 3, 4 }, e.buf);
const f = M.initBuf(d.buf, &.{ 3, 4 });
try std.testing.expectEqualSlices(T, &.{ 3, 4 }, f.buf);
const g = try M.initBufAlloc(talloc, &.{ 1, 2 });
defer g.deinit(talloc);
try std.testing.expectEqualSlices(T, &.{ 1, 2 }, g.buf);
}
test "init helpers" {
try testInitHelpers(u8);
// u0, zero sized
const talloc = testing.allocator;
const M = Matrix(u0, .{1}, .{});
const a = try M.initFilledAlloc(talloc, 0);
defer a.deinit(talloc);
try testing.expectEqualSlices(u0, &.{0}, a.buf);
}
test "reshape" {
const a = constMatrix(u8, std.simd.iota(u8, 12));
try testing.expectEqualSlices(isize, &.{1}, a.strides);
{
const b = a.reshape(.{ 3, 4 });
try testing.expectEqualSlices(usize, &.{ 3, 4 }, b.shape);
try testing.expectEqualSlices(isize, &.{ 4, 1 }, b.strides);
try testing.expectEqual(5, b.at(.{ 1, 1 }));
}
const u8x3x2x2 = Matrix(u8, .{ 3, 2, 2 }, .{});
{
const b = comptime u8x3x2x2.initConstBuf(&std.simd.iota(u8, 12));
const c = try u8x3x2x2.parse(std.fmt.comptimePrint("{}", .{b}));
try testing.expectEqualSlices(u8, b.buf, c.buf);
try testing.expectEqualSlices(usize, b.shape, c.shape);
try testing.expectEqualSlices(isize, b.strides, c.strides);
}
{
const b = u8x3x2x2.initConstBuf(&std.simd.iota(u8, 12));
try testing.expectEqualSlices(usize, &.{ 2, 2, 3 }, b.transpose().shape);
try testing.expectEqualSlices(usize, &.{ 2, 3, 2 }, b.swapAxes(0, 1).shape);
}
}
test "format()" {
const a = constMatrix(u8, std.simd.iota(u8, 4));
try testing.expectFmt("[0 1 2 3]", "{}", .{a});
// {z} specifier for zig style literals
try testing.expectFmt(".{0,1,2,3}", "{z}", .{a});
try testing.expectFmt("[[0 1][2 3]]", "{}", .{a.reshape(.{ 2, 2 })});
try testing.expectFmt(".{.{0,1},.{2,3}}", "{z}", .{a.reshape(.{ 2, 2 })});
const u8x2x2 = Matrix(u8, .{ 2, 2 }, .{});
const b = comptime u8x2x2.initConstBuf(&std.simd.iota(u8, 4));
const c = try u8x2x2.parse(std.fmt.comptimePrint("{}", .{b}));
try testing.expectEqualSlices(u8, b.buf, c.buf);
const c2 = try u8x2x2.parse(std.fmt.comptimePrint("{z}", .{b}));
try testing.expectEqualSlices(u8, b.buf, c2.buf);
const d = constMatrix(f32, std.simd.iota(f32, 1));
try testing.expectFmt("[0.0]", "{d:.1}", .{d});
try testing.expectFmt("[0e0]", "{e}", .{d});
try testing.expectFmt(".{0.0}", "{zd:.1}", .{d});
try testing.expectFmt(".{0e0}", "{ze}", .{d});
}
test "strides" {
const a = constMatrix(u8, std.simd.iota(u8, 12));
const b = a.withStrides(.{2});
try testing.expectEqualSlices(usize, &.{12}, b.shape);
try testing.expectEqualSlices(isize, &.{2}, b.strides);
try testing.expectEqual(10, b.at(.{-1}));
try testing.expectEqual(0, b.at(.{0}));
try testing.expectEqual(2, b.at(.{1}));
try testing.expectEqual(4, b.at(.{2}));
try testing.expectEqual(null, b.at(.{6}));
// negative strides
{
const c = a.withStrides(.{-1});
try testing.expectEqual(0, c.at(.{-1}));
try testing.expectEqual(11, c.at(.{0}));
try testing.expectEqual(10, c.at(.{1}));
try testing.expectEqual(0, c.at(.{11}));
// FIXME
// try testing.expectEqual(null, c.at(.{12}));
}
{
const c = a.withStrides(.{-3});
try testing.expectEqualSlices(isize, &.{-3}, c.strides);
try testing.expectEqual(11, c.at(.{0}));
try testing.expectEqual(8, c.at(.{1}));
try testing.expectEqual(5, c.at(.{2}));
try testing.expectEqual(2, c.at(.{3}));
// FIXME
// try testing.expectEqual(null, c.at(.{4}));
}
{
const c = a.reshape(.{ 2, 6 }).withStrides(.{ 6, -3 });
try testing.expectEqual(5, c.at(.{ 0, 0 }));
try testing.expectEqual(2, c.at(.{ 0, 1 }));
try testing.expectEqual(11, c.at(.{ 1, 0 }));
try testing.expectEqual(8, c.at(.{ 1, 1 }));
try testing.expectEqual(null, c.at(.{ 2, 0 }));
}
}
test "Iterator" {
const a = constMatrix(u8, std.simd.iota(u8, 12));
{
var iter = a.iterator();
var i: u8 = 0;
while (iter.next()) |it| : (i += 1) {
try testing.expectEqual(i, it);
}
}
const b = a.reshape(.{ 3, 4 });
{
var iter = b.iterator();
var i: u8 = 0;
while (iter.next()) |it| : (i += 1) {
try testing.expectEqual(i, it);
}
}
}
fn testMul1d(comptime T: type) !void {
const a = constMatrix(T, .{ 1, 2 });
const b = constMatrix(T, .{ .{3}, .{4} });
var cbuf: [1]T = undefined;
const c = Matrix(T, .{1}, .{}).init(&cbuf);
a.mul(b, c);
try testing.expectEqualSlices(T, constMatrix(T, .{11}).buf, c.buf);
}
test "1d mul" {
try testMul1d(u8);
try testMul1d(i8);
try testMul1d(f32);
}
fn testMul2d(comptime T: type) !void {
const a = constMatrix(T, .{
.{ 1, 1 },
.{ 1, 1 },
.{ 1, 1 },
});
const b = constMatrix(T, .{
.{ 2, 2, 2 },
.{ 2, 2, 2 },
});
const C = Matrix(T, .{ 3, 3 }, .{});
var cbuf: C.Buf = undefined;
const c = C.init(&cbuf);
a.mul(b, c);
try testing.expectEqualSlices(T, C.initConst(&.{
.{ 4, 4, 4 },
.{ 4, 4, 4 },
.{ 4, 4, 4 },
}).buf, c.buf);
}
test "2d mul" {
try testMul2d(u8);
try testMul2d(i8);
try testMul2d(f32);
}
fn testMul3d(comptime T: type) !void {
const a = constMatrix(T, .{
.{ .{ 1, 2, 3 }, .{ 4, 5, 6 } },
.{ .{ 7, 8, 9 }, .{ 10, 11, 12 } },
});
const b = constMatrix(T, .{
.{ .{ 1, 2 }, .{ 3, 4 }, .{ 5, 6 } },
.{ .{ 7, 8 }, .{ 9, 10 }, .{ 11, 12 } },
});
const C = Matrix(T, .{ 2, 2, 2 }, .{});
var cbuf: C.Buf = undefined;
const c = C.init(&cbuf);
a.mul(b, c);
try testing.expectEqualSlices(T, C.initConst(&.{
.{ .{ 22, 28 }, .{ 49, 64 } },
.{ .{ 220, 244 }, .{ 301, 334 } },
}).buf, c.buf);
}
test "3d mul" {
try testMul3d(u16);
try testMul3d(i16);
try testMul3d(f32);
}
fn testParse(comptime T: type) !void {
const A = Matrix(T, .{2}, .{});
try testing.expectEqualSlices(T, &.{ 1, 2 }, (try A.parse("[1 2]")).buf);
try testing.expectEqualSlices(T, &.{ 1, 2 }, (try A.parse("[+1 2]")).buf);
try testing.expectEqualSlices(T, &.{ 1, 2 }, (try A.parse(".{1,2}")).buf);
var abuf: [2]T = undefined;
const a2 = try A.parseBuf("[1 2]", &abuf);
try testing.expectEqualSlices(T, &.{ 1, 2 }, a2.buf);
const B = Matrix(T, .{ 2, 2 }, .{});
const b = try B.parse("[[1 2] [3 4]]");
try testing.expectEqualSlices(T, &.{ 1, 2, 3, 4 }, b.buf);
var bbuf: [4]T = undefined;
const b2 = try B.parseBuf("[[1 2] [3 4]]", &bbuf);
try testing.expectEqualSlices(T, &.{ 1, 2, 3, 4 }, b2.buf);
const b3 = try B.parseBuf(".{.{1, 2}, .{3,4}}", &bbuf);
try testing.expectEqualSlices(T, &.{ 1, 2, 3, 4 }, b3.buf);
const C = Matrix(T, .{ 2, 2, 2 }, .{});
const cin =
\\[[[1 2] [3 4]]
\\
\\ [[5 6] [ 7 8]]]
;
{
@setEvalBranchQuota(2000);
const c = try C.parse(cin);
try testing.expectEqualSlices(T, &.{ 1, 2, 3, 4, 5, 6, 7, 8 }, c.buf);
}
var cbuf: [8]T = undefined;
const c2 = try C.parseBuf(cin, &cbuf);
try testing.expectEqualSlices(T, &.{ 1, 2, 3, 4, 5, 6, 7, 8 }, c2.buf);
try testing.expectError(error.ParseFailure, C.parseBuf("a", &cbuf));
try testing.expectError(error.ParseFailure, C.parseBuf("[", &cbuf));
try testing.expectError(error.ParseFailure, C.parseBuf("]", &cbuf));
try testing.expectError(error.ParseFailure, C.parseBuf("[0]", &cbuf));
try testing.expectError(error.ParseFailure, C.parseBuf("[[0]]", &cbuf));
try testing.expectError(error.ParseFailure, C.parseBuf("[[[0]]]", &cbuf));
try testing.expectError(error.ParseFailure, C.parseBuf("[[[[0]]]]]", &cbuf));
try testing.expectError(error.ParseFailure, B.parseBuf("[[1 2]]", &bbuf));
try testing.expectError(error.ParseFailure, B.parseBuf("[[1 2][1 2][1 2]]", &bbuf));
}
test "parse" {
try testParse(u32);
try testParse(i32);
try testing.expectEqualSlices(
i32,
&.{ -1, -2 },
(try Matrix(i32, .{2}, .{}).parse("[-1 -2]")).buf,
);
try testParse(f32);
try testing.expectEqualSlices(
f32,
&.{ -1, -2 },
(try Matrix(f32, .{2}, .{}).parse("[-1 -2]")).buf,
);
}
fn testMul(
comptime T: type,
comptime a_shape: anytype,
comptime a_in: []const u8,
comptime b_shape: anytype,
comptime b_in: []const u8,
comptime c_shape: anytype,
comptime c_in: []const u8,
) !void {
@setEvalBranchQuota(14_000);
const a = try Matrix(T, a_shape, .{}).parse(a_in);
const b = try Matrix(T, b_shape, .{}).parse(b_in);
const C = Matrix(T, c_shape, .{});
var cbuf: C.Buf = undefined;
const c = C.init(&cbuf);
a.mul(b, c);
const expected = try C.parse(c_in);
try testing.expectEqualSlices(T, expected.buf, c.buf);
}
fn testMuls(comptime T: type) !void {
// 1d
try testMul(T, .{2},
\\[6 3]
\\
, .{ 2, 1 },
\\[[7]
\\ [4]]
\\
, .{1},
\\[54]
\\
);
try testMul(T, .{2},
\\[6 3]
\\
, .{ 2, 2 },
\\[[7 4]
\\ [6 9]]
\\
, .{2},
\\[60 51]
\\
);
// try testMul(T, .{ 2, 2, 2 },
// \\[[[7 4]
// \\ [6 9]]
// \\
// \\ [[2 6]
// \\ [7 4]]]
// \\
// , .{2},
// \\[6 3]
// \\
// , .{ 2, 2 },
// \\[[60 51]
// \\ [33 48]]
// \\
// );
// 2d
try testMul(T, .{ 2, 5 },
\\[[6 3 7 4 6]
\\ [9 2 6 7 4]]
\\
, .{ 5, 1 },
\\[[3]
\\ [7]
\\ [7]
\\ [2]
\\ [5]]
\\
, .{ 2, 1 },
\\[[126]
\\ [117]]
\\
);
try testMul(T, .{ 5, 2 },
\\[[6 3]
\\ [7 4]
\\ [6 9]
\\ [2 6]
\\ [7 4]]
\\
, .{ 2, 1 },
\\[[3]
\\ [7]]
\\
, .{ 5, 1 },
\\[[39]
\\ [49]
\\ [81]
\\ [48]
\\ [49]]
\\
);
// 3d
try testMul(T, .{ 3, 3, 2 },
\\[[[6 3]
\\ [7 4]
\\ [6 9]]
\\
\\ [[2 6]
\\ [7 4]
\\ [3 7]]
\\
\\ [[7 2]
\\ [5 4]
\\ [1 7]]]
, .{ 3, 2, 4 },
\\[[[5 1 4 0]
\\ [9 5 8 0]]
\\
\\ [[9 2 6 3]
\\ [8 2 4 2]]
\\
\\ [[6 4 8 6]
\\ [1 3 8 1]]]
, .{ 3, 3, 4 },
\\[[[ 57 21 48 0]
\\ [ 71 27 60 0]
\\ [111 51 96 0]]
\\
\\ [[ 66 16 36 18]
\\ [ 95 22 58 29]
\\ [ 83 20 46 23]]
\\
\\ [[ 44 34 72 44]
\\ [ 34 32 72 34]
\\ [ 13 25 64 13]]]
);
try testMul(T, .{ 1, 3, 2 },
\\[[[6 3]
\\ [7 4]
\\ [6 9]]]
\\
, .{ 1, 2, 4 },
\\[[[2 6 7 4]
\\ [3 7 7 2]]]
\\
, .{ 1, 3, 4 },
\\[[[ 21 57 63 30]
\\ [ 26 70 77 36]
\\ [ 39 99 105 42]]]
\\
);
try testMul(T, .{ 1, 3, 1 },
\\[[[6]
\\ [3]
\\ [7]]]
\\
, .{ 1, 1, 3 },
\\[[[4 6 9]]]
\\
, .{ 1, 3, 3 },
\\[[[24 36 54]
\\ [12 18 27]
\\ [28 42 63]]]
\\
);
// 4d
try testMul(T, .{ 1, 1, 3, 1 },
\\[[[[6]
\\ [3]
\\ [7]]]]
\\
, .{ 1, 1, 1, 3 },
\\[[[[4 6 9]]]]
\\
, .{ 1, 1, 3, 3 },
\\[[[[24 36 54]
\\ [12 18 27]
\\ [28 42 63]]]]
\\
);
try testMul(T, .{ 2, 2, 3, 2 },
\\[[[[6 3]
\\ [7 4]
\\ [6 9]]
\\
\\ [[2 6]
\\ [7 4]
\\ [3 7]]]
\\
\\
\\ [[[7 2]
\\ [5 4]
\\ [1 7]]
\\
\\ [[5 1]
\\ [4 0]
\\ [9 5]]]]
\\
, .{ 2, 2, 2, 3 },
\\[[[[8 0 9]
\\ [2 6 3]]
\\
\\ [[8 2 4]
\\ [2 6 4]]]
\\
\\
\\ [[[8 6 1]
\\ [3 8 1]]
\\
\\ [[9 8 9]
\\ [4 1 3]]]]
\\
, .{ 2, 2, 3, 3 },
\\[[[[ 54 18 63]
\\ [ 64 24 75]
\\ [ 66 54 81]]
\\
\\ [[ 28 40 32]
\\ [ 64 38 44]
\\ [ 38 48 40]]]
\\
\\
\\ [[[ 62 58 9]
\\ [ 52 62 9]
\\ [ 29 62 8]]
\\
\\ [[ 49 41 48]
\\ [ 36 32 36]
\\ [101 77 96]]]]
\\
);
// 5d
try testMul(T, .{ 2, 2, 2, 3, 2 },
\\[[[[[6 3]
\\ [7 4]
\\ [6 9]]
\\
\\ [[2 6]
\\ [7 4]
\\ [3 7]]]
\\
\\
\\ [[[7 2]
\\ [5 4]
\\ [1 7]]
\\
\\ [[5 1]
\\ [4 0]
\\ [9 5]]]]
\\
\\
\\
\\ [[[[8 0]
\\ [9 2]
\\ [6 3]]
\\
\\ [[8 2]
\\ [4 2]
\\ [6 4]]]
\\
\\
\\ [[[8 6]
\\ [1 3]
\\ [8 1]]
\\
\\ [[9 8]
\\ [9 4]
\\ [1 3]]]]]
\\
, .{ 2, 2, 2, 2, 3 },
\\[[[[[6 7 2]
\\ [0 3 1]]
\\
\\ [[7 3 1]
\\ [5 5 9]]]
\\
\\
\\ [[[3 5 1]
\\ [9 1 9]]
\\
\\ [[3 7 6]
\\ [8 7 4]]]]
\\
\\
\\
\\ [[[[1 4 7]
\\ [9 8 8]]
\\
\\ [[0 8 6]
\\ [8 7 0]]]
\\
\\
\\ [[[7 7 2]
\\ [0 7 2]]
\\
\\ [[2 0 4]
\\ [9 6 9]]]]]
\\
, .{ 2, 2, 2, 3, 3 },
\\[[[[[ 36 51 15]
\\ [ 42 61 18]
\\ [ 36 69 21]]
\\
\\ [[ 44 36 56]
\\ [ 69 41 43]
\\ [ 56 44 66]]]
\\
\\
\\ [[[ 39 37 25]
\\ [ 51 29 41]
\\ [ 66 12 64]]
\\
\\ [[ 23 42 34]
\\ [ 12 28 24]
\\ [ 67 98 74]]]]
\\
\\
\\
\\ [[[[ 8 32 56]
\\ [ 27 52 79]
\\ [ 33 48 66]]
\\
\\ [[ 16 78 48]
\\ [ 16 46 24]
\\ [ 32 76 36]]]
\\
\\
\\ [[[ 56 98 28]
\\ [ 7 28 8]
\\ [ 56 63 18]]
\\
\\ [[ 90 48 108]
\\ [ 54 24 72]
\\ [ 29 18 31]]]]]
\\
);
}
test "muls" {
try testMuls(u32);
try testMuls(i32);
try testMuls(f32);
}
test constMatrix {
{
const a = constMatrix(u8, .{ 1, 2 });
try testing.expectEqualSlices(u8, &.{ 1, 2 }, a.buf);
try testing.expectEqualSlices(usize, &.{2}, a.shape);
}
{
const b = constMatrix(u8, .{
.{ 1, 2 },
.{ 3, 4 },
});
try testing.expectEqualSlices(u8, &.{ 1, 2, 3, 4 }, b.buf);
try testing.expectEqualSlices(usize, &.{ 2, 2 }, b.shape);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment