Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created September 1, 2022 17:32
Show Gist options
  • Save pashu123/af830b60d04674ee3a1f5beeef3d46d6 to your computer and use it in GitHub Desktop.
Save pashu123/af830b60d04674ee3a1f5beeef3d46d6 to your computer and use it in GitHub Desktop.
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d1, d0)>
#map2 = affine_map<(d0, d1) -> (0, d1)>
#map3 = affine_map<(d0, d1) -> (d1)>
#map4 = affine_map<(d0, d1, d2) -> (d1)>
#map5 = affine_map<(d0, d1, d2) -> (d0)>
#map6 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map7 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map8 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
#map9 = affine_map<(d0, d1) -> (0, 0)>
#map10 = affine_map<(d0, d1) -> (0)>
module attributes {torch.debug_module_name = "DLRM_Net"} {
func.func @forward(%arg0: tensor<1x4xf32>, %arg1: tensor<3x1xi64>, %arg2: tensor<3xi64>, %arg3: tensor<1xi64>, %arg4: tensor<1xi64>) -> tensor<1x1xf32> {
%cst = arith.constant dense<[[0.349731505, 0.0543110855]]> : tensor<1x2xf32>
%cst_0 = arith.constant dense<0.302471906> : tensor<1xf32>
%cst_1 = arith.constant dense<[[0.537720621, -0.173217818, -0.135193899, -0.468927145, -0.63505131, -0.763067245, 0.872413277, -0.227923378], [-0.195912778, -0.560267091, 0.347704262, -0.721757054, -0.0951403453, -0.400464833, 0.173028052, -2.284390e-01]]> : tensor<2x8xf32>
%cst_2 = arith.constant dense<[-8.348330e-01, -0.0199278444]> : tensor<2xf32>
%cst_3 = arith.constant dense<[0, 0, 1, 0, 1, 2]> : tensor<6xi64>
%cst_4 = arith.constant dense<[1, 2, 2, 3, 3, 3]> : tensor<6xi64>
%cst_5 = arith.constant dense<[[-0.606646597, -0.583887339], [-0.678513646, 0.470395505]]> : tensor<2x2xf32>
%cst_6 = arith.constant dense<[[0.535391629, -0.134590134], [0.336855054, 0.0333649777], [0.0785710886, 0.49143666]]> : tensor<3x2xf32>
%cst_7 = arith.constant dense<[[0.0488135032, 0.215189368], [0.102763377, 0.0448831841], [-7.634520e-02, 0.14589411], [-0.0624127872, 0.391773015]]> : tensor<4x2xf32>
%cst_8 = arith.constant dense<[[0.929304063, 0.0979973599, 0.239170983], [-5.614850e-01, -1.25276566, -0.220038965]]> : tensor<2x3xf32>
%cst_9 = arith.constant dense<[0.110555418, 0.869946897]> : tensor<2xf32>
%cst_10 = arith.constant dense<[[0.237254873, 0.178356424, 0.798618853, -0.109661706], [0.167341724, -0.456533372, -1.36463046, 0.349373847], [0.462060571, -0.396703899, 1.2132349, -0.777391135]]> : tensor<3x4xf32>
%cst_11 = arith.constant dense<[0.0264186915, -0.108070649, 0.884950518]> : tensor<3xf32>
%cst_12 = arith.constant 0.000000e+00 : f32
%c1_i64 = arith.constant 1 : i64
%c3_i64 = arith.constant 3 : i64
%0 = linalg.init_tensor [4, 3] : tensor<4x3xf32>
%1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst_10 : tensor<3x4xf32>) outs(%0 : tensor<4x3xf32>) {
^bb0(%arg5: f32, %arg6: f32):
linalg.yield %arg5 : f32
} -> tensor<4x3xf32>
%2 = linalg.init_tensor [1, 3] : tensor<1x3xf32>
%3 = linalg.fill ins(%cst_12 : f32) outs(%2 : tensor<1x3xf32>) -> tensor<1x3xf32>
%4 = linalg.matmul ins(%arg0, %1 : tensor<1x4xf32>, tensor<4x3xf32>) outs(%3 : tensor<1x3xf32>) -> tensor<1x3xf32>
%5 = linalg.generic {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "parallel"]} ins(%4, %cst_11 : tensor<1x3xf32>, tensor<3xf32>) outs(%2 : tensor<1x3xf32>) {
^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
%53 = arith.addf %arg5, %arg6 : f32
linalg.yield %53 : f32
} -> tensor<1x3xf32>
%6 = linalg.generic {indexing_maps = [#map2, #map0], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<1x3xf32>) outs(%2 : tensor<1x3xf32>) {
^bb0(%arg5: f32, %arg6: f32):
%53 = arith.cmpf ugt, %arg5, %cst_12 : f32
%54 = arith.select %53, %arg5, %cst_12 : f32
linalg.yield %54 : f32
} -> tensor<1x3xf32>
%7 = linalg.init_tensor [3, 2] : tensor<3x2xf32>
%8 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst_8 : tensor<2x3xf32>) outs(%7 : tensor<3x2xf32>) {
^bb0(%arg5: f32, %arg6: f32):
linalg.yield %arg5 : f32
} -> tensor<3x2xf32>
%9 = linalg.init_tensor [1, 2] : tensor<1x2xf32>
%10 = linalg.fill ins(%cst_12 : f32) outs(%9 : tensor<1x2xf32>) -> tensor<1x2xf32>
%11 = linalg.matmul ins(%6, %8 : tensor<1x3xf32>, tensor<3x2xf32>) outs(%10 : tensor<1x2xf32>) -> tensor<1x2xf32>
%12 = linalg.generic {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "parallel"]} ins(%11, %cst_9 : tensor<1x2xf32>, tensor<2xf32>) outs(%9 : tensor<1x2xf32>) {
^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
%53 = arith.addf %arg5, %arg6 : f32
linalg.yield %53 : f32
} -> tensor<1x2xf32>
%13 = linalg.generic {indexing_maps = [#map2, #map0], iterator_types = ["parallel", "parallel"]} ins(%12 : tensor<1x2xf32>) outs(%9 : tensor<1x2xf32>) {
^bb0(%arg5: f32, %arg6: f32):
%53 = arith.cmpf ugt, %arg5, %cst_12 : f32
%54 = arith.select %53, %arg5, %cst_12 : f32
linalg.yield %54 : f32
} -> tensor<1x2xf32>
%14 = tensor.extract_slice %arg1[0, 0] [1, 1] [1, 1] : tensor<3x1xi64> to tensor<1x1xi64>
%15 = tensor.collapse_shape %14 [[0, 1]] : tensor<1x1xi64> into tensor<1xi64>
%16 = linalg.fill ins(%cst_12 : f32) outs(%9 : tensor<1x2xf32>) -> tensor<1x2xf32>
%17 = linalg.generic {indexing_maps = [#map4, #map5, #map6], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg2, %15 : tensor<3xi64>, tensor<1xi64>) outs(%16 : tensor<1x2xf32>) {
^bb0(%arg5: i64, %arg6: i64, %arg7: f32):
%53 = linalg.index 0 : index
%54 = arith.index_cast %53 : index to i64
%55 = arith.addi %54, %c1_i64 : i64
%56 = arith.index_cast %55 : i64 to index
%57 = arith.cmpi eq, %55, %c1_i64 : i64
%58 = tensor.extract %15[%56] : tensor<1xi64>
%59 = arith.select %57, %c3_i64, %58 : i64
%60 = linalg.index 1 : index
%61 = arith.index_cast %60 : index to i64
%62 = arith.cmpi slt, %arg6, %61 : i64
%63 = arith.cmpi eq, %arg6, %61 : i64
%64 = arith.ori %62, %63 : i1
%65 = arith.cmpi slt, %61, %59 : i64
%66 = arith.andi %64, %65 : i1
%67 = arith.index_cast %arg5 : i64 to index
%68 = linalg.index 2 : index
%69 = tensor.extract %cst_7[%67, %68] : tensor<4x2xf32>
%70 = arith.addf %69, %arg7 : f32
%71 = arith.select %66, %70, %arg7 : f32
linalg.yield %71 : f32
} -> tensor<1x2xf32>
%18 = tensor.extract_slice %arg1[1, 0] [1, 1] [1, 1] : tensor<3x1xi64> to tensor<1x1xi64>
%19 = tensor.collapse_shape %18 [[0, 1]] : tensor<1x1xi64> into tensor<1xi64>
%20 = linalg.fill ins(%cst_12 : f32) outs(%9 : tensor<1x2xf32>) -> tensor<1x2xf32>
%21 = linalg.generic {indexing_maps = [#map4, #map5, #map6], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg3, %19 : tensor<1xi64>, tensor<1xi64>) outs(%20 : tensor<1x2xf32>) {
^bb0(%arg5: i64, %arg6: i64, %arg7: f32):
%53 = linalg.index 0 : index
%54 = arith.index_cast %53 : index to i64
%55 = arith.addi %54, %c1_i64 : i64
%56 = arith.index_cast %55 : i64 to index
%57 = arith.cmpi eq, %55, %c1_i64 : i64
%58 = tensor.extract %19[%56] : tensor<1xi64>
%59 = arith.select %57, %c1_i64, %58 : i64
%60 = linalg.index 1 : index
%61 = arith.index_cast %60 : index to i64
%62 = arith.cmpi slt, %arg6, %61 : i64
%63 = arith.cmpi eq, %arg6, %61 : i64
%64 = arith.ori %62, %63 : i1
%65 = arith.cmpi slt, %61, %59 : i64
%66 = arith.andi %64, %65 : i1
%67 = arith.index_cast %arg5 : i64 to index
%68 = linalg.index 2 : index
%69 = tensor.extract %cst_6[%67, %68] : tensor<3x2xf32>
%70 = arith.addf %69, %arg7 : f32
%71 = arith.select %66, %70, %arg7 : f32
linalg.yield %71 : f32
} -> tensor<1x2xf32>
%22 = tensor.extract_slice %arg1[2, 0] [1, 1] [1, 1] : tensor<3x1xi64> to tensor<1x1xi64>
%23 = tensor.collapse_shape %22 [[0, 1]] : tensor<1x1xi64> into tensor<1xi64>
%24 = linalg.fill ins(%cst_12 : f32) outs(%9 : tensor<1x2xf32>) -> tensor<1x2xf32>
%25 = linalg.generic {indexing_maps = [#map4, #map5, #map6], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg4, %23 : tensor<1xi64>, tensor<1xi64>) outs(%24 : tensor<1x2xf32>) {
^bb0(%arg5: i64, %arg6: i64, %arg7: f32):
%53 = linalg.index 0 : index
%54 = arith.index_cast %53 : index to i64
%55 = arith.addi %54, %c1_i64 : i64
%56 = arith.index_cast %55 : i64 to index
%57 = arith.cmpi eq, %55, %c1_i64 : i64
%58 = tensor.extract %23[%56] : tensor<1xi64>
%59 = arith.select %57, %c1_i64, %58 : i64
%60 = linalg.index 1 : index
%61 = arith.index_cast %60 : index to i64
%62 = arith.cmpi slt, %arg6, %61 : i64
%63 = arith.cmpi eq, %arg6, %61 : i64
%64 = arith.ori %62, %63 : i1
%65 = arith.cmpi slt, %61, %59 : i64
%66 = arith.andi %64, %65 : i1
%67 = arith.index_cast %arg5 : i64 to index
%68 = linalg.index 2 : index
%69 = tensor.extract %cst_5[%67, %68] : tensor<2x2xf32>
%70 = arith.addf %69, %arg7 : f32
%71 = arith.select %66, %70, %arg7 : f32
linalg.yield %71 : f32
} -> tensor<1x2xf32>
%26 = linalg.init_tensor [1, 8] : tensor<1x8xf32>
%27 = tensor.insert_slice %13 into %26[0, 0] [1, 2] [1, 1] : tensor<1x2xf32> into tensor<1x8xf32>
%28 = tensor.insert_slice %17 into %27[0, 2] [1, 2] [1, 1] : tensor<1x2xf32> into tensor<1x8xf32>
%29 = tensor.insert_slice %21 into %28[0, 4] [1, 2] [1, 1] : tensor<1x2xf32> into tensor<1x8xf32>
%30 = tensor.insert_slice %25 into %29[0, 6] [1, 2] [1, 1] : tensor<1x2xf32> into tensor<1x8xf32>
%31 = tensor.expand_shape %30 [[0], [1, 2]] : tensor<1x8xf32> into tensor<1x4x2xf32>
%32 = linalg.init_tensor [1, 2, 4] : tensor<1x2x4xf32>
%33 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%31 : tensor<1x4x2xf32>) outs(%32 : tensor<1x2x4xf32>) {
^bb0(%arg5: f32, %arg6: f32):
linalg.yield %arg5 : f32
} -> tensor<1x2x4xf32>
%34 = linalg.init_tensor [1, 4, 4] : tensor<1x4x4xf32>
%35 = linalg.fill ins(%cst_12 : f32) outs(%34 : tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
%36 = linalg.batch_matmul ins(%31, %33 : tensor<1x4x2xf32>, tensor<1x2x4xf32>) outs(%35 : tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
%37 = linalg.init_tensor [1, 6] : tensor<1x6xf32>
%38 = linalg.generic {indexing_maps = [#map3, #map3, #map0], iterator_types = ["parallel", "parallel"]} ins(%cst_4, %cst_3 : tensor<6xi64>, tensor<6xi64>) outs(%37 : tensor<1x6xf32>) {
^bb0(%arg5: i64, %arg6: i64, %arg7: f32):
%53 = linalg.index 0 : index
%54 = arith.index_cast %arg5 : i64 to index
%55 = arith.index_cast %arg6 : i64 to index
%56 = tensor.extract %36[%53, %54, %55] : tensor<1x4x4xf32>
linalg.yield %56 : f32
} -> tensor<1x6xf32>
%39 = tensor.insert_slice %38 into %27[0, 2] [1, 6] [1, 1] : tensor<1x6xf32> into tensor<1x8xf32>
%40 = linalg.init_tensor [8, 2] : tensor<8x2xf32>
%41 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst_1 : tensor<2x8xf32>) outs(%40 : tensor<8x2xf32>) {
^bb0(%arg5: f32, %arg6: f32):
linalg.yield %arg5 : f32
} -> tensor<8x2xf32>
%42 = linalg.fill ins(%cst_12 : f32) outs(%9 : tensor<1x2xf32>) -> tensor<1x2xf32>
%43 = linalg.matmul ins(%39, %41 : tensor<1x8xf32>, tensor<8x2xf32>) outs(%42 : tensor<1x2xf32>) -> tensor<1x2xf32>
%44 = linalg.generic {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "parallel"]} ins(%43, %cst_2 : tensor<1x2xf32>, tensor<2xf32>) outs(%9 : tensor<1x2xf32>) {
^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
%53 = arith.addf %arg5, %arg6 : f32
linalg.yield %53 : f32
} -> tensor<1x2xf32>
%45 = linalg.generic {indexing_maps = [#map2, #map0], iterator_types = ["parallel", "parallel"]} ins(%44 : tensor<1x2xf32>) outs(%9 : tensor<1x2xf32>) {
^bb0(%arg5: f32, %arg6: f32):
%53 = arith.cmpf ugt, %arg5, %cst_12 : f32
%54 = arith.select %53, %arg5, %cst_12 : f32
linalg.yield %54 : f32
} -> tensor<1x2xf32>
%46 = linalg.init_tensor [2, 1] : tensor<2x1xf32>
%47 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst : tensor<1x2xf32>) outs(%46 : tensor<2x1xf32>) {
^bb0(%arg5: f32, %arg6: f32):
linalg.yield %arg5 : f32
} -> tensor<2x1xf32>
%48 = linalg.init_tensor [1, 1] : tensor<1x1xf32>
%49 = linalg.fill ins(%cst_12 : f32) outs(%48 : tensor<1x1xf32>) -> tensor<1x1xf32>
%50 = linalg.matmul ins(%45, %47 : tensor<1x2xf32>, tensor<2x1xf32>) outs(%49 : tensor<1x1xf32>) -> tensor<1x1xf32>
%51 = linalg.generic {indexing_maps = [#map9, #map10, #map0], iterator_types = ["parallel", "parallel"]} ins(%50, %cst_0 : tensor<1x1xf32>, tensor<1xf32>) outs(%48 : tensor<1x1xf32>) {
^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
%53 = arith.addf %arg5, %arg6 : f32
linalg.yield %53 : f32
} -> tensor<1x1xf32>
%52 = linalg.generic {indexing_maps = [#map9, #map0], iterator_types = ["parallel", "parallel"]} ins(%51 : tensor<1x1xf32>) outs(%48 : tensor<1x1xf32>) {
^bb0(%arg5: f32, %arg6: f32):
%53 = arith.cmpf ugt, %arg5, %cst_12 : f32
%54 = arith.select %53, %arg5, %cst_12 : f32
linalg.yield %54 : f32
} -> tensor<1x1xf32>
return %52 : tensor<1x1xf32>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment