Skip to content

Instantly share code, notes, and snippets.

@Aphoh
Created January 28, 2025 03:19
Show Gist options
  • Save Aphoh/4ab2a4d1b9ec30b46bdf7916affa7b1b to your computer and use it in GitHub Desktop.
Save Aphoh/4ab2a4d1b9ec30b46bdf7916affa7b1b to your computer and use it in GitHub Desktop.
crash.log
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/jax/_src/compiler.py", line 261, in backend_compile
(run_docker pid=4406, ip=10.164.0.48) return backend.compile(
(run_docker pid=4406, ip=10.164.0.48) jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data type in matrix multiplication in this target.
(run_docker pid=4406, ip=10.164.0.48)
(run_docker pid=4406, ip=10.164.0.48) at location: loc("/dot_general"(callsite("_splash_attention"("/opt/levanter/.venv/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py":2277:0) at callsite("__call__"("/opt/levanter/.venv/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py":2312:0) at callsite("<lambda>"("/opt/levanter/src/levanter/models/attention.py":953:0) at callsite("wrap_flash_attention"("/opt/levanter/src/levanter/models/attention.py":952:0) at callsite("_tpu_splash_attention"("/opt/levanter/src/levanter/models/attention.py":956:0) at callsite("_try_tpu_splash_attention"("/opt/levanter/src/levanter/models/attention.py":747:0) at callsite("dot_product_attention"("/opt/levanter/src/levanter/models/attention.py":138:0) at callsite("__call__"("/opt/levanter/src/levanter/models/routed_qwen_model.py":488:0) at callsite("__call__"("/opt/levanter/src/levanter/models/routed_qwen_model.py":668:0) at "_do_block"("/opt/levanter/.venv/lib/python3.10/site-packages/haliax/nn/scan.py":319:0))))))))))))
(run_docker pid=4406, ip=10.164.0.48)
(run_docker pid=4406, ip=10.164.0.48) The MLIR operation involved:
(run_docker pid=4406, ip=10.164.0.48) %4335 = "tpu.matmul"(%4330, %4332, %4334) <{transpose_lhs = false, transpose_rhs = true}> : (vector<512x128xbf16>, vector<128x128xbf16>, vector<512x128xf32>) -> vector<512x128xf32>
(run_docker pid=4406, ip=10.164.0.48) ... additional diagnostics were skipped.
(run_docker pid=4406, ip=10.164.0.48)
(run_docker pid=4406, ip=10.164.0.48) Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke
(run_docker pid=4406, ip=10.164.0.48)
(run_docker pid=4406, ip=10.164.0.48)
(run_docker pid=4406, ip=10.164.0.48) The above exception was the direct cause of the following exception:
(run_docker pid=4406, ip=10.164.0.48)
(run_docker pid=4406, ip=10.164.0.48) jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
(run_docker pid=4406, ip=10.164.0.48)
(run_docker pid=4406, ip=10.164.0.48) The above exception was the direct cause of the following exception:
(run_docker pid=4406, ip=10.164.0.48)
(run_docker pid=4406, ip=10.164.0.48) Traceback (most recent call last):
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/main/routed_lm.py", line 292, in <module>
(run_docker pid=4406, ip=10.164.0.48) levanter.config.main(main)()
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/config.py", line 84, in wrapper_inner
(run_docker pid=4406, ip=10.164.0.48) response = fn(cfg, *args, **kwargs)
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/main/routed_lm.py", line 279, in main
(run_docker pid=4406, ip=10.164.0.48) last_info = trainer.train(state, train_loader)
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/trainer.py", line 432, in train
(run_docker pid=4406, ip=10.164.0.48) for info in self.training_steps(state, train_loader):
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/trainer.py", line 421, in training_steps
(run_docker pid=4406, ip=10.164.0.48) info = self.train_step(state, example)
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/trainer.py", line 396, in train_step
(run_docker pid=4406, ip=10.164.0.48) loss, new_state, extras = self._jit_train_step_fn_no_hook(state, batch, batch_kwargs)
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/haliax/partitioning.py", line 261, in __call__
(run_docker pid=4406, ip=10.164.0.48) return self._call(False, *args, **kwargs)
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/equinox/_module.py", line 1096, in __call__
(run_docker pid=4406, ip=10.164.0.48) return self.__func__(self.__self__, *args, **kwargs)
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/haliax/partitioning.py", line 337, in _call
(run_docker pid=4406, ip=10.164.0.48) out, out_static = cached_pjitted_fun(dynamic_donated, dynamic_reserved, static)
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/haliax/nn/scan.py", line 319, in _do_block
(run_docker pid=4406, ip=10.164.0.48) return block(carry, *extra_args, **extra_kwargs)
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/routed_qwen_model.py", line 668, in __call__
(run_docker pid=4406, ip=10.164.0.48) attn_output = self.self_attn(x=x, mask=mask, expert_mask=expert_mask, key=k_attn)
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/routed_qwen_model.py", line 488, in __call__
(run_docker pid=4406, ip=10.164.0.48) attn_output = dot_product_attention(
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/attention.py", line 138, in dot_product_attention
(run_docker pid=4406, ip=10.164.0.48) attention_out = _try_tpu_splash_attention(
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/attention.py", line 747, in _try_tpu_splash_attention
(run_docker pid=4406, ip=10.164.0.48) return _tpu_splash_attention(
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/attention.py", line 956, in _tpu_splash_attention
(run_docker pid=4406, ip=10.164.0.48) attn_output = wrap_flash_attention(q_, k_, v_, segment_ids)
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/attention.py", line 952, in wrap_flash_attention
(run_docker pid=4406, ip=10.164.0.48) return jax.vmap(
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/attention.py", line 953, in <lambda>
(run_docker pid=4406, ip=10.164.0.48) lambda q, k, v, si: splash_kernel(q, k, v, segment_ids=si), in_axes=(0, 0, 0, segment_batch_axis)
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 2312, in __call__
(run_docker pid=4406, ip=10.164.0.48) return _splash_attention(
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 2277, in _splash_attention
(run_docker pid=4406, ip=10.164.0.48) return _splash_attention_custom(
(run_docker pid=4406, ip=10.164.0.48) jax._src.pallas.mosaic.error_handling.MosaicError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data type in matrix multiplication in this target.
(run_docker pid=4406, ip=10.164.0.48)
(run_docker pid=4406, ip=10.164.0.48) The MLIR operation involved:
(run_docker pid=4406, ip=10.164.0.48) %4335 = "tpu.matmul"(%4330, %4332, %4334) <{transpose_lhs = false, transpose_rhs = true}> : (vector<512x128xbf16>, vector<128x128xbf16>, vector<512x128xf32>) -> vector<512x128xf32>
(run_docker pid=4406, ip=10.164.0.48) ... additional diagnostics were skipped.
(run_docker pid=4406, ip=10.164.0.48)
(run_docker pid=4406, ip=10.164.0.48) Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke
(run_docker pid=4406, ip=10.164.0.48)
(run_docker pid=4406, ip=10.164.0.48) Traceback (most recent call last):
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/jax/_src/compiler.py", line 261, in backend_compile
(run_docker pid=4406, ip=10.164.0.48) return backend.compile(
(run_docker pid=4406, ip=10.164.0.48) jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data type in matrix multiplication in this target.
(run_docker pid=4406, ip=10.164.0.48)
(run_docker pid=4406, ip=10.164.0.48) at location: loc("/dot_general"(callsite("_splash_attention"("/opt/levanter/.venv/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py":2277:0) at callsite("__call__"("/opt/levanter/.venv/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py":2312:0) at callsite("<lambda>"("/opt/levanter/src/levanter/models/attention.py":953:0) at callsite("wrap_flash_attention"("/opt/levanter/src/levanter/models/attention.py":952:0) at callsite("_tpu_splash_attention"("/opt/levanter/src/levanter/models/attention.py":956:0) at callsite("_try_tpu_splash_attention"("/opt/levanter/src/levanter/models/attention.py":747:0) at callsite("dot_product_attention"("/opt/levanter/src/levanter/models/attention.py":138:0) at callsite("__call__"("/opt/levanter/src/levanter/models/routed_qwen_model.py":488:0) at callsite("__call__"("/opt/levanter/src/levanter/models/routed_qwen_model.py":668:0) at "_do_block"("/opt/levanter/.venv/lib/python3.10/site-packages/haliax/nn/scan.py":319:0))))))))))))
(run_docker pid=4406, ip=10.164.0.48)
(run_docker pid=4406, ip=10.164.0.48) The MLIR operation involved:
(run_docker pid=4406, ip=10.164.0.48) %4335 = "tpu.matmul"(%4330, %4332, %4334) <{transpose_lhs = false, transpose_rhs = true}> : (vector<512x128xbf16>, vector<128x128xbf16>, vector<512x128xf32>) -> vector<512x128xf32>
(run_docker pid=4406, ip=10.164.0.48) ... additional diagnostics were skipped.
(run_docker pid=4406, ip=10.164.0.48)
(run_docker pid=4406, ip=10.164.0.48) Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke
(run_docker pid=4406, ip=10.164.0.48)
(run_docker pid=4406, ip=10.164.0.48)
(run_docker pid=4406, ip=10.164.0.48) The above exception was the direct cause of the following exception:
(run_docker pid=4406, ip=10.164.0.48)
(run_docker pid=4406, ip=10.164.0.48) jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
(run_docker pid=4406, ip=10.164.0.48)
(run_docker pid=4406, ip=10.164.0.48) The above exception was the direct cause of the following exception:
(run_docker pid=4406, ip=10.164.0.48)
(run_docker pid=4406, ip=10.164.0.48) Traceback (most recent call last):
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/main/routed_lm.py", line 292, in <module>
(run_docker pid=4406, ip=10.164.0.48) levanter.config.main(main)()
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/config.py", line 84, in wrapper_inner
(run_docker pid=4406, ip=10.164.0.48) response = fn(cfg, *args, **kwargs)
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/main/routed_lm.py", line 279, in main
(run_docker pid=4406, ip=10.164.0.48) last_info = trainer.train(state, train_loader)
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/trainer.py", line 432, in train
(run_docker pid=4406, ip=10.164.0.48) for info in self.training_steps(state, train_loader):
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/trainer.py", line 421, in training_steps
(run_docker pid=4406, ip=10.164.0.48) info = self.train_step(state, example)
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/trainer.py", line 396, in train_step
(run_docker pid=4406, ip=10.164.0.48) loss, new_state, extras = self._jit_train_step_fn_no_hook(state, batch, batch_kwargs)
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/haliax/partitioning.py", line 261, in __call__
(run_docker pid=4406, ip=10.164.0.48) return self._call(False, *args, **kwargs)
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/equinox/_module.py", line 1096, in __call__
(run_docker pid=4406, ip=10.164.0.48) return self.__func__(self.__self__, *args, **kwargs)
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/haliax/partitioning.py", line 337, in _call
(run_docker pid=4406, ip=10.164.0.48) out, out_static = cached_pjitted_fun(dynamic_donated, dynamic_reserved, static)
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/haliax/nn/scan.py", line 319, in _do_block
(run_docker pid=4406, ip=10.164.0.48) return block(carry, *extra_args, **extra_kwargs)
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/routed_qwen_model.py", line 668, in __call__
(run_docker pid=4406, ip=10.164.0.48) attn_output = self.self_attn(x=x, mask=mask, expert_mask=expert_mask, key=k_attn)
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/routed_qwen_model.py", line 488, in __call__
(run_docker pid=4406, ip=10.164.0.48) attn_output = dot_product_attention(
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/attention.py", line 138, in dot_product_attention
(run_docker pid=4406, ip=10.164.0.48) attention_out = _try_tpu_splash_attention(
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/attention.py", line 747, in _try_tpu_splash_attention
(run_docker pid=4406, ip=10.164.0.48) return _tpu_splash_attention(
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/attention.py", line 956, in _tpu_splash_attention
(run_docker pid=4406, ip=10.164.0.48) attn_output = wrap_flash_attention(q_, k_, v_, segment_ids)
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/attention.py", line 952, in wrap_flash_attention
(run_docker pid=4406, ip=10.164.0.48) return jax.vmap(
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/attention.py", line 953, in <lambda>
(run_docker pid=4406, ip=10.164.0.48) lambda q, k, v, si: splash_kernel(q, k, v, segment_ids=si), in_axes=(0, 0, 0, segment_batch_axis)
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 2312, in __call__
(run_docker pid=4406, ip=10.164.0.48) return _splash_attention(
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 2277, in _splash_attention
(run_docker pid=4406, ip=10.164.0.48) return _splash_attention_custom(
(run_docker pid=4406, ip=10.164.0.48) jax._src.pallas.mosaic.error_handling.MosaicError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data type in matrix multiplication in this target.
(run_docker pid=4406, ip=10.164.0.48)
(run_docker pid=4406, ip=10.164.0.48) The MLIR operation involved:
(run_docker pid=4406, ip=10.164.0.48) %4335 = "tpu.matmul"(%4330, %4332, %4334) <{transpose_lhs = false, transpose_rhs = true}> : (vector<512x128xbf16>, vector<128x128xbf16>, vector<512x128xf32>) -> vector<512x128xf32>
(run_docker pid=4406, ip=10.164.0.48) ... additional diagnostics were skipped.
(run_docker pid=4406, ip=10.164.0.48)
(run_docker pid=4406, ip=10.164.0.48) Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment