Created
June 18, 2022 16:42
-
-
Save Birch-san/cd6aad79e5a671584d24d008b8a55c16 to your computer and use it in GitHub Desktop.
compile scatter on iree/vulkan
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# see https://github.com/google/iree/issues/9361, | |
# https://github.com/google/iree/pull/9378 | |
from iree import compiler, runtime as rt | |
rt.flags.parse_flags("--task_topology_group_count=8") | |
CODE = """ | |
#loc0 = loc(unknown) | |
module @jit_prim_fun.12 { | |
func.func public @main(%arg0: tensor<1x1xi32> loc(unknown), %arg1: tensor<1xi32> loc(unknown), %arg2: tensor<1xi32> loc(unknown)) -> tensor<1x1xi32> { | |
%0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({ | |
^bb0(%arg3: tensor<i32> loc(unknown), %arg4: tensor<i32> loc(unknown)): | |
"mhlo.return"(%arg4) : (tensor<i32>) -> () loc(#loc1) | |
}) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1]>, unique_indices = true} : (tensor<1x1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x1xi32> loc(#loc1) | |
return %0 : tensor<1x1xi32> loc(#loc0) | |
} loc(#loc0) | |
} loc(#loc0) | |
#loc1 = loc("jit(scatter)/jit(main)/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(1,), scatter_dims_to_operand_dims=(1,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"("/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/site-packages/transformers/models/bart/modeling_flax_bart.py":926:1)) | |
""" | |
extra_args = ["--mlir-print-ir-after-failure"] | |
# realistically I'd recommend the following arguments too | |
# https://github.com/nod-ai/SHARK/blob/1186d7c58e6046aea6a6115c608dbd77728e7aca/shark/iree_utils.py#L93-L96 | |
# but the problem reproduces without them anyway | |
extra_args += [ | |
"--iree-llvm-target-triple=arm64-apple-darwin21.5.0", | |
"--iree-flow-demote-i64-to-i32", | |
"--iree-vulkan-target-triple=m1-moltenvk-macos", | |
"--iree-llvm-target-cpu-features=host", | |
"--iree-mhlo-demote-i64-to-i32=false"] | |
iree_binary = compiler.compile_str( | |
CODE, target_backends=["vulkan"], input_type="mhlo", extra_args=extra_args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
succeeds when compiled via iree-org/iree@7c42a98
output is: