Created
January 19, 2025 07:14
-
-
Save travisstaloch/60beb1973792bc76da9f9b385aa1b644 to your computer and use it in GitHub Desktop.
An incomplete generic matrix lib which imitates some aspects of numpy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//! | |
//! a work in progress | |
//! | |
//! this lib only does multiplication so far. but can parse numpy output at | |
//! runtime and comptime and has lots helpers including array() which | |
//! allows multi indexing i.e `mat.array()[i][j]`. there are some working | |
//! tests for 1d, 2d, 3d and 4d matrices w/ different shapes. | |
//! | |
//! Resources: | |
//! https://ajcr.net/stride-guide-part-1/ | |
const std = @import("std"); | |
const assert = std.debug.assert; | |
const testing = std.testing; | |
const mem = std.mem; | |
const Allocator = mem.Allocator; | |
pub const ParseError = error{ParseFailure}; | |
/// `shape` must be a tuple, array, or slice coercable to []const usize | |
/// `strides` must be a tuple, array, or slice coercable to []const usize with the same length as `shape` | |
pub fn Matrix(comptime T: type, comptime shape: anytype, comptime strides: anytype) type { | |
return struct { | |
/// pointer to an N-dimensional array of `T` shaped according to `shape` | |
buf: *Buf, | |
/// the shape of `Buf`. the number of items in each dimension. | |
comptime shape: []const usize = _shape, | |
/// the strides of `Buf`. | |
comptime strides: []const isize = _strides, | |
/// an N dimensional array of `T` shaped according to `shape` | |
pub const Array = ArrayFromDim(0); | |
/// a flattened, one dimensional array of `T`. has the same memory layout as Array. | |
pub const Buf = [len]T; | |
/// the total number of elements in the Buf | |
pub const len = @reduce(.Mul, @as( | |
@Vector(_shape.len, usize), | |
_shape[0.._shape.len].*, | |
)); | |
pub const child = T; | |
/// returns an Array type for the given dimension. | |
/// for example: | |
/// given shape = .{2,4} | |
/// ArrayFromDim(0) = [2][4]T | |
/// ArrayFromDim(1) = [4]T | |
/// ArrayFromDim(2) = T | |
pub fn ArrayFromDim(comptime dim: usize) type { | |
const s = _shape[dim..]; | |
return switch (s.len) { | |
0 => T, | |
1 => [s[0]]T, | |
2 => [s[0]][s[1]]T, | |
3 => [s[0]][s[1]][s[2]]T, | |
4 => [s[0]][s[1]][s[2]][s[3]]T, | |
5 => [s[0]][s[1]][s[2]][s[3]][s[4]]T, | |
else => unreachable, // TODO | |
}; | |
} | |
const Self = @This(); | |
const _shape = toSlice(usize, shape); | |
const _strides = if (strides.len == 0) | |
newStrides(_shape) | |
else | |
toSlice(isize, strides); | |
pub fn init(buf: *Buf) Self { | |
return .{ .buf = buf }; | |
} | |
pub fn initFilled(buf: *Buf, value: T) Self { | |
var result = init(buf); | |
result.fill(value); | |
return result; | |
} | |
pub fn initArray(buf: *Buf, a: *const Array) Self { | |
const result = init(buf); | |
result.fillArray(a); | |
return result; | |
} | |
pub fn initBuf(buf: *Buf, const_buf: *const Buf) Self { | |
const result = init(buf); | |
result.fillArray(@ptrCast(const_buf)); | |
return result; | |
} | |
pub fn initConst(arr: *const Array) Self { | |
return .{ .buf = @ptrCast(@constCast(arr)) }; | |
} | |
pub fn initConstBuf(buf: *const Buf) Self { | |
return .{ .buf = @constCast(buf) }; | |
} | |
pub fn initAlloc(allocator: Allocator) !Self { | |
return init(try allocator.create(Array)); | |
} | |
pub fn initFilledAlloc(allocator: Allocator, value: T) !Self { | |
return initFilled(try allocator.create(Array), value); | |
} | |
pub fn initArrayAlloc(allocator: Allocator, arr: *const Array) !Self { | |
return initArray(try allocator.create(Array), arr); | |
} | |
pub fn initBufAlloc(allocator: Allocator, buf: *const Buf) !Self { | |
return initBuf(try allocator.create(Buf), buf); | |
} | |
pub fn deinit(self: Self, allocator: Allocator) void { | |
allocator.free(self.buf); | |
} | |
pub fn reshape( | |
self: Self, | |
comptime new_shape: anytype, | |
) Matrix(T, new_shape, newStrides(toSlice(usize, new_shape))) { | |
// TODO check for impossible reshape | |
return .{ .buf = self.buf }; | |
} | |
pub fn withStrides( | |
self: Self, | |
comptime new_strides: anytype, | |
) Matrix(T, _shape, new_strides) { | |
// TODO check for impossible re-stride | |
return .{ .buf = self.buf }; | |
} | |
fn comptimeReverseSlice(comptime slice: anytype) @TypeOf(slice) { | |
comptime { | |
var buf = slice[0..slice.len].*; | |
mem.reverse(@typeInfo(@TypeOf(buf)).array.child, &buf); | |
const fbuf = buf; | |
return &fbuf; | |
} | |
} | |
pub fn transpose(self: Self) Matrix(T, comptimeReverseSlice(self.shape), .{}) { | |
return .{ .buf = self.buf }; | |
} | |
fn comptimeSwapAxes( | |
comptime slice: anytype, | |
comptime axis1: usize, | |
comptime axis2: usize, | |
) @TypeOf(slice) { | |
comptime { | |
var buf = slice[0..slice.len].*; | |
mem.swap(@typeInfo(@TypeOf(buf)).array.child, &buf[axis1], &buf[axis2]); | |
const fbuf = buf; | |
return &fbuf; | |
} | |
} | |
pub fn swapAxes( | |
self: Self, | |
comptime axis1: usize, | |
comptime axis2: usize, | |
) Matrix(T, comptimeSwapAxes(self.shape, axis1, axis2), .{}) { | |
return .{ .buf = self.buf }; | |
} | |
fn parseErr(comptime fmt: []const u8, args: anytype) ParseError { | |
if (@inComptime()) | |
@compileError(std.fmt.comptimePrint(fmt, args)) | |
else { | |
if (!@import("builtin").is_test) std.log.err(fmt, args); | |
return error.ParseFailure; | |
} | |
} | |
/// parse numpy like output at comptime | |
pub inline fn parse(comptime input: []const u8) !Self { | |
comptime { | |
var buf: Buf = undefined; | |
_ = try parseBuf(input, &buf); | |
const fbuf = buf; | |
return .{ .buf = @constCast(&fbuf) }; | |
} | |
} | |
/// parse numpy like output at runtime into buf | |
pub fn parseBuf(input: []const u8, buf: *Buf) !Self { | |
var i: usize = 0; | |
var depth: usize = 0; | |
var bufi: usize = 0; | |
var bracket: u8 = 0; | |
while (i < input.len) : (i += 1) { | |
switch (input[i]) { | |
' ', '\n', '\t', '\r', ',' => {}, | |
'.' => { | |
if (i + 1 == input.len) | |
return parseErr("unexpected eof position {}", .{i + 1}); | |
if (input[i + 1] != '{') return parseErr( | |
"expecting '{{' but found '{c}' at position {}", | |
.{ input[i + 1], i + 1 }, | |
); | |
i += 1; | |
depth += 1; | |
bracket = '{'; | |
}, | |
'[' => { | |
depth += 1; | |
bracket = '['; | |
}, | |
']', '}' => if (depth > 0) { | |
depth -= 1; | |
} else return parseErr("unbalanced bracket at position {}", .{i}), | |
'0'...'9', '-', '+' => { | |
if (depth < _shape.len) | |
return parseErr("missing bracket at position {}", .{i}) | |
else if (depth > _shape.len) | |
return parseErr("extra bracket at position {}", .{i}); | |
const end_bracket: u8 = if (bracket == '[') | |
']' | |
else if (bracket == '{') | |
'}' | |
else | |
unreachable; | |
const rbi = mem.indexOfScalarPos(u8, input, i, end_bracket) orelse | |
return parseErr( | |
"missing closing bracket at position {}", | |
.{i}, | |
); | |
var it = mem.tokenizeAny( | |
u8, | |
input[i..rbi], | |
if (bracket == '[') " " else ", ", | |
); | |
var j: usize = 0; | |
while (it.next()) |nr| : (j += 1) { | |
const n = switch (@typeInfo(T)) { | |
.int, .comptime_int => try std.fmt.parseInt(T, nr, 10), | |
.float, .comptime_float => try std.fmt.parseFloat(T, nr), | |
else => unreachable, | |
}; | |
if (bufi >= len) | |
return parseErr( | |
"expected {} items but found {} at position {}.", | |
.{ len, bufi + 1, i }, | |
); | |
buf[bufi] = n; | |
bufi += 1; | |
} | |
if (j != _shape[_shape.len - 1]) | |
return parseErr( | |
"expected {} items but found {} at position {}.", | |
.{ _shape[_shape.len - 1], j, i }, | |
); | |
i = rbi; | |
depth -= 1; | |
}, | |
else => return parseErr( | |
"unexpected character '{c}' at position {}", | |
.{ input[i], i }, | |
), | |
} | |
} | |
if (depth != 0) return parseErr("eof. missing closing bracket", .{}); | |
if (bufi != len) return parseErr( | |
"eof. expected {} items but found {}.", | |
.{ len, bufi }, | |
); | |
return .{ .buf = buf }; | |
} | |
/// Return stride offset for the given indices. | |
fn strideOffset(indices: []const isize, __strides: []const isize) ?usize { | |
// debug("strideOffset() shape={any} indices={any} strides={any}\n", .{ _shape, indices, __strides }); | |
var offset: isize = 0; | |
assert(indices.len <= __strides.len); | |
assert(indices.len <= shape.len); | |
var it = mem.reverseIterator(indices); | |
var stit = mem.reverseIterator(__strides); | |
var shit = mem.reverseIterator(_shape); | |
// debug("\n", .{}); | |
while (it.next()) |index| { | |
const stride = stit.next() orelse unreachable; | |
const shape_signed = @as(isize, @bitCast(shit.next() orelse unreachable)); | |
const this_offset = if (index < 0 and stride < 0) | |
index * stride - 1 | |
else if (index < 0) | |
(@mod(index * stride, shape_signed)) | |
else if (stride < 0) | |
index * stride + shape_signed - 1 | |
// else if (stride < 0) | |
// @mod(index * stride, shape_signed) | |
else | |
index * stride; | |
offset += this_offset; | |
debug( | |
"shape={} index={} stride={} this_offset={} offset={}\n", | |
.{ shape_signed, index, stride, this_offset, offset }, | |
); | |
// debug("this_index={} this_stride={}\n", .{ this_index, this_stride }); | |
} | |
// debug("offset={}\n", .{offset}); | |
return if (offset < 0) | |
@bitCast(@mod(offset, len)) | |
else if (offset >= len) | |
null | |
else | |
@bitCast(offset); | |
} | |
pub fn at(self: Self, indices: anytype) ?T { | |
const i = strideOffset(toSlice(isize, indices), self.strides) orelse | |
return null; | |
debug("at() i={}\n", .{i}); | |
return self.buf[i]; | |
} | |
pub fn atPtr(self: Self, indices: anytype) ?*T { | |
const i = strideOffset(toSlice(isize, indices), self.strides) orelse | |
return null; | |
return &self.buf[i]; | |
} | |
pub fn fill(self: Self, value: T) void { | |
if (@sizeOf(T) > 0) | |
@memset(@as(*Buf, @ptrCast(self.buf)), value); | |
} | |
pub fn fillArray(self: Self, arr: *const Array) void { | |
@memcpy(self.buf, arr); | |
} | |
pub fn subMatrix(ptr: [*]T, comptime dim: usize) Matrix(T, _shape[dim..], _strides[dim..]) { | |
return .{ .buf = @ptrCast(ptr) }; | |
} | |
pub const mul = switch (_shape.len) { | |
1 => mul1d, | |
2 => mul2d, | |
else => mulNd, | |
}; | |
fn array(self: Self) *Array { | |
return @ptrCast(self.buf); | |
} | |
fn mul1d(a: Self, b: anytype, dst: anytype) void { | |
comptime assert(a.shape.len == 1); | |
comptime assert(a.shape[0] == b.shape[0]); | |
dst.fill(0); | |
for (0.._shape[0]) |i| { | |
if (b.shape.len == 2) { | |
for (0..b.shape[1]) |j| { | |
dst.array()[j] += a.array()[i] * b.array()[i][j]; | |
} | |
} else { | |
// @compileLog( | |
// b.shape, | |
// @TypeOf(dst.array), | |
// @TypeOf(a.array), | |
// @TypeOf(b.array), | |
// ); | |
// dst.array[0] += a.array[0] * b.array[0][0]; | |
unreachable; | |
} | |
// const asubm = subMatrix(@ptrCast(&a.array[i]), 1); | |
// const bsubm = @TypeOf(b).subMatrix(@ptrCast(&b.array[i]), 1); | |
// const dsubm = @TypeOf(dst).subMatrix(@ptrCast(&dst.array[i]), 1); | |
// a.mul2d(bsubm, dsubm); | |
} | |
} | |
fn mul2d(a: Self, b: anytype, dst: anytype) void { | |
comptime if (a.shape.len != 2) { | |
@compileLog(a.shape.len); | |
}; | |
comptime assert(a.shape[1] == b.shape[0]); | |
dst.fill(0); | |
for (0.._shape[0]) |i| { | |
for (0..b.shape[1]) |j| { | |
for (0.._shape[1]) |k| { | |
dst.array()[i][j] += a.array()[i][k] * b.array()[k][j]; | |
} | |
} | |
} | |
} | |
fn mulNd(a: Self, b: anytype, dst: anytype) void { | |
comptime assert(a.shape.len >= 3); | |
for (0..dst.shape[0]) |i| { | |
const asubm = subMatrix(@ptrCast(&a.array()[i]), 1); | |
const bsubm = @TypeOf(b).subMatrix(@ptrCast(&b.array()[i]), 1); | |
const dsubm = @TypeOf(dst).subMatrix(@ptrCast(&dst.array()[i]), 1); | |
switch (_shape.len) { | |
3 => asubm.mul2d(bsubm, dsubm), | |
else => asubm.mulNd(bsubm, dsubm), | |
} | |
} | |
} | |
/// output either a numpy style or zig style text matrix representation. | |
/// numpy style is the default and results in output like '[0 1]'. | |
/// when the fmt specifier has a leading 'z', format() outputs zig style literals. | |
/// for example the specifier '{zd:.1}' will result in output like '.{0.0,1.0}' | |
pub fn format( | |
self: Self, | |
comptime fmt: []const u8, | |
options: std.fmt.FormatOptions, | |
writer: anytype, | |
) !void { | |
const is_zig_fmt = comptime mem.startsWith(u8, fmt, "z"); | |
_ = try writer.write(if (is_zig_fmt) ".{" else "["); | |
const _fmt = if (is_zig_fmt) fmt[1..] else fmt; | |
for (0.._shape[0]) |i| { | |
if (_shape.len == 1) { | |
if (i != 0) try writer.writeByte(if (is_zig_fmt) ',' else ' '); | |
try std.fmt.formatType(self.buf[i], _fmt, options, writer, 1); | |
} else { | |
try std.fmt.format(writer, "{" ++ fmt ++ "}", .{subMatrix(@ptrCast(&self.array()[i]), 1)}); | |
if (is_zig_fmt and i + 1 != _shape[0]) try writer.writeByte(','); | |
} | |
} | |
_ = try writer.writeByte(if (is_zig_fmt) '}' else ']'); | |
} | |
pub fn dump(self: Self, message: []const u8) void { | |
std.debug.print("--{s}:", .{message}); | |
for (_shape, 0..) |s, i| { | |
std.debug.print("{s}{}", .{ if (i != 0) "x" else "", s }); | |
} | |
std.debug.print("--\n", .{}); | |
std.debug.print("{}", .{self}); | |
} | |
pub const Indices = [_shape.len]isize; | |
pub fn iterator(self: Self) Iterator { | |
return .{ .mat = self }; | |
} | |
pub const Iterator = struct { | |
mat: Self, | |
indices: Indices = [1]isize{0} ** _shape.len, | |
done: bool = false, | |
/// set 'indices' to next index. returns true when done indicating | |
/// there are no more indices. | |
fn nextIndex(indices: *Indices) bool { | |
var i = _shape.len - 1; | |
while (true) : (i -= 1) { | |
const ix = &indices[i]; | |
ix.* += 1; | |
if (ix.* == _shape[i]) | |
ix.* = 0 | |
else | |
return false; | |
if (i == 0) break; | |
} | |
return true; | |
} | |
pub fn next(iter: *Iterator) ?T { | |
if (iter.done) return null; | |
const offset = strideOffset(&iter.indices, iter.mat.strides) orelse | |
return null; | |
iter.done = nextIndex(&iter.indices); | |
// std.log.debug("iter.next() indices={any} offset={}", .{ iter.indices, offset }); | |
return iter.mat.buf[offset]; | |
} | |
pub fn nextPtr(iter: *Iterator) ?*T { | |
if (iter.done) return null; | |
const offset = strideOffset(iter.indices, iter.mat.strides) orelse | |
return null; | |
iter.done = nextIndex(iter.mat.dims, iter.indices); | |
// std.log.debug("indices={any}", .{iter.indices[0..iter.m.dims.len]}); | |
return &iter.mat.buf[offset]; | |
} | |
pub fn at(iter: Iterator, indices: []const isize) ?T { | |
const offset = strideOffset(indices, iter.mat.strides) orelse | |
return null; | |
return iter.mat.buf[offset]; | |
} | |
pub fn atPtr(iter: Iterator, indices: []const isize) ?*T { | |
const offset = strideOffset(indices, iter.mat.strides) orelse | |
return null; | |
return &iter.mat.buf[offset]; | |
} | |
}; | |
}; | |
} | |
fn InferShape(comptime T: type) []const usize { | |
comptime { | |
var shape: []const usize = &.{}; | |
var info = @typeInfo(switch (@typeInfo(T)) { | |
.pointer => |p| p.child, | |
inline .array, .vector => |a| blk: { | |
shape = shape ++ [1]usize{a.len}; | |
break :blk a.child; | |
}, | |
.@"struct" => |s| blk: { | |
if (s.is_tuple) { | |
shape = shape ++ [1]usize{s.fields.len}; | |
break :blk s.fields[0].type; | |
} else @compileError("unexpected type " ++ @typeName(T)); | |
}, | |
else => |x| @compileError("unexpected type " ++ @tagName(x)), | |
}); | |
while (true) { | |
const len, const U = switch (info) { | |
.array => .{ info.Array.len, info.Array.child }, | |
.@"struct" => .{ info.@"struct".fields.len, info.@"struct".fields[0].type }, | |
else => break, | |
}; | |
shape = shape ++ [1]usize{len}; | |
info = @typeInfo(U); | |
} | |
return shape; | |
} | |
} | |
fn setDefaultStrides(shape: []const usize, strides: []isize) void { | |
assert(shape.len == strides.len); | |
@memset(strides, 0); | |
const any_nonzero = for (shape) |s| { | |
if (s != 0) break true; | |
} else false; | |
if (!any_nonzero) return; | |
var iter = mem.reverseIterator(strides); | |
var iter2 = mem.reverseIterator(shape); | |
if (iter.nextPtr()) |ptr| ptr.* = 1; | |
var cum_prod: usize = 1; | |
while (iter.nextPtr()) |rs| { | |
cum_prod *= iter2.next() orelse unreachable; | |
rs.* = cum_prod; | |
} | |
} | |
fn newStrides(comptime shape: []const usize) []const isize { | |
comptime { | |
var result = [1]isize{0} ** shape.len; | |
setDefaultStrides(shape, &result); | |
const fresult = result; | |
return &fresult; | |
} | |
} | |
inline fn toSlice(comptime T: type, x: anytype) []const T { | |
return switch (@typeInfo(@TypeOf(x))) { | |
.@"struct" => |s| if (s.is_tuple) | |
&x | |
else | |
@compileError("unexpected struct type. expected tuple. found " ++ | |
@typeName(@TypeOf(x))), | |
.array => &x, | |
.pointer => x, | |
else => @compileError("unexpected type. expected tuple, array or " ++ | |
" pointer. found " ++ @typeName(@TypeOf(x))), | |
}; | |
} | |
/// helper for initializing a const matrix with an inferred shape | |
pub inline fn constMatrix(comptime T: type, items: anytype) Matrix(T, InferShape(@TypeOf(items)), .{}) { | |
const M = Matrix(T, InferShape(@TypeOf(items)), .{}); | |
const I = @TypeOf(items); | |
const shape = @as(M, undefined).shape; | |
return Matrix(T, shape, .{}).initConst(switch (@typeInfo(I)) { | |
.pointer => items, | |
.array, .@"struct", .vector => blk: { | |
const a: M.Array = items; | |
break :blk &a; | |
}, | |
else => @compileError("unexpected type " ++ @typeName(I)), | |
}); | |
} | |
fn debug(comptime fmt: []const u8, args: anytype) void { | |
if (true) return; | |
std.debug.print(fmt, args); | |
} | |
fn testInitHelpers(comptime T: type) !void { | |
const talloc = testing.allocator; | |
const M = Matrix(T, .{2}, .{}); | |
const a = try M.initAlloc(talloc); | |
defer a.deinit(talloc); | |
const b = try M.initFilledAlloc(talloc, 1); | |
defer b.deinit(talloc); | |
try std.testing.expectEqualSlices(T, &.{ 1, 1 }, b.buf); | |
const c = M.initFilled(b.buf, 2); | |
try std.testing.expectEqualSlices(T, &.{ 2, 2 }, c.buf); | |
const d = try M.initArrayAlloc(talloc, &.{ 1, 2 }); | |
defer d.deinit(talloc); | |
try std.testing.expectEqualSlices(T, &.{ 1, 2 }, d.buf); | |
const e = M.initArray(d.buf, &.{ 3, 4 }); | |
try std.testing.expectEqualSlices(T, &.{ 3, 4 }, e.buf); | |
const f = M.initBuf(d.buf, &.{ 3, 4 }); | |
try std.testing.expectEqualSlices(T, &.{ 3, 4 }, f.buf); | |
const g = try M.initBufAlloc(talloc, &.{ 1, 2 }); | |
defer g.deinit(talloc); | |
try std.testing.expectEqualSlices(T, &.{ 1, 2 }, g.buf); | |
} | |
test "init helpers" { | |
try testInitHelpers(u8); | |
// u0, zero sized | |
const talloc = testing.allocator; | |
const M = Matrix(u0, .{1}, .{}); | |
const a = try M.initFilledAlloc(talloc, 0); | |
defer a.deinit(talloc); | |
try testing.expectEqualSlices(u0, &.{0}, a.buf); | |
} | |
test "reshape" { | |
const a = constMatrix(u8, std.simd.iota(u8, 12)); | |
try testing.expectEqualSlices(isize, &.{1}, a.strides); | |
{ | |
const b = a.reshape(.{ 3, 4 }); | |
try testing.expectEqualSlices(usize, &.{ 3, 4 }, b.shape); | |
try testing.expectEqualSlices(isize, &.{ 4, 1 }, b.strides); | |
try testing.expectEqual(5, b.at(.{ 1, 1 })); | |
} | |
const u8x3x2x2 = Matrix(u8, .{ 3, 2, 2 }, .{}); | |
{ | |
const b = comptime u8x3x2x2.initConstBuf(&std.simd.iota(u8, 12)); | |
const c = try u8x3x2x2.parse(std.fmt.comptimePrint("{}", .{b})); | |
try testing.expectEqualSlices(u8, b.buf, c.buf); | |
try testing.expectEqualSlices(usize, b.shape, c.shape); | |
try testing.expectEqualSlices(isize, b.strides, c.strides); | |
} | |
{ | |
const b = u8x3x2x2.initConstBuf(&std.simd.iota(u8, 12)); | |
try testing.expectEqualSlices(usize, &.{ 2, 2, 3 }, b.transpose().shape); | |
try testing.expectEqualSlices(usize, &.{ 2, 3, 2 }, b.swapAxes(0, 1).shape); | |
} | |
} | |
test "format()" { | |
const a = constMatrix(u8, std.simd.iota(u8, 4)); | |
try testing.expectFmt("[0 1 2 3]", "{}", .{a}); | |
// {z} specifier for zig style literals | |
try testing.expectFmt(".{0,1,2,3}", "{z}", .{a}); | |
try testing.expectFmt("[[0 1][2 3]]", "{}", .{a.reshape(.{ 2, 2 })}); | |
try testing.expectFmt(".{.{0,1},.{2,3}}", "{z}", .{a.reshape(.{ 2, 2 })}); | |
const u8x2x2 = Matrix(u8, .{ 2, 2 }, .{}); | |
const b = comptime u8x2x2.initConstBuf(&std.simd.iota(u8, 4)); | |
const c = try u8x2x2.parse(std.fmt.comptimePrint("{}", .{b})); | |
try testing.expectEqualSlices(u8, b.buf, c.buf); | |
const c2 = try u8x2x2.parse(std.fmt.comptimePrint("{z}", .{b})); | |
try testing.expectEqualSlices(u8, b.buf, c2.buf); | |
const d = constMatrix(f32, std.simd.iota(f32, 1)); | |
try testing.expectFmt("[0.0]", "{d:.1}", .{d}); | |
try testing.expectFmt("[0e0]", "{e}", .{d}); | |
try testing.expectFmt(".{0.0}", "{zd:.1}", .{d}); | |
try testing.expectFmt(".{0e0}", "{ze}", .{d}); | |
} | |
test "strides" { | |
const a = constMatrix(u8, std.simd.iota(u8, 12)); | |
const b = a.withStrides(.{2}); | |
try testing.expectEqualSlices(usize, &.{12}, b.shape); | |
try testing.expectEqualSlices(isize, &.{2}, b.strides); | |
try testing.expectEqual(10, b.at(.{-1})); | |
try testing.expectEqual(0, b.at(.{0})); | |
try testing.expectEqual(2, b.at(.{1})); | |
try testing.expectEqual(4, b.at(.{2})); | |
try testing.expectEqual(null, b.at(.{6})); | |
// negative strides | |
{ | |
const c = a.withStrides(.{-1}); | |
try testing.expectEqual(0, c.at(.{-1})); | |
try testing.expectEqual(11, c.at(.{0})); | |
try testing.expectEqual(10, c.at(.{1})); | |
try testing.expectEqual(0, c.at(.{11})); | |
// FIXME | |
// try testing.expectEqual(null, c.at(.{12})); | |
} | |
{ | |
const c = a.withStrides(.{-3}); | |
try testing.expectEqualSlices(isize, &.{-3}, c.strides); | |
try testing.expectEqual(11, c.at(.{0})); | |
try testing.expectEqual(8, c.at(.{1})); | |
try testing.expectEqual(5, c.at(.{2})); | |
try testing.expectEqual(2, c.at(.{3})); | |
// FIXME | |
// try testing.expectEqual(null, c.at(.{4})); | |
} | |
{ | |
const c = a.reshape(.{ 2, 6 }).withStrides(.{ 6, -3 }); | |
try testing.expectEqual(5, c.at(.{ 0, 0 })); | |
try testing.expectEqual(2, c.at(.{ 0, 1 })); | |
try testing.expectEqual(11, c.at(.{ 1, 0 })); | |
try testing.expectEqual(8, c.at(.{ 1, 1 })); | |
try testing.expectEqual(null, c.at(.{ 2, 0 })); | |
} | |
} | |
test "Iterator" { | |
const a = constMatrix(u8, std.simd.iota(u8, 12)); | |
{ | |
var iter = a.iterator(); | |
var i: u8 = 0; | |
while (iter.next()) |it| : (i += 1) { | |
try testing.expectEqual(i, it); | |
} | |
} | |
const b = a.reshape(.{ 3, 4 }); | |
{ | |
var iter = b.iterator(); | |
var i: u8 = 0; | |
while (iter.next()) |it| : (i += 1) { | |
try testing.expectEqual(i, it); | |
} | |
} | |
} | |
fn testMul1d(comptime T: type) !void { | |
const a = constMatrix(T, .{ 1, 2 }); | |
const b = constMatrix(T, .{ .{3}, .{4} }); | |
var cbuf: [1]T = undefined; | |
const c = Matrix(T, .{1}, .{}).init(&cbuf); | |
a.mul(b, c); | |
try testing.expectEqualSlices(T, constMatrix(T, .{11}).buf, c.buf); | |
} | |
test "1d mul" { | |
try testMul1d(u8); | |
try testMul1d(i8); | |
try testMul1d(f32); | |
} | |
fn testMul2d(comptime T: type) !void { | |
const a = constMatrix(T, .{ | |
.{ 1, 1 }, | |
.{ 1, 1 }, | |
.{ 1, 1 }, | |
}); | |
const b = constMatrix(T, .{ | |
.{ 2, 2, 2 }, | |
.{ 2, 2, 2 }, | |
}); | |
const C = Matrix(T, .{ 3, 3 }, .{}); | |
var cbuf: C.Buf = undefined; | |
const c = C.init(&cbuf); | |
a.mul(b, c); | |
try testing.expectEqualSlices(T, C.initConst(&.{ | |
.{ 4, 4, 4 }, | |
.{ 4, 4, 4 }, | |
.{ 4, 4, 4 }, | |
}).buf, c.buf); | |
} | |
test "2d mul" { | |
try testMul2d(u8); | |
try testMul2d(i8); | |
try testMul2d(f32); | |
} | |
fn testMul3d(comptime T: type) !void { | |
const a = constMatrix(T, .{ | |
.{ .{ 1, 2, 3 }, .{ 4, 5, 6 } }, | |
.{ .{ 7, 8, 9 }, .{ 10, 11, 12 } }, | |
}); | |
const b = constMatrix(T, .{ | |
.{ .{ 1, 2 }, .{ 3, 4 }, .{ 5, 6 } }, | |
.{ .{ 7, 8 }, .{ 9, 10 }, .{ 11, 12 } }, | |
}); | |
const C = Matrix(T, .{ 2, 2, 2 }, .{}); | |
var cbuf: C.Buf = undefined; | |
const c = C.init(&cbuf); | |
a.mul(b, c); | |
try testing.expectEqualSlices(T, C.initConst(&.{ | |
.{ .{ 22, 28 }, .{ 49, 64 } }, | |
.{ .{ 220, 244 }, .{ 301, 334 } }, | |
}).buf, c.buf); | |
} | |
test "3d mul" { | |
try testMul3d(u16); | |
try testMul3d(i16); | |
try testMul3d(f32); | |
} | |
fn testParse(comptime T: type) !void { | |
const A = Matrix(T, .{2}, .{}); | |
try testing.expectEqualSlices(T, &.{ 1, 2 }, (try A.parse("[1 2]")).buf); | |
try testing.expectEqualSlices(T, &.{ 1, 2 }, (try A.parse("[+1 2]")).buf); | |
try testing.expectEqualSlices(T, &.{ 1, 2 }, (try A.parse(".{1,2}")).buf); | |
var abuf: [2]T = undefined; | |
const a2 = try A.parseBuf("[1 2]", &abuf); | |
try testing.expectEqualSlices(T, &.{ 1, 2 }, a2.buf); | |
const B = Matrix(T, .{ 2, 2 }, .{}); | |
const b = try B.parse("[[1 2] [3 4]]"); | |
try testing.expectEqualSlices(T, &.{ 1, 2, 3, 4 }, b.buf); | |
var bbuf: [4]T = undefined; | |
const b2 = try B.parseBuf("[[1 2] [3 4]]", &bbuf); | |
try testing.expectEqualSlices(T, &.{ 1, 2, 3, 4 }, b2.buf); | |
const b3 = try B.parseBuf(".{.{1, 2}, .{3,4}}", &bbuf); | |
try testing.expectEqualSlices(T, &.{ 1, 2, 3, 4 }, b3.buf); | |
const C = Matrix(T, .{ 2, 2, 2 }, .{}); | |
const cin = | |
\\[[[1 2] [3 4]] | |
\\ | |
\\ [[5 6] [ 7 8]]] | |
; | |
{ | |
@setEvalBranchQuota(2000); | |
const c = try C.parse(cin); | |
try testing.expectEqualSlices(T, &.{ 1, 2, 3, 4, 5, 6, 7, 8 }, c.buf); | |
} | |
var cbuf: [8]T = undefined; | |
const c2 = try C.parseBuf(cin, &cbuf); | |
try testing.expectEqualSlices(T, &.{ 1, 2, 3, 4, 5, 6, 7, 8 }, c2.buf); | |
try testing.expectError(error.ParseFailure, C.parseBuf("a", &cbuf)); | |
try testing.expectError(error.ParseFailure, C.parseBuf("[", &cbuf)); | |
try testing.expectError(error.ParseFailure, C.parseBuf("]", &cbuf)); | |
try testing.expectError(error.ParseFailure, C.parseBuf("[0]", &cbuf)); | |
try testing.expectError(error.ParseFailure, C.parseBuf("[[0]]", &cbuf)); | |
try testing.expectError(error.ParseFailure, C.parseBuf("[[[0]]]", &cbuf)); | |
try testing.expectError(error.ParseFailure, C.parseBuf("[[[[0]]]]]", &cbuf)); | |
try testing.expectError(error.ParseFailure, B.parseBuf("[[1 2]]", &bbuf)); | |
try testing.expectError(error.ParseFailure, B.parseBuf("[[1 2][1 2][1 2]]", &bbuf)); | |
} | |
test "parse" { | |
try testParse(u32); | |
try testParse(i32); | |
try testing.expectEqualSlices( | |
i32, | |
&.{ -1, -2 }, | |
(try Matrix(i32, .{2}, .{}).parse("[-1 -2]")).buf, | |
); | |
try testParse(f32); | |
try testing.expectEqualSlices( | |
f32, | |
&.{ -1, -2 }, | |
(try Matrix(f32, .{2}, .{}).parse("[-1 -2]")).buf, | |
); | |
} | |
fn testMul( | |
comptime T: type, | |
comptime a_shape: anytype, | |
comptime a_in: []const u8, | |
comptime b_shape: anytype, | |
comptime b_in: []const u8, | |
comptime c_shape: anytype, | |
comptime c_in: []const u8, | |
) !void { | |
@setEvalBranchQuota(14_000); | |
const a = try Matrix(T, a_shape, .{}).parse(a_in); | |
const b = try Matrix(T, b_shape, .{}).parse(b_in); | |
const C = Matrix(T, c_shape, .{}); | |
var cbuf: C.Buf = undefined; | |
const c = C.init(&cbuf); | |
a.mul(b, c); | |
const expected = try C.parse(c_in); | |
try testing.expectEqualSlices(T, expected.buf, c.buf); | |
} | |
fn testMuls(comptime T: type) !void { | |
// 1d | |
try testMul(T, .{2}, | |
\\[6 3] | |
\\ | |
, .{ 2, 1 }, | |
\\[[7] | |
\\ [4]] | |
\\ | |
, .{1}, | |
\\[54] | |
\\ | |
); | |
try testMul(T, .{2}, | |
\\[6 3] | |
\\ | |
, .{ 2, 2 }, | |
\\[[7 4] | |
\\ [6 9]] | |
\\ | |
, .{2}, | |
\\[60 51] | |
\\ | |
); | |
// try testMul(T, .{ 2, 2, 2 }, | |
// \\[[[7 4] | |
// \\ [6 9]] | |
// \\ | |
// \\ [[2 6] | |
// \\ [7 4]]] | |
// \\ | |
// , .{2}, | |
// \\[6 3] | |
// \\ | |
// , .{ 2, 2 }, | |
// \\[[60 51] | |
// \\ [33 48]] | |
// \\ | |
// ); | |
// 2d | |
try testMul(T, .{ 2, 5 }, | |
\\[[6 3 7 4 6] | |
\\ [9 2 6 7 4]] | |
\\ | |
, .{ 5, 1 }, | |
\\[[3] | |
\\ [7] | |
\\ [7] | |
\\ [2] | |
\\ [5]] | |
\\ | |
, .{ 2, 1 }, | |
\\[[126] | |
\\ [117]] | |
\\ | |
); | |
try testMul(T, .{ 5, 2 }, | |
\\[[6 3] | |
\\ [7 4] | |
\\ [6 9] | |
\\ [2 6] | |
\\ [7 4]] | |
\\ | |
, .{ 2, 1 }, | |
\\[[3] | |
\\ [7]] | |
\\ | |
, .{ 5, 1 }, | |
\\[[39] | |
\\ [49] | |
\\ [81] | |
\\ [48] | |
\\ [49]] | |
\\ | |
); | |
// 3d | |
try testMul(T, .{ 3, 3, 2 }, | |
\\[[[6 3] | |
\\ [7 4] | |
\\ [6 9]] | |
\\ | |
\\ [[2 6] | |
\\ [7 4] | |
\\ [3 7]] | |
\\ | |
\\ [[7 2] | |
\\ [5 4] | |
\\ [1 7]]] | |
, .{ 3, 2, 4 }, | |
\\[[[5 1 4 0] | |
\\ [9 5 8 0]] | |
\\ | |
\\ [[9 2 6 3] | |
\\ [8 2 4 2]] | |
\\ | |
\\ [[6 4 8 6] | |
\\ [1 3 8 1]]] | |
, .{ 3, 3, 4 }, | |
\\[[[ 57 21 48 0] | |
\\ [ 71 27 60 0] | |
\\ [111 51 96 0]] | |
\\ | |
\\ [[ 66 16 36 18] | |
\\ [ 95 22 58 29] | |
\\ [ 83 20 46 23]] | |
\\ | |
\\ [[ 44 34 72 44] | |
\\ [ 34 32 72 34] | |
\\ [ 13 25 64 13]]] | |
); | |
try testMul(T, .{ 1, 3, 2 }, | |
\\[[[6 3] | |
\\ [7 4] | |
\\ [6 9]]] | |
\\ | |
, .{ 1, 2, 4 }, | |
\\[[[2 6 7 4] | |
\\ [3 7 7 2]]] | |
\\ | |
, .{ 1, 3, 4 }, | |
\\[[[ 21 57 63 30] | |
\\ [ 26 70 77 36] | |
\\ [ 39 99 105 42]]] | |
\\ | |
); | |
try testMul(T, .{ 1, 3, 1 }, | |
\\[[[6] | |
\\ [3] | |
\\ [7]]] | |
\\ | |
, .{ 1, 1, 3 }, | |
\\[[[4 6 9]]] | |
\\ | |
, .{ 1, 3, 3 }, | |
\\[[[24 36 54] | |
\\ [12 18 27] | |
\\ [28 42 63]]] | |
\\ | |
); | |
// 4d | |
try testMul(T, .{ 1, 1, 3, 1 }, | |
\\[[[[6] | |
\\ [3] | |
\\ [7]]]] | |
\\ | |
, .{ 1, 1, 1, 3 }, | |
\\[[[[4 6 9]]]] | |
\\ | |
, .{ 1, 1, 3, 3 }, | |
\\[[[[24 36 54] | |
\\ [12 18 27] | |
\\ [28 42 63]]]] | |
\\ | |
); | |
try testMul(T, .{ 2, 2, 3, 2 }, | |
\\[[[[6 3] | |
\\ [7 4] | |
\\ [6 9]] | |
\\ | |
\\ [[2 6] | |
\\ [7 4] | |
\\ [3 7]]] | |
\\ | |
\\ | |
\\ [[[7 2] | |
\\ [5 4] | |
\\ [1 7]] | |
\\ | |
\\ [[5 1] | |
\\ [4 0] | |
\\ [9 5]]]] | |
\\ | |
, .{ 2, 2, 2, 3 }, | |
\\[[[[8 0 9] | |
\\ [2 6 3]] | |
\\ | |
\\ [[8 2 4] | |
\\ [2 6 4]]] | |
\\ | |
\\ | |
\\ [[[8 6 1] | |
\\ [3 8 1]] | |
\\ | |
\\ [[9 8 9] | |
\\ [4 1 3]]]] | |
\\ | |
, .{ 2, 2, 3, 3 }, | |
\\[[[[ 54 18 63] | |
\\ [ 64 24 75] | |
\\ [ 66 54 81]] | |
\\ | |
\\ [[ 28 40 32] | |
\\ [ 64 38 44] | |
\\ [ 38 48 40]]] | |
\\ | |
\\ | |
\\ [[[ 62 58 9] | |
\\ [ 52 62 9] | |
\\ [ 29 62 8]] | |
\\ | |
\\ [[ 49 41 48] | |
\\ [ 36 32 36] | |
\\ [101 77 96]]]] | |
\\ | |
); | |
// 5d | |
try testMul(T, .{ 2, 2, 2, 3, 2 }, | |
\\[[[[[6 3] | |
\\ [7 4] | |
\\ [6 9]] | |
\\ | |
\\ [[2 6] | |
\\ [7 4] | |
\\ [3 7]]] | |
\\ | |
\\ | |
\\ [[[7 2] | |
\\ [5 4] | |
\\ [1 7]] | |
\\ | |
\\ [[5 1] | |
\\ [4 0] | |
\\ [9 5]]]] | |
\\ | |
\\ | |
\\ | |
\\ [[[[8 0] | |
\\ [9 2] | |
\\ [6 3]] | |
\\ | |
\\ [[8 2] | |
\\ [4 2] | |
\\ [6 4]]] | |
\\ | |
\\ | |
\\ [[[8 6] | |
\\ [1 3] | |
\\ [8 1]] | |
\\ | |
\\ [[9 8] | |
\\ [9 4] | |
\\ [1 3]]]]] | |
\\ | |
, .{ 2, 2, 2, 2, 3 }, | |
\\[[[[[6 7 2] | |
\\ [0 3 1]] | |
\\ | |
\\ [[7 3 1] | |
\\ [5 5 9]]] | |
\\ | |
\\ | |
\\ [[[3 5 1] | |
\\ [9 1 9]] | |
\\ | |
\\ [[3 7 6] | |
\\ [8 7 4]]]] | |
\\ | |
\\ | |
\\ | |
\\ [[[[1 4 7] | |
\\ [9 8 8]] | |
\\ | |
\\ [[0 8 6] | |
\\ [8 7 0]]] | |
\\ | |
\\ | |
\\ [[[7 7 2] | |
\\ [0 7 2]] | |
\\ | |
\\ [[2 0 4] | |
\\ [9 6 9]]]]] | |
\\ | |
, .{ 2, 2, 2, 3, 3 }, | |
\\[[[[[ 36 51 15] | |
\\ [ 42 61 18] | |
\\ [ 36 69 21]] | |
\\ | |
\\ [[ 44 36 56] | |
\\ [ 69 41 43] | |
\\ [ 56 44 66]]] | |
\\ | |
\\ | |
\\ [[[ 39 37 25] | |
\\ [ 51 29 41] | |
\\ [ 66 12 64]] | |
\\ | |
\\ [[ 23 42 34] | |
\\ [ 12 28 24] | |
\\ [ 67 98 74]]]] | |
\\ | |
\\ | |
\\ | |
\\ [[[[ 8 32 56] | |
\\ [ 27 52 79] | |
\\ [ 33 48 66]] | |
\\ | |
\\ [[ 16 78 48] | |
\\ [ 16 46 24] | |
\\ [ 32 76 36]]] | |
\\ | |
\\ | |
\\ [[[ 56 98 28] | |
\\ [ 7 28 8] | |
\\ [ 56 63 18]] | |
\\ | |
\\ [[ 90 48 108] | |
\\ [ 54 24 72] | |
\\ [ 29 18 31]]]]] | |
\\ | |
); | |
} | |
test "muls" { | |
try testMuls(u32); | |
try testMuls(i32); | |
try testMuls(f32); | |
} | |
test constMatrix { | |
{ | |
const a = constMatrix(u8, .{ 1, 2 }); | |
try testing.expectEqualSlices(u8, &.{ 1, 2 }, a.buf); | |
try testing.expectEqualSlices(usize, &.{2}, a.shape); | |
} | |
{ | |
const b = constMatrix(u8, .{ | |
.{ 1, 2 }, | |
.{ 3, 4 }, | |
}); | |
try testing.expectEqualSlices(u8, &.{ 1, 2, 3, 4 }, b.buf); | |
try testing.expectEqualSlices(usize, &.{ 2, 2 }, b.shape); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment