Created
January 28, 2025 03:19
-
-
Save Aphoh/4ab2a4d1b9ec30b46bdf7916affa7b1b to your computer and use it in GitHub Desktop.
crash.log
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/jax/_src/compiler.py", line 261, in backend_compile | |
(run_docker pid=4406, ip=10.164.0.48) return backend.compile( | |
(run_docker pid=4406, ip=10.164.0.48) jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data type in matrix multiplication in this target. | |
(run_docker pid=4406, ip=10.164.0.48) | |
(run_docker pid=4406, ip=10.164.0.48) at location: loc("/dot_general"(callsite("_splash_attention"("/opt/levanter/.venv/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py":2277:0) at callsite("__call__"("/opt/levanter/.venv/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py":2312:0) at callsite("<lambda>"("/opt/levanter/src/levanter/models/attention.py":953:0) at callsite("wrap_flash_attention"("/opt/levanter/src/levanter/models/attention.py":952:0) at callsite("_tpu_splash_attention"("/opt/levanter/src/levanter/models/attention.py":956:0) at callsite("_try_tpu_splash_attention"("/opt/levanter/src/levanter/models/attention.py":747:0) at callsite("dot_product_attention"("/opt/levanter/src/levanter/models/attention.py":138:0) at callsite("__call__"("/opt/levanter/src/levanter/models/routed_qwen_model.py":488:0) at callsite("__call__"("/opt/levanter/src/levanter/models/routed_qwen_model.py":668:0) at "_do_block"("/opt/levanter/.venv/lib/python3.10/site-packages/haliax/nn/scan.py":319:0)))))))))))) | |
(run_docker pid=4406, ip=10.164.0.48) | |
(run_docker pid=4406, ip=10.164.0.48) The MLIR operation involved: | |
(run_docker pid=4406, ip=10.164.0.48) %4335 = "tpu.matmul"(%4330, %4332, %4334) <{transpose_lhs = false, transpose_rhs = true}> : (vector<512x128xbf16>, vector<128x128xbf16>, vector<512x128xf32>) -> vector<512x128xf32> | |
(run_docker pid=4406, ip=10.164.0.48) ... additional diagnostics were skipped. | |
(run_docker pid=4406, ip=10.164.0.48) | |
(run_docker pid=4406, ip=10.164.0.48) Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke | |
(run_docker pid=4406, ip=10.164.0.48) | |
(run_docker pid=4406, ip=10.164.0.48) | |
(run_docker pid=4406, ip=10.164.0.48) The above exception was the direct cause of the following exception: | |
(run_docker pid=4406, ip=10.164.0.48) | |
(run_docker pid=4406, ip=10.164.0.48) jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. | |
(run_docker pid=4406, ip=10.164.0.48) | |
(run_docker pid=4406, ip=10.164.0.48) The above exception was the direct cause of the following exception: | |
(run_docker pid=4406, ip=10.164.0.48) | |
(run_docker pid=4406, ip=10.164.0.48) Traceback (most recent call last): | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/main/routed_lm.py", line 292, in <module> | |
(run_docker pid=4406, ip=10.164.0.48) levanter.config.main(main)() | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/config.py", line 84, in wrapper_inner | |
(run_docker pid=4406, ip=10.164.0.48) response = fn(cfg, *args, **kwargs) | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/main/routed_lm.py", line 279, in main | |
(run_docker pid=4406, ip=10.164.0.48) last_info = trainer.train(state, train_loader) | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/trainer.py", line 432, in train | |
(run_docker pid=4406, ip=10.164.0.48) for info in self.training_steps(state, train_loader): | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/trainer.py", line 421, in training_steps | |
(run_docker pid=4406, ip=10.164.0.48) info = self.train_step(state, example) | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/trainer.py", line 396, in train_step | |
(run_docker pid=4406, ip=10.164.0.48) loss, new_state, extras = self._jit_train_step_fn_no_hook(state, batch, batch_kwargs) | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/haliax/partitioning.py", line 261, in __call__ | |
(run_docker pid=4406, ip=10.164.0.48) return self._call(False, *args, **kwargs) | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/equinox/_module.py", line 1096, in __call__ | |
(run_docker pid=4406, ip=10.164.0.48) return self.__func__(self.__self__, *args, **kwargs) | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/haliax/partitioning.py", line 337, in _call | |
(run_docker pid=4406, ip=10.164.0.48) out, out_static = cached_pjitted_fun(dynamic_donated, dynamic_reserved, static) | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/haliax/nn/scan.py", line 319, in _do_block | |
(run_docker pid=4406, ip=10.164.0.48) return block(carry, *extra_args, **extra_kwargs) | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/routed_qwen_model.py", line 668, in __call__ | |
(run_docker pid=4406, ip=10.164.0.48) attn_output = self.self_attn(x=x, mask=mask, expert_mask=expert_mask, key=k_attn) | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/routed_qwen_model.py", line 488, in __call__ | |
(run_docker pid=4406, ip=10.164.0.48) attn_output = dot_product_attention( | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/attention.py", line 138, in dot_product_attention | |
(run_docker pid=4406, ip=10.164.0.48) attention_out = _try_tpu_splash_attention( | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/attention.py", line 747, in _try_tpu_splash_attention | |
(run_docker pid=4406, ip=10.164.0.48) return _tpu_splash_attention( | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/attention.py", line 956, in _tpu_splash_attention | |
(run_docker pid=4406, ip=10.164.0.48) attn_output = wrap_flash_attention(q_, k_, v_, segment_ids) | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/attention.py", line 952, in wrap_flash_attention | |
(run_docker pid=4406, ip=10.164.0.48) return jax.vmap( | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/attention.py", line 953, in <lambda> | |
(run_docker pid=4406, ip=10.164.0.48) lambda q, k, v, si: splash_kernel(q, k, v, segment_ids=si), in_axes=(0, 0, 0, segment_batch_axis) | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 2312, in __call__ | |
(run_docker pid=4406, ip=10.164.0.48) return _splash_attention( | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 2277, in _splash_attention | |
(run_docker pid=4406, ip=10.164.0.48) return _splash_attention_custom( | |
(run_docker pid=4406, ip=10.164.0.48) jax._src.pallas.mosaic.error_handling.MosaicError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data type in matrix multiplication in this target. | |
(run_docker pid=4406, ip=10.164.0.48) | |
(run_docker pid=4406, ip=10.164.0.48) The MLIR operation involved: | |
(run_docker pid=4406, ip=10.164.0.48) %4335 = "tpu.matmul"(%4330, %4332, %4334) <{transpose_lhs = false, transpose_rhs = true}> : (vector<512x128xbf16>, vector<128x128xbf16>, vector<512x128xf32>) -> vector<512x128xf32> | |
(run_docker pid=4406, ip=10.164.0.48) ... additional diagnostics were skipped. | |
(run_docker pid=4406, ip=10.164.0.48) | |
(run_docker pid=4406, ip=10.164.0.48) Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke | |
(run_docker pid=4406, ip=10.164.0.48) | |
(run_docker pid=4406, ip=10.164.0.48) Traceback (most recent call last): | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/jax/_src/compiler.py", line 261, in backend_compile | |
(run_docker pid=4406, ip=10.164.0.48) return backend.compile( | |
(run_docker pid=4406, ip=10.164.0.48) jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data type in matrix multiplication in this target. | |
(run_docker pid=4406, ip=10.164.0.48) | |
(run_docker pid=4406, ip=10.164.0.48) at location: loc("/dot_general"(callsite("_splash_attention"("/opt/levanter/.venv/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py":2277:0) at callsite("__call__"("/opt/levanter/.venv/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py":2312:0) at callsite("<lambda>"("/opt/levanter/src/levanter/models/attention.py":953:0) at callsite("wrap_flash_attention"("/opt/levanter/src/levanter/models/attention.py":952:0) at callsite("_tpu_splash_attention"("/opt/levanter/src/levanter/models/attention.py":956:0) at callsite("_try_tpu_splash_attention"("/opt/levanter/src/levanter/models/attention.py":747:0) at callsite("dot_product_attention"("/opt/levanter/src/levanter/models/attention.py":138:0) at callsite("__call__"("/opt/levanter/src/levanter/models/routed_qwen_model.py":488:0) at callsite("__call__"("/opt/levanter/src/levanter/models/routed_qwen_model.py":668:0) at "_do_block"("/opt/levanter/.venv/lib/python3.10/site-packages/haliax/nn/scan.py":319:0)))))))))))) | |
(run_docker pid=4406, ip=10.164.0.48) | |
(run_docker pid=4406, ip=10.164.0.48) The MLIR operation involved: | |
(run_docker pid=4406, ip=10.164.0.48) %4335 = "tpu.matmul"(%4330, %4332, %4334) <{transpose_lhs = false, transpose_rhs = true}> : (vector<512x128xbf16>, vector<128x128xbf16>, vector<512x128xf32>) -> vector<512x128xf32> | |
(run_docker pid=4406, ip=10.164.0.48) ... additional diagnostics were skipped. | |
(run_docker pid=4406, ip=10.164.0.48) | |
(run_docker pid=4406, ip=10.164.0.48) Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke | |
(run_docker pid=4406, ip=10.164.0.48) | |
(run_docker pid=4406, ip=10.164.0.48) | |
(run_docker pid=4406, ip=10.164.0.48) The above exception was the direct cause of the following exception: | |
(run_docker pid=4406, ip=10.164.0.48) | |
(run_docker pid=4406, ip=10.164.0.48) jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. | |
(run_docker pid=4406, ip=10.164.0.48) | |
(run_docker pid=4406, ip=10.164.0.48) The above exception was the direct cause of the following exception: | |
(run_docker pid=4406, ip=10.164.0.48) | |
(run_docker pid=4406, ip=10.164.0.48) Traceback (most recent call last): | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/main/routed_lm.py", line 292, in <module> | |
(run_docker pid=4406, ip=10.164.0.48) levanter.config.main(main)() | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/config.py", line 84, in wrapper_inner | |
(run_docker pid=4406, ip=10.164.0.48) response = fn(cfg, *args, **kwargs) | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/main/routed_lm.py", line 279, in main | |
(run_docker pid=4406, ip=10.164.0.48) last_info = trainer.train(state, train_loader) | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/trainer.py", line 432, in train | |
(run_docker pid=4406, ip=10.164.0.48) for info in self.training_steps(state, train_loader): | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/trainer.py", line 421, in training_steps | |
(run_docker pid=4406, ip=10.164.0.48) info = self.train_step(state, example) | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/trainer.py", line 396, in train_step | |
(run_docker pid=4406, ip=10.164.0.48) loss, new_state, extras = self._jit_train_step_fn_no_hook(state, batch, batch_kwargs) | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/haliax/partitioning.py", line 261, in __call__ | |
(run_docker pid=4406, ip=10.164.0.48) return self._call(False, *args, **kwargs) | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/equinox/_module.py", line 1096, in __call__ | |
(run_docker pid=4406, ip=10.164.0.48) return self.__func__(self.__self__, *args, **kwargs) | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/haliax/partitioning.py", line 337, in _call | |
(run_docker pid=4406, ip=10.164.0.48) out, out_static = cached_pjitted_fun(dynamic_donated, dynamic_reserved, static) | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/haliax/nn/scan.py", line 319, in _do_block | |
(run_docker pid=4406, ip=10.164.0.48) return block(carry, *extra_args, **extra_kwargs) | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/routed_qwen_model.py", line 668, in __call__ | |
(run_docker pid=4406, ip=10.164.0.48) attn_output = self.self_attn(x=x, mask=mask, expert_mask=expert_mask, key=k_attn) | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/routed_qwen_model.py", line 488, in __call__ | |
(run_docker pid=4406, ip=10.164.0.48) attn_output = dot_product_attention( | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/attention.py", line 138, in dot_product_attention | |
(run_docker pid=4406, ip=10.164.0.48) attention_out = _try_tpu_splash_attention( | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/attention.py", line 747, in _try_tpu_splash_attention | |
(run_docker pid=4406, ip=10.164.0.48) return _tpu_splash_attention( | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/attention.py", line 956, in _tpu_splash_attention | |
(run_docker pid=4406, ip=10.164.0.48) attn_output = wrap_flash_attention(q_, k_, v_, segment_ids) | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/attention.py", line 952, in wrap_flash_attention | |
(run_docker pid=4406, ip=10.164.0.48) return jax.vmap( | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/src/levanter/models/attention.py", line 953, in <lambda> | |
(run_docker pid=4406, ip=10.164.0.48) lambda q, k, v, si: splash_kernel(q, k, v, segment_ids=si), in_axes=(0, 0, 0, segment_batch_axis) | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 2312, in __call__ | |
(run_docker pid=4406, ip=10.164.0.48) return _splash_attention( | |
(run_docker pid=4406, ip=10.164.0.48) File "/opt/levanter/.venv/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 2277, in _splash_attention | |
(run_docker pid=4406, ip=10.164.0.48) return _splash_attention_custom( | |
(run_docker pid=4406, ip=10.164.0.48) jax._src.pallas.mosaic.error_handling.MosaicError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data type in matrix multiplication in this target. | |
(run_docker pid=4406, ip=10.164.0.48) | |
(run_docker pid=4406, ip=10.164.0.48) The MLIR operation involved: | |
(run_docker pid=4406, ip=10.164.0.48) %4335 = "tpu.matmul"(%4330, %4332, %4334) <{transpose_lhs = false, transpose_rhs = true}> : (vector<512x128xbf16>, vector<128x128xbf16>, vector<512x128xf32>) -> vector<512x128xf32> | |
(run_docker pid=4406, ip=10.164.0.48) ... additional diagnostics were skipped. | |
(run_docker pid=4406, ip=10.164.0.48) | |
(run_docker pid=4406, ip=10.164.0.48) Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment