Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created October 14, 2022 18:02
Show Gist options
  • Save pashu123/7cfa13c99621f539d608307f8e601526 to your computer and use it in GitHub Desktop.
Save pashu123/7cfa13c99621f539d608307f8e601526 to your computer and use it in GitHub Desktop.
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d1, d0)>
#map2 = affine_map<(d0, d1) -> (d1)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
func.func @forward(%arg0: tensor<2x4096x320xf16>, %arg1: tensor<2x4096x320xf16>) -> tensor<2x4096x320xf16> {
%cst = arith.constant 0.000000e+00 : f16
%cst_0 = arith.constant 0.000000e+00 : f16
%0 = tensor.empty() : tensor<2x4096x320xf16>
%1 = linalg.fill ins(%cst : f16) outs(%0 : tensor<2x4096x320xf16>) -> tensor<2x4096x320xf16>
%2 = tensor.empty() : tensor<320x320xf16>
%3 = tensor.empty() : tensor<320x320xf16>
%4 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%2 : tensor<320x320xf16>) outs(%3 : tensor<320x320xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
} -> tensor<320x320xf16>
%5 = tensor.empty() : tensor<2x4096x8x40xf16>
%collapsed = tensor.collapse_shape %5 [[0, 1], [2, 3]] : tensor<2x4096x8x40xf16> into tensor<8192x320xf16>
%6 = tensor.empty() : tensor<8192x320xf16>
%7 = linalg.fill ins(%cst_0 : f16) outs(%6 : tensor<8192x320xf16>) -> tensor<8192x320xf16>
%8 = linalg.matmul ins(%collapsed, %4 : tensor<8192x320xf16>, tensor<320x320xf16>) outs(%7 : tensor<8192x320xf16>) -> tensor<8192x320xf16>
%9 = tensor.empty() : tensor<320xf16>
%10 = linalg.generic {indexing_maps = [#map2, #map0, #map0], iterator_types = ["parallel", "parallel"]} ins(%9, %8 : tensor<320xf16>, tensor<8192x320xf16>) outs(%6 : tensor<8192x320xf16>) {
^bb0(%in: f16, %in_1: f16, %out: f16):
%14 = arith.addf %in, %in_1 : f16
linalg.yield %14 : f16
} -> tensor<8192x320xf16>
%expanded = tensor.expand_shape %10 [[0, 1], [2]] : tensor<8192x320xf16> into tensor<2x4096x320xf16>
%11 = tensor.empty() : tensor<2x4096x320xf16>
%12 = tensor.empty() : tensor<2x4096x320xf16>
%13 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded, %11 : tensor<2x4096x320xf16>, tensor<2x4096x320xf16>) outs(%12
: tensor<2x4096x320xf16>) {
^bb0(%in: f16, %in_1: f16, %out: f16):
%14 = arith.addf %in, %in_1 : f16
linalg.yield %14 : f16
} -> tensor<2x4096x320xf16>
return %13 : tensor<2x4096x320xf16>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment