Last active
December 7, 2021 03:56
-
-
Save stellaraccident/41584b1c3941f6caa1ca6efc594379ba to your computer and use it in GitHub Desktop.
IREE Jax AQT Matmul Examples
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| module @aqt_dense { | |
| iree_input.global private @_params$0 = dense<[[0.000000e+00, 1.000000e-03, 2.000000e-03], [3.000000e-03, 4.000000e-03, 0.00500000035], [6.000000e-03, 7.000000e-03, 8.000000e-03], [0.00900000054, 0.0100000007, 0.0110000009], [1.200000e-02, 1.300000e-02, 1.400000e-02], [0.0150000006, 1.600000e-02, 1.700000e-02]]> : tensor<6x3xf32> | |
| iree_input.global private @_params$1 = dense<[0.000000e+00, 1.000000e+01, 2.000000e+01]> : tensor<3xf32> | |
| iree_input.global private @_params$2 = dense<5.000000e+00> : tensor<f32> | |
| iree_input.global private @_params$3 = dense<[[0.000000e+00, 0.00999999977, 2.000000e-02, 3.000000e-02, 4.000000e-02, 0.049999997, 6.000000e-02, 7.000000e-02, 8.000000e-02], [0.0899999961, 0.099999994, 1.100000e-01, 1.200000e-01, 1.300000e-01, 1.400000e-01, 0.149999991, 1.600000e-01, 1.700000e-01], [0.179999992, 1.900000e-01, 0.199999988, 2.100000e-01, 2.200000e-01, 0.229999989, 2.400000e-01, 2.500000e-01, 2.600000e-01]]> : tensor<3x9xf32> | |
| iree_input.global private @_params$4 = dense<[0.000000e+00, 3.000000e+00, 6.000000e+00, 9.000000e+00, 1.200000e+01, 1.500000e+01, 1.800000e+01, 2.100000e+01, 2.400000e+01]> : tensor<9xf32> | |
| iree_input.global private @_params$5 = dense<5.000000e+00> : tensor<f32> | |
| func @compute_simulated(%arg0: tensor<5x6xf32>) -> tensor<5x9xf32> { | |
| %0 = iree_input.global.load @_params$0 : tensor<6x3xf32> | |
| %1 = iree_input.global.load @_params$1 : tensor<3xf32> | |
| %2 = iree_input.global.load @_params$2 : tensor<f32> | |
| %3 = iree_input.global.load @_params$3 : tensor<3x9xf32> | |
| %4 = iree_input.global.load @_params$4 : tensor<9xf32> | |
| %5 = iree_input.global.load @_params$5 : tensor<f32> | |
| %6 = call @main(%0, %1, %2, %3, %4, %5, %arg0) : (tensor<6x3xf32>, tensor<3xf32>, tensor<f32>, tensor<3x9xf32>, tensor<9xf32>, tensor<f32>, tensor<5x6xf32>) -> tensor<5x9xf32> | |
| return %6 : tensor<5x9xf32> | |
| } | |
| func private @main(%arg0: tensor<6x3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<f32>, %arg3: tensor<3x9xf32>, %arg4: tensor<9xf32>, %arg5: tensor<f32>, %arg6: tensor<5x6xf32>) -> tensor<5x9xf32> { | |
| %0 = mhlo.constant dense<5.000000e-01> : tensor<3x9xf32> | |
| %1 = mhlo.constant dense<1.270000e+02> : tensor<f32> | |
| %2 = mhlo.constant dense<0xFF800000> : tensor<f32> | |
| %3 = mhlo.constant dense<127> : tensor<i32> | |
| %4 = mhlo.constant dense<-127> : tensor<i32> | |
| %5 = mhlo.constant dense<5.000000e-01> : tensor<5x3xf32> | |
| %6 = mhlo.constant dense<5.000000e-01> : tensor<6x3xf32> | |
| %7 = mhlo.constant dense<5.000000e-01> : tensor<5x6xf32> | |
| %8 = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x6xf32> | |
| %9 = mhlo.multiply %arg6, %8 : tensor<5x6xf32> | |
| %10 = mhlo.add %9, %7 : tensor<5x6xf32> | |
| %11 = "mhlo.floor"(%10) : (tensor<5x6xf32>) -> tensor<5x6xf32> | |
| %12 = call @jit_clip(%11, %4, %3) : (tensor<5x6xf32>, tensor<i32>, tensor<i32>) -> tensor<5x6xf32> | |
| %13 = "mhlo.abs"(%arg0) : (tensor<6x3xf32>) -> tensor<6x3xf32> | |
| %14 = mhlo.reduce %13, %2 ( { | |
| ^bb0(%arg7: tensor<f32>, %arg8: tensor<f32>): // no predecessors | |
| %46 = mhlo.maximum %arg7, %arg8 : tensor<f32> | |
| "mhlo.return"(%46) : (tensor<f32>) -> () | |
| }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<6x3xf32>, tensor<f32>) -> tensor<f32> | |
| %15 = mhlo.divide %1, %14 : tensor<f32> | |
| %16 = "mhlo.broadcast_in_dim"(%15) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<6x3xf32> | |
| %17 = mhlo.multiply %arg0, %16 : tensor<6x3xf32> | |
| %18 = mhlo.add %17, %6 : tensor<6x3xf32> | |
| %19 = "mhlo.floor"(%18) : (tensor<6x3xf32>) -> tensor<6x3xf32> | |
| %20 = "mhlo.dot_general"(%12, %19) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<5x6xf32>, tensor<6x3xf32>) -> tensor<5x3xf32> | |
| %21 = mhlo.multiply %arg2, %15 : tensor<f32> | |
| %22 = "mhlo.broadcast_in_dim"(%21) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x3xf32> | |
| %23 = mhlo.divide %20, %22 : tensor<5x3xf32> | |
| %24 = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xf32>) -> tensor<1x3xf32> | |
| %25 = "mhlo.broadcast_in_dim"(%24) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x3xf32>) -> tensor<5x3xf32> | |
| %26 = mhlo.add %23, %25 : tensor<5x3xf32> | |
| %27 = "mhlo.broadcast_in_dim"(%arg5) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x3xf32> | |
| %28 = mhlo.multiply %26, %27 : tensor<5x3xf32> | |
| %29 = mhlo.add %28, %5 : tensor<5x3xf32> | |
| %30 = "mhlo.floor"(%29) : (tensor<5x3xf32>) -> tensor<5x3xf32> | |
| %31 = call @jit_clip_0(%30, %4, %3) : (tensor<5x3xf32>, tensor<i32>, tensor<i32>) -> tensor<5x3xf32> | |
| %32 = "mhlo.abs"(%arg3) : (tensor<3x9xf32>) -> tensor<3x9xf32> | |
| %33 = mhlo.reduce %32, %2 ( { | |
| ^bb0(%arg7: tensor<f32>, %arg8: tensor<f32>): // no predecessors | |
| %46 = mhlo.maximum %arg7, %arg8 : tensor<f32> | |
| "mhlo.return"(%46) : (tensor<f32>) -> () | |
| }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<3x9xf32>, tensor<f32>) -> tensor<f32> | |
| %34 = mhlo.divide %1, %33 : tensor<f32> | |
| %35 = "mhlo.broadcast_in_dim"(%34) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<3x9xf32> | |
| %36 = mhlo.multiply %arg3, %35 : tensor<3x9xf32> | |
| %37 = mhlo.add %36, %0 : tensor<3x9xf32> | |
| %38 = "mhlo.floor"(%37) : (tensor<3x9xf32>) -> tensor<3x9xf32> | |
| %39 = "mhlo.dot_general"(%31, %38) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<5x3xf32>, tensor<3x9xf32>) -> tensor<5x9xf32> | |
| %40 = mhlo.multiply %arg5, %34 : tensor<f32> | |
| %41 = "mhlo.broadcast_in_dim"(%40) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x9xf32> | |
| %42 = mhlo.divide %39, %41 : tensor<5x9xf32> | |
| %43 = "mhlo.broadcast_in_dim"(%arg4) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<9xf32>) -> tensor<1x9xf32> | |
| %44 = "mhlo.broadcast_in_dim"(%43) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x9xf32>) -> tensor<5x9xf32> | |
| %45 = mhlo.add %42, %44 : tensor<5x9xf32> | |
| return %45 : tensor<5x9xf32> | |
| } | |
| func private @jit_clip(%arg0: tensor<5x6xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<5x6xf32> { | |
| %0 = "mhlo.convert"(%arg1) : (tensor<i32>) -> tensor<f32> | |
| %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x6xf32> | |
| %2 = mhlo.maximum %1, %arg0 : tensor<5x6xf32> | |
| %3 = "mhlo.convert"(%arg2) : (tensor<i32>) -> tensor<f32> | |
| %4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x6xf32> | |
| %5 = mhlo.minimum %4, %2 : tensor<5x6xf32> | |
| return %5 : tensor<5x6xf32> | |
| } | |
| func private @jit_clip_0(%arg0: tensor<5x3xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<5x3xf32> { | |
| %0 = "mhlo.convert"(%arg1) : (tensor<i32>) -> tensor<f32> | |
| %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x3xf32> | |
| %2 = mhlo.maximum %1, %arg0 : tensor<5x3xf32> | |
| %3 = "mhlo.convert"(%arg2) : (tensor<i32>) -> tensor<f32> | |
| %4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x3xf32> | |
| %5 = mhlo.minimum %4, %2 : tensor<5x3xf32> | |
| return %5 : tensor<5x3xf32> | |
| } | |
| } | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| module @aqt_matmul { | |
| iree_input.global private @_params$0 = dense<[[0.000000e+00, 5.003000e+02, 1.000600e+03], [1500.8999, 2.001200e+03, 2.501500e+03], [3001.7998, 3502.09985, 4.002400e+03], [4502.69971, 5.003000e+03, 5.503300e+03], [6003.59961, 6503.8999, 7004.1997], [7.504500e+03, 8004.7998, 8.505100e+03]]> : tensor<6x3xf32> | |
| iree_input.global private @_params$1 = dense<5.000000e+00> : tensor<f32> | |
| func @compute_native(%arg0: tensor<5x6xf32>) -> tensor<5x3xf32> { | |
| %0 = iree_input.global.load @_params$0 : tensor<6x3xf32> | |
| %1 = iree_input.global.load @_params$1 : tensor<f32> | |
| %2 = call @main(%0, %1, %arg0) : (tensor<6x3xf32>, tensor<f32>, tensor<5x6xf32>) -> tensor<5x3xf32> | |
| return %2 : tensor<5x3xf32> | |
| } | |
| func private @main(%arg0: tensor<6x3xf32>, %arg1: tensor<f32>, %arg2: tensor<5x6xf32>) -> tensor<5x3xf32> { | |
| %0 = mhlo.constant dense<5.000000e-01> : tensor<6x3xf32> | |
| %1 = mhlo.constant dense<1.270000e+02> : tensor<f32> | |
| %2 = mhlo.constant dense<0xFF800000> : tensor<f32> | |
| %3 = mhlo.constant dense<127> : tensor<i32> | |
| %4 = mhlo.constant dense<-127> : tensor<i32> | |
| %5 = mhlo.constant dense<5.000000e-01> : tensor<5x6xf32> | |
| %6 = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x6xf32> | |
| %7 = mhlo.multiply %arg2, %6 : tensor<5x6xf32> | |
| %8 = mhlo.add %7, %5 : tensor<5x6xf32> | |
| %9 = "mhlo.floor"(%8) : (tensor<5x6xf32>) -> tensor<5x6xf32> | |
| %10 = call @jit_clip(%9, %4, %3) : (tensor<5x6xf32>, tensor<i32>, tensor<i32>) -> tensor<5x6xf32> | |
| %11 = "mhlo.convert"(%10) : (tensor<5x6xf32>) -> tensor<5x6xi8> | |
| %12 = "mhlo.abs"(%arg0) : (tensor<6x3xf32>) -> tensor<6x3xf32> | |
| %13 = mhlo.reduce %12, %2 ( { | |
| ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors | |
| %25 = mhlo.maximum %arg3, %arg4 : tensor<f32> | |
| "mhlo.return"(%25) : (tensor<f32>) -> () | |
| }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<6x3xf32>, tensor<f32>) -> tensor<f32> | |
| %14 = mhlo.divide %1, %13 : tensor<f32> | |
| %15 = "mhlo.broadcast_in_dim"(%14) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<6x3xf32> | |
| %16 = mhlo.multiply %arg0, %15 : tensor<6x3xf32> | |
| %17 = mhlo.add %16, %0 : tensor<6x3xf32> | |
| %18 = "mhlo.floor"(%17) : (tensor<6x3xf32>) -> tensor<6x3xf32> | |
| %19 = "mhlo.convert"(%18) : (tensor<6x3xf32>) -> tensor<6x3xi8> | |
| %20 = "mhlo.dot_general"(%11, %19) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<5x6xi8>, tensor<6x3xi8>) -> tensor<5x3xi32> | |
| %21 = mhlo.multiply %arg1, %14 : tensor<f32> | |
| %22 = "mhlo.convert"(%20) : (tensor<5x3xi32>) -> tensor<5x3xf32> | |
| %23 = "mhlo.broadcast_in_dim"(%21) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x3xf32> | |
| %24 = mhlo.divide %22, %23 : tensor<5x3xf32> | |
| return %24 : tensor<5x3xf32> | |
| } | |
| func private @jit_clip(%arg0: tensor<5x6xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<5x6xf32> { | |
| %0 = "mhlo.convert"(%arg1) : (tensor<i32>) -> tensor<f32> | |
| %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x6xf32> | |
| %2 = mhlo.maximum %1, %arg0 : tensor<5x6xf32> | |
| %3 = "mhlo.convert"(%arg2) : (tensor<i32>) -> tensor<f32> | |
| %4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x6xf32> | |
| %5 = mhlo.minimum %4, %2 : tensor<5x6xf32> | |
| return %5 : tensor<5x6xf32> | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| module @aqt_matmul { | |
| iree_input.global private @_params$0 = dense<[[0.000000e+00, 5.003000e+02, 1.000600e+03], [1500.8999, 2.001200e+03, 2.501500e+03], [3001.7998, 3502.09985, 4.002400e+03], [4502.69971, 5.003000e+03, 5.503300e+03], [6003.59961, 6503.8999, 7004.1997], [7.504500e+03, 8004.7998, 8.505100e+03]]> : tensor<6x3xf32> | |
| iree_input.global private @_params$1 = dense<5.000000e+00> : tensor<f32> | |
| func @compute_simulated(%arg0: tensor<5x6xf32>) -> tensor<5x3xf32> { | |
| %0 = iree_input.global.load @_params$0 : tensor<6x3xf32> | |
| %1 = iree_input.global.load @_params$1 : tensor<f32> | |
| %2 = call @main(%0, %1, %arg0) : (tensor<6x3xf32>, tensor<f32>, tensor<5x6xf32>) -> tensor<5x3xf32> | |
| return %2 : tensor<5x3xf32> | |
| } | |
| func private @main(%arg0: tensor<6x3xf32>, %arg1: tensor<f32>, %arg2: tensor<5x6xf32>) -> tensor<5x3xf32> { | |
| %0 = mhlo.constant dense<5.000000e-01> : tensor<6x3xf32> | |
| %1 = mhlo.constant dense<1.270000e+02> : tensor<f32> | |
| %2 = mhlo.constant dense<0xFF800000> : tensor<f32> | |
| %3 = mhlo.constant dense<127> : tensor<i32> | |
| %4 = mhlo.constant dense<-127> : tensor<i32> | |
| %5 = mhlo.constant dense<5.000000e-01> : tensor<5x6xf32> | |
| %6 = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x6xf32> | |
| %7 = mhlo.multiply %arg2, %6 : tensor<5x6xf32> | |
| %8 = mhlo.add %7, %5 : tensor<5x6xf32> | |
| %9 = "mhlo.floor"(%8) : (tensor<5x6xf32>) -> tensor<5x6xf32> | |
| %10 = call @jit_clip(%9, %4, %3) : (tensor<5x6xf32>, tensor<i32>, tensor<i32>) -> tensor<5x6xf32> | |
| %11 = "mhlo.abs"(%arg0) : (tensor<6x3xf32>) -> tensor<6x3xf32> | |
| %12 = mhlo.reduce %11, %2 ( { | |
| ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors | |
| %22 = mhlo.maximum %arg3, %arg4 : tensor<f32> | |
| "mhlo.return"(%22) : (tensor<f32>) -> () | |
| }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<6x3xf32>, tensor<f32>) -> tensor<f32> | |
| %13 = mhlo.divide %1, %12 : tensor<f32> | |
| %14 = "mhlo.broadcast_in_dim"(%13) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<6x3xf32> | |
| %15 = mhlo.multiply %arg0, %14 : tensor<6x3xf32> | |
| %16 = mhlo.add %15, %0 : tensor<6x3xf32> | |
| %17 = "mhlo.floor"(%16) : (tensor<6x3xf32>) -> tensor<6x3xf32> | |
| %18 = "mhlo.dot_general"(%10, %17) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<5x6xf32>, tensor<6x3xf32>) -> tensor<5x3xf32> | |
| %19 = mhlo.multiply %arg1, %13 : tensor<f32> | |
| %20 = "mhlo.broadcast_in_dim"(%19) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x3xf32> | |
| %21 = mhlo.divide %18, %20 : tensor<5x3xf32> | |
| return %21 : tensor<5x3xf32> | |
| } | |
| func private @jit_clip(%arg0: tensor<5x6xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<5x6xf32> { | |
| %0 = "mhlo.convert"(%arg1) : (tensor<i32>) -> tensor<f32> | |
| %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x6xf32> | |
| %2 = mhlo.maximum %1, %arg0 : tensor<5x6xf32> | |
| %3 = "mhlo.convert"(%arg2) : (tensor<i32>) -> tensor<f32> | |
| %4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x6xf32> | |
| %5 = mhlo.minimum %4, %2 : tensor<5x6xf32> | |
| return %5 : tensor<5x6xf32> | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment