Skip to content

Instantly share code, notes, and snippets.

@snowclipsed
Last active January 15, 2025 19:00
Show Gist options
  • Save snowclipsed/339800915f13e1f95e2ceab374633b39 to your computer and use it in GitHub Desktop.
Save snowclipsed/339800915f13e1f95e2ceab374633b39 to your computer and use it in GitHub Desktop.
Fast Matrix Multiplication in ZIG in FP32.
const std = @import("std");
// Can also try:
// 8 x 64
// 16 x 64
// Top GFLOPs/s on an Intel® Core™ i7-13620H Processor = 300.9 GFLOPs/s
// Comments were added using Claude.
// To run simply run zig build-exe -O ReleaseFast matmul_FP32.zig, then run the binary ./matmul_FP32
// To test simply run zig test -O ReleaseFast matmul_FP32.zig
// To test performance on a generated binary, run : sudo perf stat -e cache-misses,cache-references,instructions,cycles ./matmul_FP32
// Configuration
const T: usize = 64; // Tile size (adjust as necessary)
const V: usize = 32; // Vector size for SIMD operations (adjust as necessary)
pub fn tiledMatMul(allocator: std.mem.Allocator, A: []const f32, B: []const f32, C: []f32, M: usize, N: usize, K: usize) !void {
// Determine the number of available CPU threads
const num_threads = try std.Thread.getCpuCount();
// Calculate the number of tiles in each dimension
// This rounds up the division result, so partial tiles are also considered
const tiles_M = (M + T - 1) / T;
const tiles_N = (N + T - 1) / T;
// const tiles_K = (K + T - 1) / T; // Not used in the current implementation
// Create a queue of work items for the threads to process
var work_queue = std.ArrayList(WorkItem).init(allocator);
defer work_queue.deinit(); // Ensure that resources are cleaned up
// Populate the work queue with the coordinates of each tile
for (0..tiles_M) |i| {
for (0..tiles_N) |j| {
try work_queue.append(.{ .i = i, .j = j }); // Add a work item for each (i, j) tile position
}
}
// Shuffle the work queue to distribute workload evenly among threads
var rng = std.rand.DefaultPrng.init(@intCast(std.time.milliTimestamp())); // Seed with current time
rng.random().shuffle(WorkItem, work_queue.items); // Shuffle to improve load balancing
// Create a thread pool to manage the threads
var thread_pool = try std.ArrayList(std.Thread).initCapacity(allocator, num_threads);
defer thread_pool.deinit(); // Ensure that resources are cleaned up
// Shared context that will be passed to each thread
var context = ThreadContext{
.A = A,
.B = B,
.C = C,
.M = M,
.N = N,
.K = K,
.work_queue = &work_queue,
.mutex = std.Thread.Mutex{},
};
// Spawn threads to process the work queue
for (0..num_threads) |_| {
try thread_pool.append(try std.Thread.spawn(.{}, workerThread, .{&context}));
}
// Wait for all threads to finish their work
for (thread_pool.items) |thread| {
thread.join();
}
}
const WorkItem = struct {
i: usize, // Tile row index
j: usize, // Tile column index
};
const ThreadContext = struct {
A: []const f32, // Matrix A
B: []const f32, // Matrix B
C: []f32, // Result matrix C
M: usize, // Rows in matrix A and C
N: usize, // Columns in matrix B and C
K: usize, // Columns in matrix A, Rows in matrix B
work_queue: *std.ArrayList(WorkItem), // Pointer to the work queue
mutex: std.Thread.Mutex, // Mutex to synchronize access to the work queue
};
fn workerThread(context: *ThreadContext) void {
while (true) {
// Acquire the mutex to safely access the work queue
context.mutex.lock();
const work_item = if (context.work_queue.popOrNull()) |item| item else {
context.mutex.unlock();
break; // Exit loop if there are no more work items
};
context.mutex.unlock();
// Compute the start and end indices for the tile in matrix C
const i_start = work_item.i * T;
const j_start = work_item.j * T;
const i_end = @min(i_start + T, context.M); // Handle edge cases where the tile extends beyond matrix bounds
const j_end = @min(j_start + T, context.N);
// Local storage for the result of the current tile
var local_C: [T][T]f32 = [_][T]f32{[_]f32{0} ** T} ** T; // Initialize with zeros
var k: usize = 0;
while (k < context.K) : (k += T) {
const k_end = @min(k + T, context.K); // Handle edge cases where the tile extends beyond matrix bounds
tiledMultiplyKernel(context.A, context.B, &local_C, context.N, context.K, i_start, j_start, k, i_end, j_end, k_end);
}
// Accumulate the results from the local tile into the global matrix C
for (i_start..i_end) |i| {
for (j_start..j_end) |j| {
context.C[i * context.N + j] += local_C[i - i_start][j - j_start]; // Accumulate into the global result matrix
}
}
}
}
fn tiledMultiplyKernel(A: []const f32, B: []const f32, local_C: *[T][T]f32, N: usize, K: usize, i_start: usize, j_start: usize, k_start: usize, i_end: usize, j_end: usize, k_end: usize) void {
// Local buffers for storing tiles of A and B
var A_local: [T][T]f32 = undefined;
var B_local: [T][T]f32 = undefined;
// Load a tile of A into A_local
for (0..T) |i| {
for (0..T) |k| {
if (i_start + i < i_end and k_start + k < k_end) {
A_local[i][k] = A[(i_start + i) * K + (k_start + k)];
} else {
A_local[i][k] = 0; // Zero-padding for elements outside the matrix bounds
}
}
}
// Load a tile of B into B_local
for (0..T) |k| {
for (0..T) |j| {
if (k_start + k < k_end and j_start + j < j_end) {
B_local[k][j] = B[(k_start + k) * N + (j_start + j)];
} else {
B_local[k][j] = 0; // Zero-padding for elements outside the matrix bounds
}
}
}
// Compute the multiplication of the A and B tiles and store the result in local_C
var i: usize = 0;
while (i < T) : (i += 1) {
var j: usize = 0;
while (j < T) : (j += V) {
// SIMD vector for accumulating the sum of the products
var vec_sum: @Vector(V, f32) = @splat(0);
var k: usize = 0;
while (k < T) : (k += 1) {
const a_val = A_local[i][k]; // Load an element of A
const a_vec = @as(@Vector(V, f32), @splat(a_val)); // Broadcast it to a SIMD vector
const b_vec = blk: {
var temp: @Vector(V, f32) = undefined;
for (0..V) |idx| {
temp[idx] = B_local[k][j + idx]; // Load a vector from B
}
break :blk temp;
};
vec_sum += a_vec * b_vec; // Multiply and accumulate in the SIMD vector
}
// Store the results from the SIMD vector back into the local_C tile
for (0..V) |idx| {
local_C[i][j + idx] += vec_sum[idx];
}
}
}
}
// Test function
test "tiledMatMul_correctness" {
const allocator = std.testing.allocator;
const test_sizes = [_][3]usize{
.{ 128, 128, 128 },
.{ 100, 100, 100 },
.{ 200, 150, 175 },
.{ 32, 64, 48 },
.{ 47, 34, 45 },
};
for (test_sizes) |size| {
const M = size[0];
const N = size[1];
const K = size[2];
const A = try allocator.alloc(f32, M * K);
defer allocator.free(A);
const B = try allocator.alloc(f32, K * N);
defer allocator.free(B);
const C = try allocator.alloc(f32, M * N);
defer allocator.free(C);
var C_ref = try allocator.alloc(f32, M * N);
defer allocator.free(C_ref);
// Initialize matrices
for (A, 0..) |*val, i| {
val.* = @floatFromInt(i % 10);
}
for (B, 0..) |*val, i| {
val.* = @floatFromInt((i + 1) % 10);
}
@memset(C, 0);
@memset(C_ref, 0);
// Perform tiled matrix multiplication
try tiledMatMul(allocator, A, B, C, M, N, K);
// Perform reference matrix multiplication
for (0..M) |i| {
for (0..N) |j| {
var sum: f32 = 0;
for (0..K) |k| {
sum += A[i * K + k] * B[k * N + j];
}
C_ref[i * N + j] = sum;
}
}
// Compare results
for (C, C_ref) |c, c_ref| {
try std.testing.expectApproxEqAbs(c, c_ref, 1e-10);
try std.testing.expectEqual(c, c_ref);
}
std.debug.print("Test passed for size: M={}, N={}, K={}\n", .{ M, N, K });
}
}
pub fn calculateGflops(allocator: std.mem.Allocator, M: usize, N: usize, K: usize, iterations: usize) !f64 {
const A = try allocator.alloc(f32, M * K);
defer allocator.free(A);
const B = try allocator.alloc(f32, K * N);
defer allocator.free(B);
const C = try allocator.alloc(f32, M * N);
defer allocator.free(C);
// Initialize matrices
for (A, 0..) |*val, i| {
val.* = @floatFromInt(i % 10);
}
for (B, 0..) |*val, i| {
val.* = @floatFromInt((i + 1) % 10);
}
// Warmup run
try tiledMatMul(allocator, A, B, C, M, N, K);
// Timed runs
var timer = try std.time.Timer.start();
for (0..iterations) |_| {
try tiledMatMul(allocator, A, B, C, M, N, K);
}
const elapsed_ns = timer.read();
// Calculate GFLOPS
const ops = 2 * M * N * K * iterations; // multiply-add is 2 operations
const seconds = @as(f64, @floatFromInt(elapsed_ns)) / 1e9;
const gflops = @as(f64, @floatFromInt(ops)) / seconds / 1e9;
return gflops;
}
test "GFLOPS_Benchmark" {
const allocator = std.testing.allocator;
const sizes = [_][3]usize{
.{ 256, 256, 256 },
.{ 512, 512, 512 },
.{ 1024, 1024, 1024 },
.{ 1024, 2048, 1024 },
.{ 2048, 2048, 2048 },
.{ 2048, 4096, 2048 },
.{ 4096, 4096, 4096 },
.{ 8192, 2048, 8192 },
.{ 1152, 4304, 1152 },
.{ 1, 2048, 51200 },
// .{ 8192, 8192, 8192 },
// .{ 8192, 16384, 8192 },
// .{ 16384, 16384, 16384 },
// .{ 16384, 32768, 16384 },
// .{ 32768, 32768, 32768 },
// .{ 32768, 65536, 32768 },
// .{ 65536, 65536, 65536 },
};
const iterations = 5;
for (sizes) |size| {
const M = size[0];
const N = size[1];
const K = size[2];
const gflops = try calculateGflops(allocator, M, N, K, iterations);
std.debug.print("Matrix size: {}x{}x{}, GFLOPS: {d:.2}\n", .{ M, N, K, gflops });
}
}
pub fn main() !void {
var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);
defer arena.deinit();
const allocator = arena.allocator();
const sizes = [_][3]usize{
.{ 256, 256, 256 },
.{ 512, 512, 512 },
.{ 1024, 1024, 1024 },
.{ 1024, 2048, 1024 },
.{ 2048, 2048, 2048 },
};
const iterations = 10;
std.debug.print("T = {} \n V = {} \n", .{ T, V });
for (sizes) |size| {
const M = size[0];
const N = size[1];
const K = size[2];
const gflops = try calculateGflops(allocator, M, N, K, iterations);
std.debug.print("Matrix size: {}x{}x{}, GFLOPS: {d:.2}\n", .{ M, N, K, gflops });
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment