We want to compile linalg.mmt4d
via "scalable" vectorisation:
%out = linalg.mmt4d ins(%lhs, %rhs: tensor<2x2x4x8xi8>, tensor<?x2x?x8xi8>)
outs(%acc: tensor<2x?x4x?xi32>) -> tensor<2x?x4x?xi32>
import torch | |
import torch.nn as nn | |
class GatherModel(nn.Module): | |
def __init__(self): | |
super(GatherModel, self).__init__() | |
def forward(self, x): | |
print("Input Tensor:") | |
print(x) |
// -----// IR Dump After CleanupBufferAllocView (iree-codegen-cleanup-buffer-alloc-view) //----- // | |
func.func @pipeline_dispatch_0_depthwise_conv_2d_nhwc_hwc_1x10x20x1x1x9_i32() { | |
%c0_i32 = arith.constant 0 : i32 | |
%c10 = arith.constant 10 : index | |
%c20 = arith.constant 20 : index | |
%c0 = arith.constant 0 : index | |
%c1 = arith.constant 1 : index | |
%c2 = arith.constant 2 : index | |
%c5 = arith.constant 5 : index | |
%c3 = arith.constant 3 : index |
// The following matmul maps very nicely onto SME with SVL=128bits. For f32, there are 4 4x4 tiles, | |
// that can be assembled as a 8x8 matrix. | |
// %mat_A_tr - transpose of A | |
func.func @matmul(%mat_A_tr: memref<6x8xf32>, %mat_B: memref<6x8xf32>, %mat_C: memref<8x8xf32>) { | |
linalg.matmul ins(%mat_A_tr, %mat_B: memref<6x8xf32>, memref<6x8xf32>) | |
outs(%mat_C: memref<8x8xf32>) | |
return | |
} |
// -----// IR Dump After TileAndDistributeToWorkgroups (iree-codegen-tile-and-distribute-to-workgroups) //----- // | |
hal.executable.variant public @embedded_elf_arm_64, target = <"llvm-cpu", "embedded-elf-arm_64", {cpu = "generic", cpu_features = "+reserve-x18", data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128", native_vector_size = 16 : index, target_triple = "aarch64-unknown-unknown-eabi-elf"}> { | |
hal.executable.export public @pipeline_dispatch_0_depthwise_conv_2d_nhwc_hwc_1x1080x1920x1x1x43 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) attributes {translation_info = #iree_codegen.translation_info<CPUConvTileAndDecomposeExpert>} { | |
^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index): | |
%c30 = arith.constant 30 : index | |
%c18 = arith.constant 18 : index | |
%c1 = arith.constant 1 : index | |
hal.return %c30, %c18, %c1 : index, index, index |