Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created September 18, 2024 09:50
Show Gist options
  • Save pashu123/0f41d93c12826be20756d40878d3b6ec to your computer and use it in GitHub Desktop.
Save pashu123/0f41d93c12826be20756d40878d3b6ec to your computer and use it in GitHub Desktop.
#map = affine_map<(d0)[s0] -> (-d0 + s0, 64)>
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#pipeline_layout = #hal.pipeline.layout<constants = 10, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>
#translation = #iree_codegen.translation_info<CPUDataTiling>
module {
func.func @_fully_dynamic_pack_simple_dispatch_0_pack_i32() attributes {translation_info = #translation} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c64 = arith.constant 64 : index
%c32_i64 = arith.constant 32 : i64
%0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : i32
%1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : i32
%2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : i32
%3 = hal.interface.constant.load layout(#pipeline_layout) ordinal(3) : i32
%4 = hal.interface.constant.load layout(#pipeline_layout) ordinal(4) : i32
%5 = hal.interface.constant.load layout(#pipeline_layout) ordinal(5) : i32
%6 = hal.interface.constant.load layout(#pipeline_layout) ordinal(6) : i32
%7 = hal.interface.constant.load layout(#pipeline_layout) ordinal(7) : i32
%8 = hal.interface.constant.load layout(#pipeline_layout) ordinal(8) : i32
%9 = hal.interface.constant.load layout(#pipeline_layout) ordinal(9) : i32
%10 = arith.extui %0 : i32 to i64
%11 = arith.extui %1 : i32 to i64
%12 = arith.shli %11, %c32_i64 : i64
%13 = arith.ori %10, %12 : i64
%14 = arith.index_castui %13 : i64 to index
%15 = arith.extui %2 : i32 to i64
%16 = arith.extui %3 : i32 to i64
%17 = arith.shli %16, %c32_i64 : i64
%18 = arith.ori %15, %17 : i64
%19 = arith.index_castui %18 : i64 to index
%20 = arith.extui %4 : i32 to i64
%21 = arith.extui %5 : i32 to i64
%22 = arith.shli %21, %c32_i64 : i64
%23 = arith.ori %20, %22 : i64
%24 = arith.index_castui %23 : i64 to index
%25 = arith.extui %6 : i32 to i64
%26 = arith.extui %7 : i32 to i64
%27 = arith.shli %26, %c32_i64 : i64
%28 = arith.ori %25, %27 : i64
%29 = arith.index_castui %28 : i64 to index
%30 = arith.extui %8 : i32 to i64
%31 = arith.extui %9 : i32 to i64
%32 = arith.shli %31, %c32_i64 : i64
%33 = arith.ori %30, %32 : i64
%34 = arith.index_castui %33 : i64 to index
%35 = flow.dispatch.workload.ordinal %14, 0 : index
%36 = flow.dispatch.workload.ordinal %19, 1 : index
%37 = flow.dispatch.workload.ordinal %24, 2 : index
%38 = flow.dispatch.workload.ordinal %29, 3 : index
%39 = flow.dispatch.workload.ordinal %34, 4 : index
%40 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c64) flags(ReadOnly) : memref<?x?xi32, strided<[?, 1], offset: 16>, #hal.descriptor_type<storage_buffer>>{%35, %36}
memref.assume_alignment %40, 64 : memref<?x?xi32, strided<[?, 1], offset: 16>, #hal.descriptor_type<storage_buffer>>
%41 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c64) flags(Indirect) : memref<?x?x?x?xi32, strided<[?, ?, ?, 1], offset: 16>, #hal.descriptor_type<storage_buffer>>{%37, %38, %39, %39}
memref.assume_alignment %41, 64 : memref<?x?x?x?xi32, strided<[?, ?, ?, 1], offset: 16>, #hal.descriptor_type<storage_buffer>>
scf.forall (%arg0, %arg1) = (0, 0) to (%37, %38) step (64, 64) {
%42 = affine.min #map(%arg0)[%37]
%43 = affine.min #map(%arg1)[%38]
%44 = affine.apply #map1(%arg0)[%39]
%45 = affine.apply #map1(%42)[%39]
%46 = affine.apply #map1(%arg1)[%39]
%47 = affine.apply #map1(%43)[%39]
%subview = memref.subview %40[%44, %46] [%45, %47] [1, 1] : memref<?x?xi32, strided<[?, 1], offset: 16>, #hal.descriptor_type<storage_buffer>> to memref<?x?xi32, strided<[?, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%dim = memref.dim %41, %c2 : memref<?x?x?x?xi32, strided<[?, ?, ?, 1], offset: 16>, #hal.descriptor_type<storage_buffer>>
%dim_0 = memref.dim %41, %c3 : memref<?x?x?x?xi32, strided<[?, ?, ?, 1], offset: 16>, #hal.descriptor_type<storage_buffer>>
%subview_1 = memref.subview %41[%arg0, %arg1, 0, 0] [%42, %43, %dim, %dim_0] [1, 1, 1, 1] : memref<?x?x?x?xi32, strided<[?, ?, ?, 1], offset: 16>, #hal.descriptor_type<storage_buffer>> to memref<?x?x?x?xi32, strided<[?, ?, ?, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%alloc = memref.alloc(%42, %43, %dim, %dim_0) : memref<?x?x?x?xi32>
linalg.generic {indexing_maps = [#map2, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%subview_1 : memref<?x?x?x?xi32, strided<[?, ?, ?, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%alloc : memref<?x?x?x?xi32>) {
^bb0(%in: i32, %out: i32):
linalg.yield %in : i32
}
%48 = scf.for %arg2 = %c0 to %42 step %c1 iter_args(%arg3 = %alloc) -> (memref<?x?x?x?xi32>) {
%49 = affine.apply #map1(%arg2)[%39]
%50 = scf.for %arg4 = %c0 to %43 step %c1 iter_args(%arg5 = %arg3) -> (memref<?x?x?x?xi32>) {
%51 = affine.apply #map1(%arg4)[%39]
%subview_3 = memref.subview %subview[%49, %51] [%39, %39] [1, 1] : memref<?x?xi32, strided<[?, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<?x?xi32, strided<[?, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_4 = memref.subview %arg5[%arg2, %arg4, 0, 0] [1, 1, %dim, %dim_0] [1, 1, 1, 1] : memref<?x?x?x?xi32> to memref<1x1x?x?xi32, strided<[?, ?, ?, 1], offset: ?>>
iree_linalg_ext.pack %subview_3 inner_dims_pos = [0, 1] inner_tiles = [%39, %39] into %subview_4 : (memref<?x?xi32, strided<[?, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x?x?xi32, strided<[?, ?, ?, 1], offset: ?>>)
linalg.generic {indexing_maps = [#map2, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%subview_4 : memref<1x1x?x?xi32, strided<[?, ?, ?, 1], offset: ?>>) outs(%subview_4 : memref<1x1x?x?xi32, strided<[?, ?, ?, 1], offset: ?>>) {
^bb0(%in: i32, %out: i32):
linalg.yield %in : i32
}
scf.yield %arg5 : memref<?x?x?x?xi32>
}
scf.yield %50 : memref<?x?x?x?xi32>
}
%subview_2 = memref.subview %41[%arg0, %arg1, 0, 0] [%42, %43, %39, %39] [1, 1, 1, 1] : memref<?x?x?x?xi32, strided<[?, ?, ?, 1], offset: 16>, #hal.descriptor_type<storage_buffer>> to memref<?x?x?x?xi32, strided<[?, ?, ?, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
linalg.generic {indexing_maps = [#map2, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%48 : memref<?x?x?x?xi32>) outs(%subview_2 : memref<?x?x?x?xi32, strided<[?, ?, ?, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
^bb0(%in: i32, %out: i32):
linalg.yield %in : i32
}
} {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
linalg.generic {indexing_maps = [#map2, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%41 : memref<?x?x?x?xi32, strided<[?, ?, ?, 1], offset: 16>, #hal.descriptor_type<storage_buffer>>) outs(%41 : memref<?x?x?x?xi32, strided<[?, ?, ?, 1], offset: 16>, #hal.descriptor_type<storage_buffer>>) {
^bb0(%in: i32, %out: i32):
linalg.yield %in : i32
}
return
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment