Skip to content

Instantly share code, notes, and snippets.

@ita9naiwa
Created February 16, 2025 07:51
Show Gist options
  • Save ita9naiwa/0eec0bede4c46a90d1121e00348f39b1 to your computer and use it in GitHub Desktop.
Save ita9naiwa/0eec0bede4c46a90d1121e00348f39b1 to your computer and use it in GitHub Desktop.
folding `tensor.pad`
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @test_fusion(%arg0: tensor<32x16x256x256xf32>, %arg1: tensor<32xf32>, %arg2: tensor<32x16xf32>, %arg3: tensor<32x16xf32>) -> tensor<512x258x258xf32> {
%cst = arith.constant 1.000000e+00 : f32
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<32x16x256x256xf32>, tensor<32xf32>, tensor<32x16xf32>, tensor<32x16xf32>) outs(%arg0 : tensor<32x16x256x256xf32>) {
^bb0(%in: f32, %in_1: f32, %in_2: f32, %in_3: f32, %out: f32):
%1 = arith.addf %in_1, %cst : f32
%2 = math.rsqrt %1 : f32
%3 = arith.mulf %in, %2 : f32
%4 = arith.mulf %3, %in_2 : f32
%5 = arith.addf %4, %in_3 : f32
%6 = arith.negf %5 : f32
%7 = math.exp %6 : f32
%8 = arith.addf %7, %cst : f32
%9 = arith.divf %cst, %8 : f32
%10 = arith.mulf %9, %5 : f32
linalg.yield %10 : f32
} -> tensor<32x16x256x256xf32>
%collapsed = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<32x16x256x256xf32> into tensor<512x256x256xf32>
%padded = tensor.pad %collapsed low[0, 1, 1] high[0, 1, 1] {
^bb0(%arg4: index, %arg5: index, %arg6: index):
tensor.yield %cst_0 : f32
} : tensor<512x256x256xf32> to tensor<512x258x258xf32>
return %padded : tensor<512x258x258xf32>
}
// -----// IR Dump After BlockDynamicDimensionsPass (iree-codegen-block-dynamic-dimensions) //----- //
func.func @test_fusion(%arg0: tensor<32x16x256x256xf32>, %arg1: tensor<32xf32>, %arg2: tensor<32x16xf32>, %arg3: tensor<32x16xf32>) -> tensor<512x258x258xf32> {
%cst = arith.constant 1.000000e+00 : f32
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<32x16x256x256xf32>, tensor<32xf32>, tensor<32x16xf32>, tensor<32x16xf32>) outs(%arg0 : tensor<32x16x256x256xf32>) {
^bb0(%in: f32, %in_1: f32, %in_2: f32, %in_3: f32, %out: f32):
%1 = arith.addf %in_1, %cst : f32
%2 = math.rsqrt %1 : f32
%3 = arith.mulf %in, %2 : f32
%4 = arith.mulf %3, %in_2 : f32
%5 = arith.addf %4, %in_3 : f32
%6 = arith.negf %5 : f32
%7 = math.exp %6 : f32
%8 = arith.addf %7, %cst : f32
%9 = arith.divf %cst, %8 : f32
%10 = arith.mulf %9, %5 : f32
linalg.yield %10 : f32
} -> tensor<32x16x256x256xf32>
%padded = tensor.pad %0 low[0, 0, 1, 1] high[0, 0, 1, 1] {
^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
tensor.yield %cst_0 : f32
} : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
%collapsed = tensor.collapse_shape %padded [[0, 1], [2], [3]] : tensor<32x16x258x258xf32> into tensor<512x258x258xf32>
return %collapsed : tensor<512x258x258xf32>
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment