Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Created June 6, 2022 22:02
Show Gist options
  • Save Birch-san/4f35b7261ee2d8c102a4359b8fee1b46 to your computer and use it in GitHub Desktop.
Save Birch-san/4f35b7261ee2d8c102a4359b8fee1b46 to your computer and use it in GitHub Desktop.
stack trace from attempting to run dalle-playground on M1 GPU with jax/jaxlib 345cc19949273cc414d94e6f13d0620b780af465, iree candidate-20220606.161
Error invoking IREE compiler tool iree-compile
Diagnostics:
/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/site-packages/transformers/models/bart/modeling_flax_bart.py:926:1: error: failed to legalize operation 'mhlo.scatter' that was explicitly marked illegal
input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
^
compilation failed
Invoked with:
iree-compile /Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/site-packages/iree/compiler/tools/../_mlir_libs/iree-compile - --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-mlir-to-vm-bytecode-module --iree-llvm-embedded-linker-path=/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/site-packages/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-triple=arm64-apple-darwin21.5.0 --iree-flow-demote-i64-to-i32 --iree-vulkan-target-triple=m1-moltenvk-macos --iree-llvm-target-cpu-features=host --iree-mhlo-demote-i64-to-i32=false
Need more information? Set IREE_SAVE_TEMPS=/some/dir in your environment to save all artifacts and reproducers.
File "/Users/birch/git/jax/jax/_src/iree.py", line 165, in compile
iree_binary = iree.compiler.compile_str(
File "/Users/birch/git/jax/jax/_src/dispatch.py", line 807, in backend_compile
return backend.compile(built_c, compile_options=options)
File "/Users/birch/git/jax/jax/_src/profiler.py", line 297, in wrapper
return func(*args, **kwargs)
File "/Users/birch/git/jax/jax/_src/dispatch.py", line 861, in compile_or_get_cached
return backend_compile(backend, computation, compile_options)
File "/Users/birch/git/jax/jax/_src/dispatch.py", line 900, in from_xla_computation
compiled = compile_or_get_cached(backend, xla_computation, options)
File "/Users/birch/git/jax/jax/_src/dispatch.py", line 797, in compile
self._executable = XlaCompiledComputation.from_xla_computation(
File "/Users/birch/git/jax/jax/_src/dispatch.py", line 245, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
File "/Users/birch/git/jax/jax/_src/dispatch.py", line 164, in xla_primitive_callable
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
File "/Users/birch/git/jax/jax/_src/util.py", line 212, in cached
return f(*args, **kwargs)
File "/Users/birch/git/jax/jax/_src/util.py", line 219, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/Users/birch/git/jax/jax/_src/dispatch.py", line 99, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
File "/Users/birch/git/jax/jax/core.py", line 676, in process_primitive
return primitive.impl(*tracers, **params)
File "/Users/birch/git/jax/jax/core.py", line 326, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/Users/birch/git/jax/jax/core.py", line 323, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/Users/birch/git/jax/jax/_src/lax/slicing.py", line 604, in scatter
return scatter_p.bind(
File "/Users/birch/git/jax/jax/_src/ops/scatter.py", line 109, in _scatter_impl
out = scatter_op(
File "/Users/birch/git/jax/jax/_src/ops/scatter.py", line 70, in _scatter_update
return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
File "/Users/birch/git/jax/jax/_src/numpy/lax_numpy.py", line 4875, in set
return scatter._scatter_update(self.array, self.index, values, lax.scatter,
File "/Users/birch/git/dalle-playground/backend/app.py", line 56, in <module>
model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment