Created
March 1, 2025 23:07
-
-
Save makslevental/8cc7102abea48f94c32f58dd262eb138 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from triton_mlir.extras.context import RAIIMLIRContextModule | |
from triton_mlir.dialects import tt as ttpp, scf, llvm, _tt_ops_gen as tt | |
from triton_mlir.ir import Attribute, ArrayAttr, TypeAttr, Type | |
from triton_mlir.extras.dialects.ext import arith | |
ctx = RAIIMLIRContextModule() | |
@ttpp.jit(arg_attrs=ArrayAttr.parse('[{tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, {tt.divisibility = 16 : i32}, {tt.divisibility = 16 : i32}, {tt.divisibility = 16 : i32}, {tt.divisibility = 16 : i32}, {tt.divisibility = 16 : i32}, {tt.divisibility = 16 : i32}]'), function_type=TypeAttr.parse('(!tt.ptr<f16>, !tt.ptr<f16>, !tt.ptr<f16>, i32, i32, i32, i32, i32, i32) -> ()'), noinline=False, sym_name='matmul_kernel', sym_visibility='public') | |
def matmul_kernel(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8): | |
c0_i32 = arith.constant(0, Type.parse('i32')) | |
v0 = arith.cmpi(predicate=4, lhs=arg3, rhs=c0_i32) | |
llvm.intr_assume(cond=v0, op_bundle_operands=[], op_bundle_sizes=[]) | |
c0_i32_0 = arith.constant(0, Type.parse('i32')) | |
v1 = arith.cmpi(predicate=4, lhs=arg4, rhs=c0_i32_0) | |
llvm.intr_assume(cond=v1, op_bundle_operands=[], op_bundle_sizes=[]) | |
c0_i32_1 = arith.constant(0, Type.parse('i32')) | |
v2 = arith.cmpi(predicate=4, lhs=arg5, rhs=c0_i32_1) | |
llvm.intr_assume(cond=v2, op_bundle_operands=[], op_bundle_sizes=[]) | |
c0_i32_2 = arith.constant(0, Type.parse('i32')) | |
v3 = arith.cmpi(predicate=4, lhs=arg6, rhs=c0_i32_2) | |
llvm.intr_assume(cond=v3, op_bundle_operands=[], op_bundle_sizes=[]) | |
true = arith.constant(True, Type.parse('i1')) | |
llvm.intr_assume(cond=true, op_bundle_operands=[], op_bundle_sizes=[]) | |
true_3 = arith.constant(True, Type.parse('i1')) | |
llvm.intr_assume(cond=true_3, op_bundle_operands=[], op_bundle_sizes=[]) | |
c0_i32_4 = arith.constant(0, Type.parse('i32')) | |
v4 = arith.cmpi(predicate=4, lhs=arg7, rhs=c0_i32_4) | |
llvm.intr_assume(cond=v4, op_bundle_operands=[], op_bundle_sizes=[]) | |
c0_i32_5 = arith.constant(0, Type.parse('i32')) | |
v5 = arith.cmpi(predicate=4, lhs=arg8, rhs=c0_i32_5) | |
llvm.intr_assume(cond=v5, op_bundle_operands=[], op_bundle_sizes=[]) | |
true_6 = arith.constant(True, Type.parse('i1')) | |
llvm.intr_assume(cond=true_6, op_bundle_operands=[], op_bundle_sizes=[]) | |
v6 = tt.get_program_id(axis=0) | |
v7 = tt.call(result=[Type.parse('i32')], callee='cdiv__i32__(1,)cconstexpr_128_', operands_=[arg3]) | |
v8 = tt.call(result=[Type.parse('i32')], callee='cdiv__i32__(1,)cconstexpr_256_', operands_=[arg4]) | |
c1_i32 = arith.constant(1, Type.parse('i32')) | |
c1_i32_7 = arith.constant(1, Type.parse('i32')) | |
v9 = arith.extsi(out=Type.parse('i64'), in_=c1_i32_7) | |
v10 = arith.extsi(out=Type.parse('i64'), in_=v8) | |
v11 = arith.muli(lhs=v9, rhs=v10, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64 = arith.constant(-2147483648, Type.parse('i64')) | |
v12 = arith.cmpi(predicate=3, lhs=v11, rhs=c2147483647_i64) | |
v13 = arith.cmpi(predicate=5, lhs=v11, rhs=c_2147483648_i64) | |
v14 = arith.andi(lhs=v12, rhs=v13) | |
v15 = arith.muli(lhs=c1_i32_7, rhs=v8, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
v16 = arith.divsi(lhs=v6, rhs=v15) | |
c1_i32_8 = arith.constant(1, Type.parse('i32')) | |
c1_i32_9 = arith.constant(1, Type.parse('i32')) | |
v17 = arith.extsi(out=Type.parse('i64'), in_=v16) | |
v18 = arith.extsi(out=Type.parse('i64'), in_=c1_i32_9) | |
v19 = arith.muli(lhs=v17, rhs=v18, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_10 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_11 = arith.constant(-2147483648, Type.parse('i64')) | |
v20 = arith.cmpi(predicate=3, lhs=v19, rhs=c2147483647_i64_10) | |
v21 = arith.cmpi(predicate=5, lhs=v19, rhs=c_2147483648_i64_11) | |
v22 = arith.andi(lhs=v20, rhs=v21) | |
v23 = arith.muli(lhs=v16, rhs=c1_i32_9, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
v24 = arith.extsi(out=Type.parse('i64'), in_=v7) | |
v25 = arith.extsi(out=Type.parse('i64'), in_=v23) | |
v26 = arith.subi(lhs=v24, rhs=v25, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_12 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_13 = arith.constant(-2147483648, Type.parse('i64')) | |
v27 = arith.cmpi(predicate=3, lhs=v26, rhs=c2147483647_i64_12) | |
v28 = arith.cmpi(predicate=5, lhs=v26, rhs=c_2147483648_i64_13) | |
v29 = arith.andi(lhs=v27, rhs=v28) | |
v30 = arith.subi(lhs=v7, rhs=v23, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c1_i32_14 = arith.constant(1, Type.parse('i32')) | |
v31 = arith.minsi(lhs=v30, rhs=c1_i32_14) | |
v32 = arith.remsi(lhs=v6, rhs=v15) | |
v33 = arith.remsi(lhs=v32, rhs=v31) | |
v34 = arith.extsi(out=Type.parse('i64'), in_=v23) | |
v35 = arith.extsi(out=Type.parse('i64'), in_=v33) | |
v36 = arith.addi(lhs=v34, rhs=v35, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_15 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_16 = arith.constant(-2147483648, Type.parse('i64')) | |
v37 = arith.cmpi(predicate=3, lhs=v36, rhs=c2147483647_i64_15) | |
v38 = arith.cmpi(predicate=5, lhs=v36, rhs=c_2147483648_i64_16) | |
v39 = arith.andi(lhs=v37, rhs=v38) | |
v40 = arith.addi(lhs=v23, rhs=v33, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
v41 = arith.remsi(lhs=v6, rhs=v15) | |
v42 = arith.divsi(lhs=v41, rhs=v31) | |
c128_i32 = arith.constant(128, Type.parse('i32')) | |
c128_i32_17 = arith.constant(128, Type.parse('i32')) | |
v43 = arith.extsi(out=Type.parse('i64'), in_=v40) | |
v44 = arith.extsi(out=Type.parse('i64'), in_=c128_i32_17) | |
v45 = arith.muli(lhs=v43, rhs=v44, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_18 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_19 = arith.constant(-2147483648, Type.parse('i64')) | |
v46 = arith.cmpi(predicate=3, lhs=v45, rhs=c2147483647_i64_18) | |
v47 = arith.cmpi(predicate=5, lhs=v45, rhs=c_2147483648_i64_19) | |
v48 = arith.andi(lhs=v46, rhs=v47) | |
v49 = arith.muli(lhs=v40, rhs=c128_i32_17, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
v50 = tt.make_range(result=Type.parse('tensor<128xi32>'), start=0, end=128) | |
v51 = tt.splat(result=Type.parse('tensor<128xi32>'), src=v49) | |
v52 = arith.extsi(out=Type.parse('tensor<128xi64>'), in_=v51) | |
v53 = arith.extsi(out=Type.parse('tensor<128xi64>'), in_=v50) | |
v54 = arith.addi(lhs=v52, rhs=v53, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_20 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_21 = arith.constant(-2147483648, Type.parse('i64')) | |
cst = arith.constant(np.full([128], 2147483647, np.int64), Type.parse('tensor<128xi64>')) | |
v55 = arith.cmpi(predicate=3, lhs=v54, rhs=cst) | |
cst_22 = arith.constant(np.full([128], -2147483648, np.int64), Type.parse('tensor<128xi64>')) | |
v56 = arith.cmpi(predicate=5, lhs=v54, rhs=cst_22) | |
v57 = arith.andi(lhs=v55, rhs=v56) | |
v58 = arith.addi(lhs=v51, rhs=v50, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
v59 = tt.splat(result=Type.parse('tensor<128xi32>'), src=arg3) | |
v60 = arith.remsi(lhs=v58, rhs=v59) | |
c256_i32 = arith.constant(256, Type.parse('i32')) | |
c256_i32_23 = arith.constant(256, Type.parse('i32')) | |
v61 = arith.extsi(out=Type.parse('i64'), in_=v42) | |
v62 = arith.extsi(out=Type.parse('i64'), in_=c256_i32_23) | |
v63 = arith.muli(lhs=v61, rhs=v62, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_24 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_25 = arith.constant(-2147483648, Type.parse('i64')) | |
v64 = arith.cmpi(predicate=3, lhs=v63, rhs=c2147483647_i64_24) | |
v65 = arith.cmpi(predicate=5, lhs=v63, rhs=c_2147483648_i64_25) | |
v66 = arith.andi(lhs=v64, rhs=v65) | |
v67 = arith.muli(lhs=v42, rhs=c256_i32_23, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
v68 = tt.make_range(result=Type.parse('tensor<256xi32>'), start=0, end=256) | |
v69 = tt.splat(result=Type.parse('tensor<256xi32>'), src=v67) | |
v70 = arith.extsi(out=Type.parse('tensor<256xi64>'), in_=v69) | |
v71 = arith.extsi(out=Type.parse('tensor<256xi64>'), in_=v68) | |
v72 = arith.addi(lhs=v70, rhs=v71, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_26 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_27 = arith.constant(-2147483648, Type.parse('i64')) | |
cst_28 = arith.constant(np.full([256], 2147483647, np.int64), Type.parse('tensor<256xi64>')) | |
v73 = arith.cmpi(predicate=3, lhs=v72, rhs=cst_28) | |
cst_29 = arith.constant(np.full([256], -2147483648, np.int64), Type.parse('tensor<256xi64>')) | |
v74 = arith.cmpi(predicate=5, lhs=v72, rhs=cst_29) | |
v75 = arith.andi(lhs=v73, rhs=v74) | |
v76 = arith.addi(lhs=v69, rhs=v68, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
v77 = tt.splat(result=Type.parse('tensor<256xi32>'), src=arg4) | |
v78 = arith.remsi(lhs=v76, rhs=v77) | |
v79 = tt.make_range(result=Type.parse('tensor<16xi32>'), start=0, end=16) | |
v80 = tt.expand_dims(src=v60, axis=1) | |
v81 = tt.splat(result=Type.parse('tensor<128x1xi32>'), src=arg6) | |
v82 = arith.extsi(out=Type.parse('tensor<128x1xi64>'), in_=v80) | |
v83 = arith.extsi(out=Type.parse('tensor<128x1xi64>'), in_=v81) | |
v84 = arith.muli(lhs=v82, rhs=v83, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_30 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_31 = arith.constant(-2147483648, Type.parse('i64')) | |
cst_32 = arith.constant(np.full([128, 1], 2147483647, np.int64), Type.parse('tensor<128x1xi64>')) | |
v85 = arith.cmpi(predicate=3, lhs=v84, rhs=cst_32) | |
cst_33 = arith.constant(np.full([128, 1], -2147483648, np.int64), Type.parse('tensor<128x1xi64>')) | |
v86 = arith.cmpi(predicate=5, lhs=v84, rhs=cst_33) | |
v87 = arith.andi(lhs=v85, rhs=v86) | |
v88 = arith.muli(lhs=v80, rhs=v81, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
v89 = tt.expand_dims(src=v79, axis=0) | |
c1_i32_34 = arith.constant(1, Type.parse('i32')) | |
c1_i32_35 = arith.constant(1, Type.parse('i32')) | |
cst_36 = arith.constant(np.full([1, 16], 1, np.int32), Type.parse('tensor<1x16xi32>')) | |
v90 = arith.extsi(out=Type.parse('tensor<1x16xi64>'), in_=v89) | |
v91 = arith.extsi(out=Type.parse('tensor<1x16xi64>'), in_=cst_36) | |
v92 = arith.muli(lhs=v90, rhs=v91, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_37 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_38 = arith.constant(-2147483648, Type.parse('i64')) | |
cst_39 = arith.constant(np.full([1, 16], 2147483647, np.int64), Type.parse('tensor<1x16xi64>')) | |
v93 = arith.cmpi(predicate=3, lhs=v92, rhs=cst_39) | |
cst_40 = arith.constant(np.full([1, 16], -2147483648, np.int64), Type.parse('tensor<1x16xi64>')) | |
v94 = arith.cmpi(predicate=5, lhs=v92, rhs=cst_40) | |
v95 = arith.andi(lhs=v93, rhs=v94) | |
v96 = arith.muli(lhs=v89, rhs=cst_36, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
v97 = tt.broadcast(result=Type.parse('tensor<128x16xi32>'), src=v88) | |
v98 = tt.broadcast(result=Type.parse('tensor<128x16xi32>'), src=v96) | |
v99 = arith.extsi(out=Type.parse('tensor<128x16xi64>'), in_=v97) | |
v100 = arith.extsi(out=Type.parse('tensor<128x16xi64>'), in_=v98) | |
v101 = arith.addi(lhs=v99, rhs=v100, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_41 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_42 = arith.constant(-2147483648, Type.parse('i64')) | |
cst_43 = arith.constant(np.full([128, 16], 2147483647, np.int64), Type.parse('tensor<128x16xi64>')) | |
v102 = arith.cmpi(predicate=3, lhs=v101, rhs=cst_43) | |
cst_44 = arith.constant(np.full([128, 16], -2147483648, np.int64), Type.parse('tensor<128x16xi64>')) | |
v103 = arith.cmpi(predicate=5, lhs=v101, rhs=cst_44) | |
v104 = arith.andi(lhs=v102, rhs=v103) | |
v105 = arith.addi(lhs=v97, rhs=v98, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
v106 = tt.splat(result=Type.parse('tensor<128x16x!tt.ptr<f16>>'), src=arg0) | |
v107 = tt.addptr(result=Type.parse('tensor<128x16x!tt.ptr<f16>>'), ptr=v106, offset=v105) | |
v108 = tt.expand_dims(src=v79, axis=1) | |
v109 = tt.splat(result=Type.parse('tensor<16x1xi32>'), src=arg7) | |
v110 = arith.extsi(out=Type.parse('tensor<16x1xi64>'), in_=v108) | |
v111 = arith.extsi(out=Type.parse('tensor<16x1xi64>'), in_=v109) | |
v112 = arith.muli(lhs=v110, rhs=v111, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_45 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_46 = arith.constant(-2147483648, Type.parse('i64')) | |
cst_47 = arith.constant(np.full([16, 1], 2147483647, np.int64), Type.parse('tensor<16x1xi64>')) | |
v113 = arith.cmpi(predicate=3, lhs=v112, rhs=cst_47) | |
cst_48 = arith.constant(np.full([16, 1], -2147483648, np.int64), Type.parse('tensor<16x1xi64>')) | |
v114 = arith.cmpi(predicate=5, lhs=v112, rhs=cst_48) | |
v115 = arith.andi(lhs=v113, rhs=v114) | |
v116 = arith.muli(lhs=v108, rhs=v109, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
v117 = tt.expand_dims(src=v78, axis=0) | |
c1_i32_49 = arith.constant(1, Type.parse('i32')) | |
c1_i32_50 = arith.constant(1, Type.parse('i32')) | |
cst_51 = arith.constant(np.full([1, 256], 1, np.int32), Type.parse('tensor<1x256xi32>')) | |
v118 = arith.extsi(out=Type.parse('tensor<1x256xi64>'), in_=v117) | |
v119 = arith.extsi(out=Type.parse('tensor<1x256xi64>'), in_=cst_51) | |
v120 = arith.muli(lhs=v118, rhs=v119, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_52 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_53 = arith.constant(-2147483648, Type.parse('i64')) | |
cst_54 = arith.constant(np.full([1, 256], 2147483647, np.int64), Type.parse('tensor<1x256xi64>')) | |
v121 = arith.cmpi(predicate=3, lhs=v120, rhs=cst_54) | |
cst_55 = arith.constant(np.full([1, 256], -2147483648, np.int64), Type.parse('tensor<1x256xi64>')) | |
v122 = arith.cmpi(predicate=5, lhs=v120, rhs=cst_55) | |
v123 = arith.andi(lhs=v121, rhs=v122) | |
v124 = arith.muli(lhs=v117, rhs=cst_51, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
v125 = tt.broadcast(result=Type.parse('tensor<16x256xi32>'), src=v116) | |
v126 = tt.broadcast(result=Type.parse('tensor<16x256xi32>'), src=v124) | |
v127 = arith.extsi(out=Type.parse('tensor<16x256xi64>'), in_=v125) | |
v128 = arith.extsi(out=Type.parse('tensor<16x256xi64>'), in_=v126) | |
v129 = arith.addi(lhs=v127, rhs=v128, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_56 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_57 = arith.constant(-2147483648, Type.parse('i64')) | |
cst_58 = arith.constant(np.full([16, 256], 2147483647, np.int64), Type.parse('tensor<16x256xi64>')) | |
v130 = arith.cmpi(predicate=3, lhs=v129, rhs=cst_58) | |
cst_59 = arith.constant(np.full([16, 256], -2147483648, np.int64), Type.parse('tensor<16x256xi64>')) | |
v131 = arith.cmpi(predicate=5, lhs=v129, rhs=cst_59) | |
v132 = arith.andi(lhs=v130, rhs=v131) | |
v133 = arith.addi(lhs=v125, rhs=v126, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
v134 = tt.splat(result=Type.parse('tensor<16x256x!tt.ptr<f16>>'), src=arg1) | |
v135 = tt.addptr(result=Type.parse('tensor<16x256x!tt.ptr<f16>>'), ptr=v134, offset=v133) | |
v136 = tt.call(result=[Type.parse('tensor<128x256xf32>')], callee='zeros____(0, 0)cconstexpr_128__(0, 1)cconstexpr_256__(1,)cconstexpr_fp32_', operands_=[]) | |
v137 = tt.call(result=[Type.parse('i32')], callee='cdiv__i32__(1,)cconstexpr_16_', operands_=[arg5]) | |
c0_i32_60 = arith.constant(0, Type.parse('i32')) | |
c1_i32_61 = arith.constant(1, Type.parse('i32')) | |
v138 = arith.bitcast(out=Type.parse('i32'), in_=c0_i32_60) | |
v139 = arith.bitcast(out=Type.parse('i32'), in_=v137) | |
v140 = arith.bitcast(out=Type.parse('i32'), in_=c1_i32_61) | |
for arg9, [arg10, arg11, arg12], [v141_0, v141_1, v141_2] in scf.for_(v138, v139, v140, iter_args=[v136, v107, v135]): | |
v206 = tt.expand_dims(src=v79, axis=0) | |
c16_i32 = arith.constant(16, Type.parse('i32')) | |
c16_i32_89 = arith.constant(16, Type.parse('i32')) | |
v207 = arith.extsi(out=Type.parse('i64'), in_=arg9) | |
v208 = arith.extsi(out=Type.parse('i64'), in_=c16_i32_89) | |
v209 = arith.muli(lhs=v207, rhs=v208, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_90 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_91 = arith.constant(-2147483648, Type.parse('i64')) | |
v210 = arith.cmpi(predicate=3, lhs=v209, rhs=c2147483647_i64_90) | |
v211 = arith.cmpi(predicate=5, lhs=v209, rhs=c_2147483648_i64_91) | |
v212 = arith.andi(lhs=v210, rhs=v211) | |
v213 = arith.muli(lhs=arg9, rhs=c16_i32_89, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
v214 = arith.extsi(out=Type.parse('i64'), in_=arg5) | |
v215 = arith.extsi(out=Type.parse('i64'), in_=v213) | |
v216 = arith.subi(lhs=v214, rhs=v215, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_92 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_93 = arith.constant(-2147483648, Type.parse('i64')) | |
v217 = arith.cmpi(predicate=3, lhs=v216, rhs=c2147483647_i64_92) | |
v218 = arith.cmpi(predicate=5, lhs=v216, rhs=c_2147483648_i64_93) | |
v219 = arith.andi(lhs=v217, rhs=v218) | |
v220 = arith.subi(lhs=arg5, rhs=v213, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
v221 = tt.splat(result=Type.parse('tensor<1x16xi32>'), src=v220) | |
v222 = arith.cmpi(predicate=2, lhs=v206, rhs=v221) | |
cst_94 = arith.constant(0.0, Type.parse('f32')) | |
v223 = tt.broadcast(result=Type.parse('tensor<128x16xi1>'), src=v222) | |
cst_95 = arith.constant(np.full([128, 16], 0.0, np.float32), Type.parse('tensor<128x16xf32>')) | |
v224 = arith.truncf(out=Type.parse('tensor<128x16xf16>'), in_=cst_95) | |
v225 = tt.load(ptr=arg11, mask=v223, other=v224, boundary_check=[], cache=1, evict=1, is_volatile=False) | |
v226 = tt.expand_dims(src=v79, axis=1) | |
c16_i32_96 = arith.constant(16, Type.parse('i32')) | |
c16_i32_97 = arith.constant(16, Type.parse('i32')) | |
v227 = arith.extsi(out=Type.parse('i64'), in_=arg9) | |
v228 = arith.extsi(out=Type.parse('i64'), in_=c16_i32_97) | |
v229 = arith.muli(lhs=v227, rhs=v228, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_98 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_99 = arith.constant(-2147483648, Type.parse('i64')) | |
v230 = arith.cmpi(predicate=3, lhs=v229, rhs=c2147483647_i64_98) | |
v231 = arith.cmpi(predicate=5, lhs=v229, rhs=c_2147483648_i64_99) | |
v232 = arith.andi(lhs=v230, rhs=v231) | |
v233 = arith.muli(lhs=arg9, rhs=c16_i32_97, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
v234 = arith.extsi(out=Type.parse('i64'), in_=arg5) | |
v235 = arith.extsi(out=Type.parse('i64'), in_=v233) | |
v236 = arith.subi(lhs=v234, rhs=v235, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_100 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_101 = arith.constant(-2147483648, Type.parse('i64')) | |
v237 = arith.cmpi(predicate=3, lhs=v236, rhs=c2147483647_i64_100) | |
v238 = arith.cmpi(predicate=5, lhs=v236, rhs=c_2147483648_i64_101) | |
v239 = arith.andi(lhs=v237, rhs=v238) | |
v240 = arith.subi(lhs=arg5, rhs=v233, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
v241 = tt.splat(result=Type.parse('tensor<16x1xi32>'), src=v240) | |
v242 = arith.cmpi(predicate=2, lhs=v226, rhs=v241) | |
cst_102 = arith.constant(0.0, Type.parse('f32')) | |
v243 = tt.broadcast(result=Type.parse('tensor<16x256xi1>'), src=v242) | |
cst_103 = arith.constant(np.full([16, 256], 0.0, np.float32), Type.parse('tensor<16x256xf32>')) | |
v244 = arith.truncf(out=Type.parse('tensor<16x256xf16>'), in_=cst_103) | |
v245 = tt.load(ptr=arg12, mask=v243, other=v244, boundary_check=[], cache=1, evict=1, is_volatile=False) | |
cst_104 = arith.constant(0.0, Type.parse('f32')) | |
v246 = tt.dot(a=v225, b=v245, c=arg10, input_precision=2, max_num_imprecise_acc=0) | |
c16_i32_105 = arith.constant(16, Type.parse('i32')) | |
cst_106 = arith.constant(np.full([128, 16], 16, np.int32), Type.parse('tensor<128x16xi32>')) | |
v247 = tt.addptr(result=Type.parse('tensor<128x16x!tt.ptr<f16>>'), ptr=arg11, offset=cst_106) | |
c16_i32_107 = arith.constant(16, Type.parse('i32')) | |
c16_i32_108 = arith.constant(16, Type.parse('i32')) | |
v248 = arith.extsi(out=Type.parse('i64'), in_=c16_i32_108) | |
v249 = arith.extsi(out=Type.parse('i64'), in_=arg7) | |
v250 = arith.muli(lhs=v248, rhs=v249, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_109 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_110 = arith.constant(-2147483648, Type.parse('i64')) | |
v251 = arith.cmpi(predicate=3, lhs=v250, rhs=c2147483647_i64_109) | |
v252 = arith.cmpi(predicate=5, lhs=v250, rhs=c_2147483648_i64_110) | |
v253 = arith.andi(lhs=v251, rhs=v252) | |
v254 = arith.muli(lhs=c16_i32_108, rhs=arg7, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
v255 = tt.splat(result=Type.parse('tensor<16x256xi32>'), src=v254) | |
v256 = tt.addptr(result=Type.parse('tensor<16x256x!tt.ptr<f16>>'), ptr=arg12, offset=v255) | |
scf.yield_(results_=[v246, v247, v256]) | |
v142 = arith.truncf(out=Type.parse('tensor<128x256xf16>'), in_=v141_0) | |
c128_i32_62 = arith.constant(128, Type.parse('i32')) | |
c128_i32_63 = arith.constant(128, Type.parse('i32')) | |
v143 = arith.extsi(out=Type.parse('i64'), in_=v40) | |
v144 = arith.extsi(out=Type.parse('i64'), in_=c128_i32_63) | |
v145 = arith.muli(lhs=v143, rhs=v144, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_64 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_65 = arith.constant(-2147483648, Type.parse('i64')) | |
v146 = arith.cmpi(predicate=3, lhs=v145, rhs=c2147483647_i64_64) | |
v147 = arith.cmpi(predicate=5, lhs=v145, rhs=c_2147483648_i64_65) | |
v148 = arith.andi(lhs=v146, rhs=v147) | |
v149 = arith.muli(lhs=v40, rhs=c128_i32_63, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
v150 = tt.make_range(result=Type.parse('tensor<128xi32>'), start=0, end=128) | |
v151 = tt.splat(result=Type.parse('tensor<128xi32>'), src=v149) | |
v152 = arith.extsi(out=Type.parse('tensor<128xi64>'), in_=v151) | |
v153 = arith.extsi(out=Type.parse('tensor<128xi64>'), in_=v150) | |
v154 = arith.addi(lhs=v152, rhs=v153, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_66 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_67 = arith.constant(-2147483648, Type.parse('i64')) | |
cst_68 = arith.constant(np.full([128], 2147483647, np.int64), Type.parse('tensor<128xi64>')) | |
v155 = arith.cmpi(predicate=3, lhs=v154, rhs=cst_68) | |
cst_69 = arith.constant(np.full([128], -2147483648, np.int64), Type.parse('tensor<128xi64>')) | |
v156 = arith.cmpi(predicate=5, lhs=v154, rhs=cst_69) | |
v157 = arith.andi(lhs=v155, rhs=v156) | |
v158 = arith.addi(lhs=v151, rhs=v150, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c256_i32_70 = arith.constant(256, Type.parse('i32')) | |
c256_i32_71 = arith.constant(256, Type.parse('i32')) | |
v159 = arith.extsi(out=Type.parse('i64'), in_=v42) | |
v160 = arith.extsi(out=Type.parse('i64'), in_=c256_i32_71) | |
v161 = arith.muli(lhs=v159, rhs=v160, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_72 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_73 = arith.constant(-2147483648, Type.parse('i64')) | |
v162 = arith.cmpi(predicate=3, lhs=v161, rhs=c2147483647_i64_72) | |
v163 = arith.cmpi(predicate=5, lhs=v161, rhs=c_2147483648_i64_73) | |
v164 = arith.andi(lhs=v162, rhs=v163) | |
v165 = arith.muli(lhs=v42, rhs=c256_i32_71, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
v166 = tt.make_range(result=Type.parse('tensor<256xi32>'), start=0, end=256) | |
v167 = tt.splat(result=Type.parse('tensor<256xi32>'), src=v165) | |
v168 = arith.extsi(out=Type.parse('tensor<256xi64>'), in_=v167) | |
v169 = arith.extsi(out=Type.parse('tensor<256xi64>'), in_=v166) | |
v170 = arith.addi(lhs=v168, rhs=v169, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_74 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_75 = arith.constant(-2147483648, Type.parse('i64')) | |
cst_76 = arith.constant(np.full([256], 2147483647, np.int64), Type.parse('tensor<256xi64>')) | |
v171 = arith.cmpi(predicate=3, lhs=v170, rhs=cst_76) | |
cst_77 = arith.constant(np.full([256], -2147483648, np.int64), Type.parse('tensor<256xi64>')) | |
v172 = arith.cmpi(predicate=5, lhs=v170, rhs=cst_77) | |
v173 = arith.andi(lhs=v171, rhs=v172) | |
v174 = arith.addi(lhs=v167, rhs=v166, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
v175 = tt.expand_dims(src=v158, axis=1) | |
v176 = tt.splat(result=Type.parse('tensor<128x1xi32>'), src=arg8) | |
v177 = arith.extsi(out=Type.parse('tensor<128x1xi64>'), in_=v176) | |
v178 = arith.extsi(out=Type.parse('tensor<128x1xi64>'), in_=v175) | |
v179 = arith.muli(lhs=v177, rhs=v178, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_78 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_79 = arith.constant(-2147483648, Type.parse('i64')) | |
cst_80 = arith.constant(np.full([128, 1], 2147483647, np.int64), Type.parse('tensor<128x1xi64>')) | |
v180 = arith.cmpi(predicate=3, lhs=v179, rhs=cst_80) | |
cst_81 = arith.constant(np.full([128, 1], -2147483648, np.int64), Type.parse('tensor<128x1xi64>')) | |
v181 = arith.cmpi(predicate=5, lhs=v179, rhs=cst_81) | |
v182 = arith.andi(lhs=v180, rhs=v181) | |
v183 = arith.muli(lhs=v176, rhs=v175, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
v184 = tt.splat(result=Type.parse('tensor<128x1x!tt.ptr<f16>>'), src=arg2) | |
v185 = tt.addptr(result=Type.parse('tensor<128x1x!tt.ptr<f16>>'), ptr=v184, offset=v183) | |
v186 = tt.expand_dims(src=v174, axis=0) | |
c1_i32_82 = arith.constant(1, Type.parse('i32')) | |
c1_i32_83 = arith.constant(1, Type.parse('i32')) | |
cst_84 = arith.constant(np.full([1, 256], 1, np.int32), Type.parse('tensor<1x256xi32>')) | |
v187 = arith.extsi(out=Type.parse('tensor<1x256xi64>'), in_=cst_84) | |
v188 = arith.extsi(out=Type.parse('tensor<1x256xi64>'), in_=v186) | |
v189 = arith.muli(lhs=v187, rhs=v188, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_85 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_86 = arith.constant(-2147483648, Type.parse('i64')) | |
cst_87 = arith.constant(np.full([1, 256], 2147483647, np.int64), Type.parse('tensor<1x256xi64>')) | |
v190 = arith.cmpi(predicate=3, lhs=v189, rhs=cst_87) | |
cst_88 = arith.constant(np.full([1, 256], -2147483648, np.int64), Type.parse('tensor<1x256xi64>')) | |
v191 = arith.cmpi(predicate=5, lhs=v189, rhs=cst_88) | |
v192 = arith.andi(lhs=v190, rhs=v191) | |
v193 = arith.muli(lhs=cst_84, rhs=v186, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
v194 = tt.broadcast(result=Type.parse('tensor<128x256x!tt.ptr<f16>>'), src=v185) | |
v195 = tt.broadcast(result=Type.parse('tensor<128x256xi32>'), src=v193) | |
v196 = tt.addptr(result=Type.parse('tensor<128x256x!tt.ptr<f16>>'), ptr=v194, offset=v195) | |
v197 = tt.expand_dims(src=v158, axis=1) | |
v198 = tt.splat(result=Type.parse('tensor<128x1xi32>'), src=arg3) | |
v199 = arith.cmpi(predicate=2, lhs=v197, rhs=v198) | |
v200 = tt.expand_dims(src=v174, axis=0) | |
v201 = tt.splat(result=Type.parse('tensor<1x256xi32>'), src=arg4) | |
v202 = arith.cmpi(predicate=2, lhs=v200, rhs=v201) | |
v203 = tt.broadcast(result=Type.parse('tensor<128x256xi1>'), src=v199) | |
v204 = tt.broadcast(result=Type.parse('tensor<128x256xi1>'), src=v202) | |
v205 = arith.andi(lhs=v203, rhs=v204) | |
tt.store(ptr=v196, value=v142, mask=v205, boundary_check=[], cache=1, evict=1) | |
tt.return_(srcs=[]) | |
@ttpp.jit(function_type=TypeAttr.parse('(i32) -> i32'), noinline=False, sym_name='cdiv__i32__(1,)cconstexpr_128_', sym_visibility='private') | |
def cdiv__i32___1__cconstexpr_128_(arg0): | |
c128_i32 = arith.constant(128, Type.parse('i32')) | |
c128_i32_0 = arith.constant(128, Type.parse('i32')) | |
v0 = arith.extsi(out=Type.parse('i64'), in_=arg0) | |
v1 = arith.extsi(out=Type.parse('i64'), in_=c128_i32_0) | |
v2 = arith.addi(lhs=v0, rhs=v1, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64 = arith.constant(-2147483648, Type.parse('i64')) | |
v3 = arith.cmpi(predicate=3, lhs=v2, rhs=c2147483647_i64) | |
v4 = arith.cmpi(predicate=5, lhs=v2, rhs=c_2147483648_i64) | |
v5 = arith.andi(lhs=v3, rhs=v4) | |
v6 = arith.addi(lhs=arg0, rhs=c128_i32_0, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c1_i32 = arith.constant(1, Type.parse('i32')) | |
c1_i32_1 = arith.constant(1, Type.parse('i32')) | |
v7 = arith.extsi(out=Type.parse('i64'), in_=v6) | |
v8 = arith.extsi(out=Type.parse('i64'), in_=c1_i32_1) | |
v9 = arith.subi(lhs=v7, rhs=v8, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_2 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_3 = arith.constant(-2147483648, Type.parse('i64')) | |
v10 = arith.cmpi(predicate=3, lhs=v9, rhs=c2147483647_i64_2) | |
v11 = arith.cmpi(predicate=5, lhs=v9, rhs=c_2147483648_i64_3) | |
v12 = arith.andi(lhs=v10, rhs=v11) | |
v13 = arith.subi(lhs=v6, rhs=c1_i32_1, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c128_i32_4 = arith.constant(128, Type.parse('i32')) | |
c128_i32_5 = arith.constant(128, Type.parse('i32')) | |
v14 = arith.divsi(lhs=v13, rhs=c128_i32_5) | |
tt.return_(srcs=[v14]) | |
@ttpp.jit(function_type=TypeAttr.parse('(i32) -> i32'), noinline=False, sym_name='cdiv__i32__(1,)cconstexpr_256_', sym_visibility='private') | |
def cdiv__i32___1__cconstexpr_256_(arg0): | |
c256_i32 = arith.constant(256, Type.parse('i32')) | |
c256_i32_0 = arith.constant(256, Type.parse('i32')) | |
v0 = arith.extsi(out=Type.parse('i64'), in_=arg0) | |
v1 = arith.extsi(out=Type.parse('i64'), in_=c256_i32_0) | |
v2 = arith.addi(lhs=v0, rhs=v1, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64 = arith.constant(-2147483648, Type.parse('i64')) | |
v3 = arith.cmpi(predicate=3, lhs=v2, rhs=c2147483647_i64) | |
v4 = arith.cmpi(predicate=5, lhs=v2, rhs=c_2147483648_i64) | |
v5 = arith.andi(lhs=v3, rhs=v4) | |
v6 = arith.addi(lhs=arg0, rhs=c256_i32_0, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c1_i32 = arith.constant(1, Type.parse('i32')) | |
c1_i32_1 = arith.constant(1, Type.parse('i32')) | |
v7 = arith.extsi(out=Type.parse('i64'), in_=v6) | |
v8 = arith.extsi(out=Type.parse('i64'), in_=c1_i32_1) | |
v9 = arith.subi(lhs=v7, rhs=v8, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_2 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_3 = arith.constant(-2147483648, Type.parse('i64')) | |
v10 = arith.cmpi(predicate=3, lhs=v9, rhs=c2147483647_i64_2) | |
v11 = arith.cmpi(predicate=5, lhs=v9, rhs=c_2147483648_i64_3) | |
v12 = arith.andi(lhs=v10, rhs=v11) | |
v13 = arith.subi(lhs=v6, rhs=c1_i32_1, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c256_i32_4 = arith.constant(256, Type.parse('i32')) | |
c256_i32_5 = arith.constant(256, Type.parse('i32')) | |
v14 = arith.divsi(lhs=v13, rhs=c256_i32_5) | |
tt.return_(srcs=[v14]) | |
@ttpp.jit(function_type=TypeAttr.parse('() -> tensor<128x256xf32>'), noinline=False, sym_name='zeros____(0, 0)cconstexpr_128__(0, 1)cconstexpr_256__(1,)cconstexpr_fp32_', sym_visibility='private') | |
def zeros_____0_0_cconstexpr_128___0_1_cconstexpr_256___1__cconstexpr_fp32_(): | |
cst = arith.constant(0.0, Type.parse('f32')) | |
cst_0 = arith.constant(np.full([128, 256], 0.0, np.float32), Type.parse('tensor<128x256xf32>')) | |
tt.return_(srcs=[cst_0]) | |
@ttpp.jit(function_type=TypeAttr.parse('(i32) -> i32'), noinline=False, sym_name='cdiv__i32__(1,)cconstexpr_16_', sym_visibility='private') | |
def cdiv__i32___1__cconstexpr_16_(arg0): | |
c16_i32 = arith.constant(16, Type.parse('i32')) | |
c16_i32_0 = arith.constant(16, Type.parse('i32')) | |
v0 = arith.extsi(out=Type.parse('i64'), in_=arg0) | |
v1 = arith.extsi(out=Type.parse('i64'), in_=c16_i32_0) | |
v2 = arith.addi(lhs=v0, rhs=v1, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64 = arith.constant(-2147483648, Type.parse('i64')) | |
v3 = arith.cmpi(predicate=3, lhs=v2, rhs=c2147483647_i64) | |
v4 = arith.cmpi(predicate=5, lhs=v2, rhs=c_2147483648_i64) | |
v5 = arith.andi(lhs=v3, rhs=v4) | |
v6 = arith.addi(lhs=arg0, rhs=c16_i32_0, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c1_i32 = arith.constant(1, Type.parse('i32')) | |
c1_i32_1 = arith.constant(1, Type.parse('i32')) | |
v7 = arith.extsi(out=Type.parse('i64'), in_=v6) | |
v8 = arith.extsi(out=Type.parse('i64'), in_=c1_i32_1) | |
v9 = arith.subi(lhs=v7, rhs=v8, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c2147483647_i64_2 = arith.constant(2147483647, Type.parse('i64')) | |
c_2147483648_i64_3 = arith.constant(-2147483648, Type.parse('i64')) | |
v10 = arith.cmpi(predicate=3, lhs=v9, rhs=c2147483647_i64_2) | |
v11 = arith.cmpi(predicate=5, lhs=v9, rhs=c_2147483648_i64_3) | |
v12 = arith.andi(lhs=v10, rhs=v11) | |
v13 = arith.subi(lhs=v6, rhs=c1_i32_1, overflow_flags=Attribute.parse('#arith.overflow<none>')) | |
c16_i32_4 = arith.constant(16, Type.parse('i32')) | |
c16_i32_5 = arith.constant(16, Type.parse('i32')) | |
v14 = arith.divsi(lhs=v13, rhs=c16_i32_5) | |
tt.return_(srcs=[v14]) | |
cdiv__i32___1__cconstexpr_128_.emit() | |
cdiv__i32___1__cconstexpr_256_.emit() | |
zeros_____0_0_cconstexpr_128___0_1_cconstexpr_256___1__cconstexpr_fp32_.emit() | |
cdiv__i32___1__cconstexpr_16_.emit() | |
matmul_kernel.emit() | |
print(ctx.module) | |
ctx.module.operation.verify() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment