Skip to content

Instantly share code, notes, and snippets.

@stellaraccident
Last active December 7, 2021 03:56
Show Gist options
  • Save stellaraccident/41584b1c3941f6caa1ca6efc594379ba to your computer and use it in GitHub Desktop.
Save stellaraccident/41584b1c3941f6caa1ca6efc594379ba to your computer and use it in GitHub Desktop.
IREE Jax AQT Matmul Examples
module @aqt_dense {
iree_input.global private @_params$0 = dense<[[0.000000e+00, 1.000000e-03, 2.000000e-03], [3.000000e-03, 4.000000e-03, 0.00500000035], [6.000000e-03, 7.000000e-03, 8.000000e-03], [0.00900000054, 0.0100000007, 0.0110000009], [1.200000e-02, 1.300000e-02, 1.400000e-02], [0.0150000006, 1.600000e-02, 1.700000e-02]]> : tensor<6x3xf32>
iree_input.global private @_params$1 = dense<[0.000000e+00, 1.000000e+01, 2.000000e+01]> : tensor<3xf32>
iree_input.global private @_params$2 = dense<5.000000e+00> : tensor<f32>
iree_input.global private @_params$3 = dense<[[0.000000e+00, 0.00999999977, 2.000000e-02, 3.000000e-02, 4.000000e-02, 0.049999997, 6.000000e-02, 7.000000e-02, 8.000000e-02], [0.0899999961, 0.099999994, 1.100000e-01, 1.200000e-01, 1.300000e-01, 1.400000e-01, 0.149999991, 1.600000e-01, 1.700000e-01], [0.179999992, 1.900000e-01, 0.199999988, 2.100000e-01, 2.200000e-01, 0.229999989, 2.400000e-01, 2.500000e-01, 2.600000e-01]]> : tensor<3x9xf32>
iree_input.global private @_params$4 = dense<[0.000000e+00, 3.000000e+00, 6.000000e+00, 9.000000e+00, 1.200000e+01, 1.500000e+01, 1.800000e+01, 2.100000e+01, 2.400000e+01]> : tensor<9xf32>
iree_input.global private @_params$5 = dense<5.000000e+00> : tensor<f32>
func @compute_simulated(%arg0: tensor<5x6xf32>) -> tensor<5x9xf32> {
%0 = iree_input.global.load @_params$0 : tensor<6x3xf32>
%1 = iree_input.global.load @_params$1 : tensor<3xf32>
%2 = iree_input.global.load @_params$2 : tensor<f32>
%3 = iree_input.global.load @_params$3 : tensor<3x9xf32>
%4 = iree_input.global.load @_params$4 : tensor<9xf32>
%5 = iree_input.global.load @_params$5 : tensor<f32>
%6 = call @main(%0, %1, %2, %3, %4, %5, %arg0) : (tensor<6x3xf32>, tensor<3xf32>, tensor<f32>, tensor<3x9xf32>, tensor<9xf32>, tensor<f32>, tensor<5x6xf32>) -> tensor<5x9xf32>
return %6 : tensor<5x9xf32>
}
func private @main(%arg0: tensor<6x3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<f32>, %arg3: tensor<3x9xf32>, %arg4: tensor<9xf32>, %arg5: tensor<f32>, %arg6: tensor<5x6xf32>) -> tensor<5x9xf32> {
%0 = mhlo.constant dense<5.000000e-01> : tensor<3x9xf32>
%1 = mhlo.constant dense<1.270000e+02> : tensor<f32>
%2 = mhlo.constant dense<0xFF800000> : tensor<f32>
%3 = mhlo.constant dense<127> : tensor<i32>
%4 = mhlo.constant dense<-127> : tensor<i32>
%5 = mhlo.constant dense<5.000000e-01> : tensor<5x3xf32>
%6 = mhlo.constant dense<5.000000e-01> : tensor<6x3xf32>
%7 = mhlo.constant dense<5.000000e-01> : tensor<5x6xf32>
%8 = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x6xf32>
%9 = mhlo.multiply %arg6, %8 : tensor<5x6xf32>
%10 = mhlo.add %9, %7 : tensor<5x6xf32>
%11 = "mhlo.floor"(%10) : (tensor<5x6xf32>) -> tensor<5x6xf32>
%12 = call @jit_clip(%11, %4, %3) : (tensor<5x6xf32>, tensor<i32>, tensor<i32>) -> tensor<5x6xf32>
%13 = "mhlo.abs"(%arg0) : (tensor<6x3xf32>) -> tensor<6x3xf32>
%14 = mhlo.reduce %13, %2 ( {
^bb0(%arg7: tensor<f32>, %arg8: tensor<f32>): // no predecessors
%46 = mhlo.maximum %arg7, %arg8 : tensor<f32>
"mhlo.return"(%46) : (tensor<f32>) -> ()
}) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<6x3xf32>, tensor<f32>) -> tensor<f32>
%15 = mhlo.divide %1, %14 : tensor<f32>
%16 = "mhlo.broadcast_in_dim"(%15) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<6x3xf32>
%17 = mhlo.multiply %arg0, %16 : tensor<6x3xf32>
%18 = mhlo.add %17, %6 : tensor<6x3xf32>
%19 = "mhlo.floor"(%18) : (tensor<6x3xf32>) -> tensor<6x3xf32>
%20 = "mhlo.dot_general"(%12, %19) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<5x6xf32>, tensor<6x3xf32>) -> tensor<5x3xf32>
%21 = mhlo.multiply %arg2, %15 : tensor<f32>
%22 = "mhlo.broadcast_in_dim"(%21) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x3xf32>
%23 = mhlo.divide %20, %22 : tensor<5x3xf32>
%24 = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xf32>) -> tensor<1x3xf32>
%25 = "mhlo.broadcast_in_dim"(%24) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x3xf32>) -> tensor<5x3xf32>
%26 = mhlo.add %23, %25 : tensor<5x3xf32>
%27 = "mhlo.broadcast_in_dim"(%arg5) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x3xf32>
%28 = mhlo.multiply %26, %27 : tensor<5x3xf32>
%29 = mhlo.add %28, %5 : tensor<5x3xf32>
%30 = "mhlo.floor"(%29) : (tensor<5x3xf32>) -> tensor<5x3xf32>
%31 = call @jit_clip_0(%30, %4, %3) : (tensor<5x3xf32>, tensor<i32>, tensor<i32>) -> tensor<5x3xf32>
%32 = "mhlo.abs"(%arg3) : (tensor<3x9xf32>) -> tensor<3x9xf32>
%33 = mhlo.reduce %32, %2 ( {
^bb0(%arg7: tensor<f32>, %arg8: tensor<f32>): // no predecessors
%46 = mhlo.maximum %arg7, %arg8 : tensor<f32>
"mhlo.return"(%46) : (tensor<f32>) -> ()
}) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<3x9xf32>, tensor<f32>) -> tensor<f32>
%34 = mhlo.divide %1, %33 : tensor<f32>
%35 = "mhlo.broadcast_in_dim"(%34) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<3x9xf32>
%36 = mhlo.multiply %arg3, %35 : tensor<3x9xf32>
%37 = mhlo.add %36, %0 : tensor<3x9xf32>
%38 = "mhlo.floor"(%37) : (tensor<3x9xf32>) -> tensor<3x9xf32>
%39 = "mhlo.dot_general"(%31, %38) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<5x3xf32>, tensor<3x9xf32>) -> tensor<5x9xf32>
%40 = mhlo.multiply %arg5, %34 : tensor<f32>
%41 = "mhlo.broadcast_in_dim"(%40) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x9xf32>
%42 = mhlo.divide %39, %41 : tensor<5x9xf32>
%43 = "mhlo.broadcast_in_dim"(%arg4) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<9xf32>) -> tensor<1x9xf32>
%44 = "mhlo.broadcast_in_dim"(%43) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x9xf32>) -> tensor<5x9xf32>
%45 = mhlo.add %42, %44 : tensor<5x9xf32>
return %45 : tensor<5x9xf32>
}
func private @jit_clip(%arg0: tensor<5x6xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<5x6xf32> {
%0 = "mhlo.convert"(%arg1) : (tensor<i32>) -> tensor<f32>
%1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x6xf32>
%2 = mhlo.maximum %1, %arg0 : tensor<5x6xf32>
%3 = "mhlo.convert"(%arg2) : (tensor<i32>) -> tensor<f32>
%4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x6xf32>
%5 = mhlo.minimum %4, %2 : tensor<5x6xf32>
return %5 : tensor<5x6xf32>
}
func private @jit_clip_0(%arg0: tensor<5x3xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<5x3xf32> {
%0 = "mhlo.convert"(%arg1) : (tensor<i32>) -> tensor<f32>
%1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x3xf32>
%2 = mhlo.maximum %1, %arg0 : tensor<5x3xf32>
%3 = "mhlo.convert"(%arg2) : (tensor<i32>) -> tensor<f32>
%4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x3xf32>
%5 = mhlo.minimum %4, %2 : tensor<5x3xf32>
return %5 : tensor<5x3xf32>
}
}
module @aqt_matmul {
iree_input.global private @_params$0 = dense<[[0.000000e+00, 5.003000e+02, 1.000600e+03], [1500.8999, 2.001200e+03, 2.501500e+03], [3001.7998, 3502.09985, 4.002400e+03], [4502.69971, 5.003000e+03, 5.503300e+03], [6003.59961, 6503.8999, 7004.1997], [7.504500e+03, 8004.7998, 8.505100e+03]]> : tensor<6x3xf32>
iree_input.global private @_params$1 = dense<5.000000e+00> : tensor<f32>
func @compute_native(%arg0: tensor<5x6xf32>) -> tensor<5x3xf32> {
%0 = iree_input.global.load @_params$0 : tensor<6x3xf32>
%1 = iree_input.global.load @_params$1 : tensor<f32>
%2 = call @main(%0, %1, %arg0) : (tensor<6x3xf32>, tensor<f32>, tensor<5x6xf32>) -> tensor<5x3xf32>
return %2 : tensor<5x3xf32>
}
func private @main(%arg0: tensor<6x3xf32>, %arg1: tensor<f32>, %arg2: tensor<5x6xf32>) -> tensor<5x3xf32> {
%0 = mhlo.constant dense<5.000000e-01> : tensor<6x3xf32>
%1 = mhlo.constant dense<1.270000e+02> : tensor<f32>
%2 = mhlo.constant dense<0xFF800000> : tensor<f32>
%3 = mhlo.constant dense<127> : tensor<i32>
%4 = mhlo.constant dense<-127> : tensor<i32>
%5 = mhlo.constant dense<5.000000e-01> : tensor<5x6xf32>
%6 = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x6xf32>
%7 = mhlo.multiply %arg2, %6 : tensor<5x6xf32>
%8 = mhlo.add %7, %5 : tensor<5x6xf32>
%9 = "mhlo.floor"(%8) : (tensor<5x6xf32>) -> tensor<5x6xf32>
%10 = call @jit_clip(%9, %4, %3) : (tensor<5x6xf32>, tensor<i32>, tensor<i32>) -> tensor<5x6xf32>
%11 = "mhlo.convert"(%10) : (tensor<5x6xf32>) -> tensor<5x6xi8>
%12 = "mhlo.abs"(%arg0) : (tensor<6x3xf32>) -> tensor<6x3xf32>
%13 = mhlo.reduce %12, %2 ( {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors
%25 = mhlo.maximum %arg3, %arg4 : tensor<f32>
"mhlo.return"(%25) : (tensor<f32>) -> ()
}) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<6x3xf32>, tensor<f32>) -> tensor<f32>
%14 = mhlo.divide %1, %13 : tensor<f32>
%15 = "mhlo.broadcast_in_dim"(%14) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<6x3xf32>
%16 = mhlo.multiply %arg0, %15 : tensor<6x3xf32>
%17 = mhlo.add %16, %0 : tensor<6x3xf32>
%18 = "mhlo.floor"(%17) : (tensor<6x3xf32>) -> tensor<6x3xf32>
%19 = "mhlo.convert"(%18) : (tensor<6x3xf32>) -> tensor<6x3xi8>
%20 = "mhlo.dot_general"(%11, %19) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<5x6xi8>, tensor<6x3xi8>) -> tensor<5x3xi32>
%21 = mhlo.multiply %arg1, %14 : tensor<f32>
%22 = "mhlo.convert"(%20) : (tensor<5x3xi32>) -> tensor<5x3xf32>
%23 = "mhlo.broadcast_in_dim"(%21) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x3xf32>
%24 = mhlo.divide %22, %23 : tensor<5x3xf32>
return %24 : tensor<5x3xf32>
}
func private @jit_clip(%arg0: tensor<5x6xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<5x6xf32> {
%0 = "mhlo.convert"(%arg1) : (tensor<i32>) -> tensor<f32>
%1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x6xf32>
%2 = mhlo.maximum %1, %arg0 : tensor<5x6xf32>
%3 = "mhlo.convert"(%arg2) : (tensor<i32>) -> tensor<f32>
%4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x6xf32>
%5 = mhlo.minimum %4, %2 : tensor<5x6xf32>
return %5 : tensor<5x6xf32>
}
}
module @aqt_matmul {
iree_input.global private @_params$0 = dense<[[0.000000e+00, 5.003000e+02, 1.000600e+03], [1500.8999, 2.001200e+03, 2.501500e+03], [3001.7998, 3502.09985, 4.002400e+03], [4502.69971, 5.003000e+03, 5.503300e+03], [6003.59961, 6503.8999, 7004.1997], [7.504500e+03, 8004.7998, 8.505100e+03]]> : tensor<6x3xf32>
iree_input.global private @_params$1 = dense<5.000000e+00> : tensor<f32>
func @compute_simulated(%arg0: tensor<5x6xf32>) -> tensor<5x3xf32> {
%0 = iree_input.global.load @_params$0 : tensor<6x3xf32>
%1 = iree_input.global.load @_params$1 : tensor<f32>
%2 = call @main(%0, %1, %arg0) : (tensor<6x3xf32>, tensor<f32>, tensor<5x6xf32>) -> tensor<5x3xf32>
return %2 : tensor<5x3xf32>
}
func private @main(%arg0: tensor<6x3xf32>, %arg1: tensor<f32>, %arg2: tensor<5x6xf32>) -> tensor<5x3xf32> {
%0 = mhlo.constant dense<5.000000e-01> : tensor<6x3xf32>
%1 = mhlo.constant dense<1.270000e+02> : tensor<f32>
%2 = mhlo.constant dense<0xFF800000> : tensor<f32>
%3 = mhlo.constant dense<127> : tensor<i32>
%4 = mhlo.constant dense<-127> : tensor<i32>
%5 = mhlo.constant dense<5.000000e-01> : tensor<5x6xf32>
%6 = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x6xf32>
%7 = mhlo.multiply %arg2, %6 : tensor<5x6xf32>
%8 = mhlo.add %7, %5 : tensor<5x6xf32>
%9 = "mhlo.floor"(%8) : (tensor<5x6xf32>) -> tensor<5x6xf32>
%10 = call @jit_clip(%9, %4, %3) : (tensor<5x6xf32>, tensor<i32>, tensor<i32>) -> tensor<5x6xf32>
%11 = "mhlo.abs"(%arg0) : (tensor<6x3xf32>) -> tensor<6x3xf32>
%12 = mhlo.reduce %11, %2 ( {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors
%22 = mhlo.maximum %arg3, %arg4 : tensor<f32>
"mhlo.return"(%22) : (tensor<f32>) -> ()
}) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<6x3xf32>, tensor<f32>) -> tensor<f32>
%13 = mhlo.divide %1, %12 : tensor<f32>
%14 = "mhlo.broadcast_in_dim"(%13) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<6x3xf32>
%15 = mhlo.multiply %arg0, %14 : tensor<6x3xf32>
%16 = mhlo.add %15, %0 : tensor<6x3xf32>
%17 = "mhlo.floor"(%16) : (tensor<6x3xf32>) -> tensor<6x3xf32>
%18 = "mhlo.dot_general"(%10, %17) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<5x6xf32>, tensor<6x3xf32>) -> tensor<5x3xf32>
%19 = mhlo.multiply %arg1, %13 : tensor<f32>
%20 = "mhlo.broadcast_in_dim"(%19) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x3xf32>
%21 = mhlo.divide %18, %20 : tensor<5x3xf32>
return %21 : tensor<5x3xf32>
}
func private @jit_clip(%arg0: tensor<5x6xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<5x6xf32> {
%0 = "mhlo.convert"(%arg1) : (tensor<i32>) -> tensor<f32>
%1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x6xf32>
%2 = mhlo.maximum %1, %arg0 : tensor<5x6xf32>
%3 = "mhlo.convert"(%arg2) : (tensor<i32>) -> tensor<f32>
%4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x6xf32>
%5 = mhlo.minimum %4, %2 : tensor<5x6xf32>
return %5 : tensor<5x6xf32>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment