Skip to content

Instantly share code, notes, and snippets.

@bjacob
Created April 25, 2025 16:03
Show Gist options
  • Save bjacob/b603a913b35f35330bdfefb184b6b826 to your computer and use it in GitHub Desktop.
Save bjacob/b603a913b35f35330bdfefb184b6b826 to your computer and use it in GitHub Desktop.
Test setup for FP8 pingpong after Llama dispatch
builtin.module @calls attributes {
} {
func.func private @matmul_test.generate_random_matrix(%device: !hal.device, %dim0: i64, %dim1: i64, %element_type: i32, %seed: i32) -> !hal.buffer_view
func.func private @matmul_test.check_matmul_results(%device: !hal.device, %m: i64, %k: i64, %n: i64, %transpose_rhs: i32, %lhs: !hal.buffer_view, %rhs: !hal.buffer_view, %acc: !hal.buffer_view, %actual_result: !hal.buffer_view)
func.func private @module.matmul(%lhs: !hal.buffer_view, %rhs: !hal.buffer_view) -> !hal.buffer_view
func.func @matmul() attributes {
iree.reflection = {description = "Matmul shape (MxKxN): 1280x4096x4096"}
} {
%device_index = arith.constant 0 : index
%device = hal.devices.get %device_index : !hal.device
%lhs_dim0 = arith.constant 1280 : i64
%lhs_dim1 = arith.constant 4096 : i64
%lhs_element_type = hal.element_type<f8E4M3FNUZ> : i32
%lhs_seed = arith.constant 5 : i32
%lhs = call @matmul_test.generate_random_matrix(%device, %lhs_dim0, %lhs_dim1, %lhs_element_type, %lhs_seed) : (!hal.device, i64, i64, i32, i32) -> !hal.buffer_view
%rhs_dim0 = arith.constant 4096 : i64
%rhs_dim1 = arith.constant 4096 : i64
%rhs_element_type = hal.element_type<f8E4M3FNUZ> : i32
%rhs_seed = arith.constant 6 : i32
%rhs = call @matmul_test.generate_random_matrix(%device, %rhs_dim0, %rhs_dim1, %rhs_element_type, %rhs_seed) : (!hal.device, i64, i64, i32, i32) -> !hal.buffer_view
%acc = util.null : !hal.buffer_view
%result = call @module.matmul(%lhs, %rhs) : (!hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view
%m = arith.constant 1280 : i64
%k = arith.constant 4096 : i64
%n = arith.constant 4096 : i64
%transpose_rhs = arith.constant 1 : i32
call @matmul_test.check_matmul_results(%device, %m, %k, %n, %transpose_rhs, %lhs, %rhs, %acc, %result) : (!hal.device, i64, i64, i64, i32, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view) -> ()
return
}
}
func.func @matmul(%lhs: tensor<?x4096xf8E4M3FNUZ>, %rhs: tensor<4096x4096xf8E4M3FNUZ>) -> tensor<?x4096xf32> {
%c0 = arith.constant 0 : index
%c256 = arith.constant 256 : index
%m = tensor.dim %lhs, %c0 : tensor<?x4096xf8E4M3FNUZ>
%m_outer = arith.divsi %m, %c256 : index
%lhs_expanded = tensor.expand_shape %lhs [[0, 1], [2]] output_shape [%m_outer, 256, 4096] : tensor<?x4096xf8E4M3FNUZ> into tensor<?x256x4096xf8E4M3FNUZ>
%init_acc = tensor.empty(%m_outer) : tensor<?x256x4096xf32>
%c0_acc_type = arith.constant 0.0: f32
%acc = linalg.fill ins(%c0_acc_type : f32) outs(%init_acc : tensor<?x256x4096xf32>) -> tensor<?x256x4096xf32>
%result_expanded = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
], iterator_types = [
"parallel", "parallel", "parallel", "reduction"
]
} ins(%lhs_expanded, %rhs : tensor<?x256x4096xf8E4M3FNUZ>, tensor<4096x4096xf8E4M3FNUZ>)
outs(%acc : tensor<?x256x4096xf32>)
{
^bb0(%lhs_val: f8E4M3FNUZ, %rhs_val: f8E4M3FNUZ, %out: f32):
%56 = arith.extf %lhs_val : f8E4M3FNUZ to f32
%57 = arith.extf %rhs_val : f8E4M3FNUZ to f32
%58 = arith.mulf %56, %57 : f32
%59 = arith.addf %out, %58 : f32
linalg.yield %59 : f32
} -> tensor<?x256x4096xf32>
%result = tensor.collapse_shape %result_expanded [[0, 1], [2]] : tensor<?x256x4096xf32> into tensor<?x4096xf32>
return %result: tensor<?x4096xf32>
}
#!/bin/bash
set -eux
ninja -C ~/iree-build iree-compile iree-e2e-matmul-test
~/iree-build/tools/iree-compile calls.mlir \
--iree-hip-target=gfx942 \
--iree-hal-target-backends=rocm \
-o tmp/calls.vmfb
~/iree-build/tools/iree-compile matmul.mlir \
--iree-hip-target=gfx942 \
--iree-hal-target-backends=rocm \
--mlir-disable-threading \
--iree-codegen-enable-default-tuning-specs=true \
-o tmp/dispatch.vmfb
#--mlir-print-ir-after-all \
~/iree-build/tools/testing/e2e/iree-e2e-matmul-test --device=hip \
--module=tmp/dispatch.vmfb \
--module=tmp/calls.vmfb \
--acceptable_fp_delta=1e-02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment