Created
June 6, 2022 22:02
-
-
Save Birch-san/4f35b7261ee2d8c102a4359b8fee1b46 to your computer and use it in GitHub Desktop.
stack trace from attempting to run dalle-playground on M1 GPU with jax/jaxlib 345cc19949273cc414d94e6f13d0620b780af465, iree candidate-20220606.161
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Error invoking IREE compiler tool iree-compile | |
Diagnostics: | |
/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/site-packages/transformers/models/bart/modeling_flax_bart.py:926:1: error: failed to legalize operation 'mhlo.scatter' that was explicitly marked illegal | |
input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id) | |
^ | |
compilation failed | |
Invoked with: | |
iree-compile /Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/site-packages/iree/compiler/tools/../_mlir_libs/iree-compile - --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-mlir-to-vm-bytecode-module --iree-llvm-embedded-linker-path=/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/site-packages/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-triple=arm64-apple-darwin21.5.0 --iree-flow-demote-i64-to-i32 --iree-vulkan-target-triple=m1-moltenvk-macos --iree-llvm-target-cpu-features=host --iree-mhlo-demote-i64-to-i32=false | |
Need more information? Set IREE_SAVE_TEMPS=/some/dir in your environment to save all artifacts and reproducers. | |
File "/Users/birch/git/jax/jax/_src/iree.py", line 165, in compile | |
iree_binary = iree.compiler.compile_str( | |
File "/Users/birch/git/jax/jax/_src/dispatch.py", line 807, in backend_compile | |
return backend.compile(built_c, compile_options=options) | |
File "/Users/birch/git/jax/jax/_src/profiler.py", line 297, in wrapper | |
return func(*args, **kwargs) | |
File "/Users/birch/git/jax/jax/_src/dispatch.py", line 861, in compile_or_get_cached | |
return backend_compile(backend, computation, compile_options) | |
File "/Users/birch/git/jax/jax/_src/dispatch.py", line 900, in from_xla_computation | |
compiled = compile_or_get_cached(backend, xla_computation, options) | |
File "/Users/birch/git/jax/jax/_src/dispatch.py", line 797, in compile | |
self._executable = XlaCompiledComputation.from_xla_computation( | |
File "/Users/birch/git/jax/jax/_src/dispatch.py", line 245, in _xla_callable_uncached | |
return lower_xla_callable(fun, device, backend, name, donated_invars, False, | |
File "/Users/birch/git/jax/jax/_src/dispatch.py", line 164, in xla_primitive_callable | |
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None, | |
File "/Users/birch/git/jax/jax/_src/util.py", line 212, in cached | |
return f(*args, **kwargs) | |
File "/Users/birch/git/jax/jax/_src/util.py", line 219, in wrapper | |
return cached(config._trace_context(), *args, **kwargs) | |
File "/Users/birch/git/jax/jax/_src/dispatch.py", line 99, in apply_primitive | |
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), | |
File "/Users/birch/git/jax/jax/core.py", line 676, in process_primitive | |
return primitive.impl(*tracers, **params) | |
File "/Users/birch/git/jax/jax/core.py", line 326, in bind_with_trace | |
out = trace.process_primitive(self, map(trace.full_raise, args), params) | |
File "/Users/birch/git/jax/jax/core.py", line 323, in bind | |
return self.bind_with_trace(find_top_trace(args), args, params) | |
File "/Users/birch/git/jax/jax/_src/lax/slicing.py", line 604, in scatter | |
return scatter_p.bind( | |
File "/Users/birch/git/jax/jax/_src/ops/scatter.py", line 109, in _scatter_impl | |
out = scatter_op( | |
File "/Users/birch/git/jax/jax/_src/ops/scatter.py", line 70, in _scatter_update | |
return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, | |
File "/Users/birch/git/jax/jax/_src/numpy/lax_numpy.py", line 4875, in set | |
return scatter._scatter_update(self.array, self.index, values, lax.scatter, | |
File "/Users/birch/git/dalle-playground/backend/app.py", line 56, in <module> | |
model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment