Last active
January 15, 2025 19:00
-
-
Save snowclipsed/339800915f13e1f95e2ceab374633b39 to your computer and use it in GitHub Desktop.
Fast Matrix Multiplication in ZIG in FP32.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
// Can also try: | |
// 8 x 64 | |
// 16 x 64 | |
// Top GFLOPs/s on an Intel® Core™ i7-13620H Processor = 300.9 GFLOPs/s | |
// Comments were added using Claude. | |
// To run simply run zig build-exe -O ReleaseFast matmul_FP32.zig, then run the binary ./matmul_FP32 | |
// To test simply run zig test -O ReleaseFast matmul_FP32.zig | |
// To test performance on a generated binary, run : sudo perf stat -e cache-misses,cache-references,instructions,cycles ./matmul_FP32 | |
// Configuration | |
const T: usize = 64; // Tile size (adjust as necessary) | |
const V: usize = 32; // Vector size for SIMD operations (adjust as necessary) | |
pub fn tiledMatMul(allocator: std.mem.Allocator, A: []const f32, B: []const f32, C: []f32, M: usize, N: usize, K: usize) !void { | |
// Determine the number of available CPU threads | |
const num_threads = try std.Thread.getCpuCount(); | |
// Calculate the number of tiles in each dimension | |
// This rounds up the division result, so partial tiles are also considered | |
const tiles_M = (M + T - 1) / T; | |
const tiles_N = (N + T - 1) / T; | |
// const tiles_K = (K + T - 1) / T; // Not used in the current implementation | |
// Create a queue of work items for the threads to process | |
var work_queue = std.ArrayList(WorkItem).init(allocator); | |
defer work_queue.deinit(); // Ensure that resources are cleaned up | |
// Populate the work queue with the coordinates of each tile | |
for (0..tiles_M) |i| { | |
for (0..tiles_N) |j| { | |
try work_queue.append(.{ .i = i, .j = j }); // Add a work item for each (i, j) tile position | |
} | |
} | |
// Shuffle the work queue to distribute workload evenly among threads | |
var rng = std.rand.DefaultPrng.init(@intCast(std.time.milliTimestamp())); // Seed with current time | |
rng.random().shuffle(WorkItem, work_queue.items); // Shuffle to improve load balancing | |
// Create a thread pool to manage the threads | |
var thread_pool = try std.ArrayList(std.Thread).initCapacity(allocator, num_threads); | |
defer thread_pool.deinit(); // Ensure that resources are cleaned up | |
// Shared context that will be passed to each thread | |
var context = ThreadContext{ | |
.A = A, | |
.B = B, | |
.C = C, | |
.M = M, | |
.N = N, | |
.K = K, | |
.work_queue = &work_queue, | |
.mutex = std.Thread.Mutex{}, | |
}; | |
// Spawn threads to process the work queue | |
for (0..num_threads) |_| { | |
try thread_pool.append(try std.Thread.spawn(.{}, workerThread, .{&context})); | |
} | |
// Wait for all threads to finish their work | |
for (thread_pool.items) |thread| { | |
thread.join(); | |
} | |
} | |
const WorkItem = struct { | |
i: usize, // Tile row index | |
j: usize, // Tile column index | |
}; | |
const ThreadContext = struct { | |
A: []const f32, // Matrix A | |
B: []const f32, // Matrix B | |
C: []f32, // Result matrix C | |
M: usize, // Rows in matrix A and C | |
N: usize, // Columns in matrix B and C | |
K: usize, // Columns in matrix A, Rows in matrix B | |
work_queue: *std.ArrayList(WorkItem), // Pointer to the work queue | |
mutex: std.Thread.Mutex, // Mutex to synchronize access to the work queue | |
}; | |
fn workerThread(context: *ThreadContext) void { | |
while (true) { | |
// Acquire the mutex to safely access the work queue | |
context.mutex.lock(); | |
const work_item = if (context.work_queue.popOrNull()) |item| item else { | |
context.mutex.unlock(); | |
break; // Exit loop if there are no more work items | |
}; | |
context.mutex.unlock(); | |
// Compute the start and end indices for the tile in matrix C | |
const i_start = work_item.i * T; | |
const j_start = work_item.j * T; | |
const i_end = @min(i_start + T, context.M); // Handle edge cases where the tile extends beyond matrix bounds | |
const j_end = @min(j_start + T, context.N); | |
// Local storage for the result of the current tile | |
var local_C: [T][T]f32 = [_][T]f32{[_]f32{0} ** T} ** T; // Initialize with zeros | |
var k: usize = 0; | |
while (k < context.K) : (k += T) { | |
const k_end = @min(k + T, context.K); // Handle edge cases where the tile extends beyond matrix bounds | |
tiledMultiplyKernel(context.A, context.B, &local_C, context.N, context.K, i_start, j_start, k, i_end, j_end, k_end); | |
} | |
// Accumulate the results from the local tile into the global matrix C | |
for (i_start..i_end) |i| { | |
for (j_start..j_end) |j| { | |
context.C[i * context.N + j] += local_C[i - i_start][j - j_start]; // Accumulate into the global result matrix | |
} | |
} | |
} | |
} | |
fn tiledMultiplyKernel(A: []const f32, B: []const f32, local_C: *[T][T]f32, N: usize, K: usize, i_start: usize, j_start: usize, k_start: usize, i_end: usize, j_end: usize, k_end: usize) void { | |
// Local buffers for storing tiles of A and B | |
var A_local: [T][T]f32 = undefined; | |
var B_local: [T][T]f32 = undefined; | |
// Load a tile of A into A_local | |
for (0..T) |i| { | |
for (0..T) |k| { | |
if (i_start + i < i_end and k_start + k < k_end) { | |
A_local[i][k] = A[(i_start + i) * K + (k_start + k)]; | |
} else { | |
A_local[i][k] = 0; // Zero-padding for elements outside the matrix bounds | |
} | |
} | |
} | |
// Load a tile of B into B_local | |
for (0..T) |k| { | |
for (0..T) |j| { | |
if (k_start + k < k_end and j_start + j < j_end) { | |
B_local[k][j] = B[(k_start + k) * N + (j_start + j)]; | |
} else { | |
B_local[k][j] = 0; // Zero-padding for elements outside the matrix bounds | |
} | |
} | |
} | |
// Compute the multiplication of the A and B tiles and store the result in local_C | |
var i: usize = 0; | |
while (i < T) : (i += 1) { | |
var j: usize = 0; | |
while (j < T) : (j += V) { | |
// SIMD vector for accumulating the sum of the products | |
var vec_sum: @Vector(V, f32) = @splat(0); | |
var k: usize = 0; | |
while (k < T) : (k += 1) { | |
const a_val = A_local[i][k]; // Load an element of A | |
const a_vec = @as(@Vector(V, f32), @splat(a_val)); // Broadcast it to a SIMD vector | |
const b_vec = blk: { | |
var temp: @Vector(V, f32) = undefined; | |
for (0..V) |idx| { | |
temp[idx] = B_local[k][j + idx]; // Load a vector from B | |
} | |
break :blk temp; | |
}; | |
vec_sum += a_vec * b_vec; // Multiply and accumulate in the SIMD vector | |
} | |
// Store the results from the SIMD vector back into the local_C tile | |
for (0..V) |idx| { | |
local_C[i][j + idx] += vec_sum[idx]; | |
} | |
} | |
} | |
} | |
// Test function | |
test "tiledMatMul_correctness" { | |
const allocator = std.testing.allocator; | |
const test_sizes = [_][3]usize{ | |
.{ 128, 128, 128 }, | |
.{ 100, 100, 100 }, | |
.{ 200, 150, 175 }, | |
.{ 32, 64, 48 }, | |
.{ 47, 34, 45 }, | |
}; | |
for (test_sizes) |size| { | |
const M = size[0]; | |
const N = size[1]; | |
const K = size[2]; | |
const A = try allocator.alloc(f32, M * K); | |
defer allocator.free(A); | |
const B = try allocator.alloc(f32, K * N); | |
defer allocator.free(B); | |
const C = try allocator.alloc(f32, M * N); | |
defer allocator.free(C); | |
var C_ref = try allocator.alloc(f32, M * N); | |
defer allocator.free(C_ref); | |
// Initialize matrices | |
for (A, 0..) |*val, i| { | |
val.* = @floatFromInt(i % 10); | |
} | |
for (B, 0..) |*val, i| { | |
val.* = @floatFromInt((i + 1) % 10); | |
} | |
@memset(C, 0); | |
@memset(C_ref, 0); | |
// Perform tiled matrix multiplication | |
try tiledMatMul(allocator, A, B, C, M, N, K); | |
// Perform reference matrix multiplication | |
for (0..M) |i| { | |
for (0..N) |j| { | |
var sum: f32 = 0; | |
for (0..K) |k| { | |
sum += A[i * K + k] * B[k * N + j]; | |
} | |
C_ref[i * N + j] = sum; | |
} | |
} | |
// Compare results | |
for (C, C_ref) |c, c_ref| { | |
try std.testing.expectApproxEqAbs(c, c_ref, 1e-10); | |
try std.testing.expectEqual(c, c_ref); | |
} | |
std.debug.print("Test passed for size: M={}, N={}, K={}\n", .{ M, N, K }); | |
} | |
} | |
pub fn calculateGflops(allocator: std.mem.Allocator, M: usize, N: usize, K: usize, iterations: usize) !f64 { | |
const A = try allocator.alloc(f32, M * K); | |
defer allocator.free(A); | |
const B = try allocator.alloc(f32, K * N); | |
defer allocator.free(B); | |
const C = try allocator.alloc(f32, M * N); | |
defer allocator.free(C); | |
// Initialize matrices | |
for (A, 0..) |*val, i| { | |
val.* = @floatFromInt(i % 10); | |
} | |
for (B, 0..) |*val, i| { | |
val.* = @floatFromInt((i + 1) % 10); | |
} | |
// Warmup run | |
try tiledMatMul(allocator, A, B, C, M, N, K); | |
// Timed runs | |
var timer = try std.time.Timer.start(); | |
for (0..iterations) |_| { | |
try tiledMatMul(allocator, A, B, C, M, N, K); | |
} | |
const elapsed_ns = timer.read(); | |
// Calculate GFLOPS | |
const ops = 2 * M * N * K * iterations; // multiply-add is 2 operations | |
const seconds = @as(f64, @floatFromInt(elapsed_ns)) / 1e9; | |
const gflops = @as(f64, @floatFromInt(ops)) / seconds / 1e9; | |
return gflops; | |
} | |
test "GFLOPS_Benchmark" { | |
const allocator = std.testing.allocator; | |
const sizes = [_][3]usize{ | |
.{ 256, 256, 256 }, | |
.{ 512, 512, 512 }, | |
.{ 1024, 1024, 1024 }, | |
.{ 1024, 2048, 1024 }, | |
.{ 2048, 2048, 2048 }, | |
.{ 2048, 4096, 2048 }, | |
.{ 4096, 4096, 4096 }, | |
.{ 8192, 2048, 8192 }, | |
.{ 1152, 4304, 1152 }, | |
.{ 1, 2048, 51200 }, | |
// .{ 8192, 8192, 8192 }, | |
// .{ 8192, 16384, 8192 }, | |
// .{ 16384, 16384, 16384 }, | |
// .{ 16384, 32768, 16384 }, | |
// .{ 32768, 32768, 32768 }, | |
// .{ 32768, 65536, 32768 }, | |
// .{ 65536, 65536, 65536 }, | |
}; | |
const iterations = 5; | |
for (sizes) |size| { | |
const M = size[0]; | |
const N = size[1]; | |
const K = size[2]; | |
const gflops = try calculateGflops(allocator, M, N, K, iterations); | |
std.debug.print("Matrix size: {}x{}x{}, GFLOPS: {d:.2}\n", .{ M, N, K, gflops }); | |
} | |
} | |
pub fn main() !void { | |
var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); | |
defer arena.deinit(); | |
const allocator = arena.allocator(); | |
const sizes = [_][3]usize{ | |
.{ 256, 256, 256 }, | |
.{ 512, 512, 512 }, | |
.{ 1024, 1024, 1024 }, | |
.{ 1024, 2048, 1024 }, | |
.{ 2048, 2048, 2048 }, | |
}; | |
const iterations = 10; | |
std.debug.print("T = {} \n V = {} \n", .{ T, V }); | |
for (sizes) |size| { | |
const M = size[0]; | |
const N = size[1]; | |
const K = size[2]; | |
const gflops = try calculateGflops(allocator, M, N, K, iterations); | |
std.debug.print("Matrix size: {}x{}x{}, GFLOPS: {d:.2}\n", .{ M, N, K, gflops }); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment