Last active
January 21, 2025 17:44
-
-
Save dlwh/ee8f7b70f4752e9544090f221a4cd717 to your computer and use it in GitHub Desktop.
oom error
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Traceback (most recent call last): | |
File "/opt/levanter/src/levanter/main/sft.py", line 258, in <module> | |
levanter.config.main(train)() | |
File "/opt/levanter/src/levanter/config.py", line 84, in wrapper_inner | |
response = fn(cfg, *args, **kwargs) | |
File "/opt/levanter/src/levanter/main/sft.py", line 234, in train | |
trainer.train(state, loader) | |
File "/opt/levanter/src/levanter/trainer.py", line 431, in train | |
for info in self.training_steps(state, train_loader): | |
File "/opt/levanter/src/levanter/trainer.py", line 420, in training_steps | |
info = self.train_step(state, example) | |
File "/opt/levanter/src/levanter/trainer.py", line 397, in train_step | |
loss, new_state = self._jit_train_step_fn_no_hook(state, batch, batch_kwargs) | |
File "/opt/levanter/.venv/lib/python3.10/site-packages/haliax/partitioning.py", line 261, in __call__ | |
return self._call(False, *args, **kwargs) | |
File "/opt/levanter/.venv/lib/python3.10/site-packages/equinox/_module.py", line 1096, in __call__ | |
return self.__func__(self.__self__, *args, **kwargs) | |
File "/opt/levanter/.venv/lib/python3.10/site-packages/haliax/partitioning.py", line 337, in _call | |
out, out_static = cached_pjitted_fun(dynamic_donated, dynamic_reserved, static) | |
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 33.59G of 30.75G hbm. Exceeded hbm capacity by 2.85G. | |
Total hbm usage >= 34.85G: | |
reserved 1.25G | |
program 11.16G | |
arguments 22.44G | |
Output size 22.44G; shares 22.44G with arguments. | |
Program hbm requirement 11.16G: | |
global 8.84M | |
scoped 5.00M | |
HLO temp 11.14G (100.0% utilization: Unpadded (8.21G) Padded (8.21G), 26.4% fragmentation (2.94G)) | |
Largest program allocations in hbm: | |
1. Size: 1.96G | |
Operator: op_name="jit(_train_step)/jit(main)/jvp(contract embed -> batch, position, vocab)/0 1 2 , 3 2 -> 0 1 3/dot_general" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/equinox/_jit.py" source_line=244 | |
Shape: bf16[2,4096,128258]{1,2,0:T(8,128)(2,1)} | |
Unpadded size: 1.96G | |
Extra memory due to padding: 96.0K (1.0x expansion) | |
XLA label: fusion.193.remat8 = fusion(all-gather.35.remat2.1.remat7, copy-done.24, copy-done.68, copy-done.89), kind=kOutput, calls=fused_computation.230.clone.clone.clone.clone.clone.clone.clone.clone | |
Allocation type: HLO temp | |
========================== | |
2. Size: 1002.06M | |
Operator: op_name="jit(_train_step)/jit(main)/transpose(jvp(contract embed -> batch, position, vocab))/0 1 2 , 3 2 -> 0 1 3/dot_general" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/equinox/_jit.py" source_line=244 | |
Shape: bf16[128258,4096]{1,0:T(8,128)(2,1)} | |
Unpadded size: 1002.02M | |
Extra memory due to padding: 48.0K (1.0x expansion) | |
XLA label: all-reduce.37 = all-reduce(convolution_bitcast_fusion.2), channel_id=79, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=add.15.clone | |
Allocation type: HLO temp | |
========================== | |
3. Size: 1002.06M | |
Operator: op_name="jit(_train_step)/jit(main)/jvp(contract embed -> batch, position, vocab)/0 1 2 , 3 2 -> 0 1 3/dot_general" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/equinox/_jit.py" source_line=244 | |
Shape: bf16[128258,4096]{1,0:T(8,128)(2,1)} | |
Unpadded size: 1002.02M | |
Extra memory due to padding: 48.0K (1.0x expansion) | |
XLA label: all-gather.35.remat2.1.remat7 = all-gather(convert.151.remat2.1.remat5.2.remat.1.remat3), channel_id=155, replica_groups=[1,4]<=[4], dimensions={1}, use_global_device_ids=true | |
Allocation type: HLO temp | |
========================== | |
4. Size: 1002.06M | |
Operator: op_name="jit(_train_step)/jit(main)/jvp(contract embed -> batch, position, vocab)/0 1 2 , 3 2 -> 0 1 3/dot_general" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/equinox/_jit.py" source_line=244 | |
Shape: bf16[128258,4096]{1,0:T(8,128)(2,1)} | |
Unpadded size: 1002.02M | |
Extra memory due to padding: 48.0K (1.0x expansion) | |
XLA label: all-gather.35.remat6 = all-gather(convert.151.remat2.1.remat5.2.remat3), channel_id=154, replica_groups=[1,4]<=[4], dimensions={1}, use_global_device_ids=true | |
Allocation type: HLO temp | |
========================== | |
5. Size: 896.00M | |
Operator: op_name="jit(_train_step)/jit(main)/convert_element_type" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/jmp/_src/policy.py" source_line=31 | |
Shape: bf16[32,14336,1024]{2,1,0:T(8,128)(2,1)} | |
Unpadded size: 896.00M | |
XLA label: convert.144 = convert(param.12) | |
Allocation type: HLO temp | |
========================== | |
6. Size: 896.00M | |
Operator: op_name="jit(_train_step)/jit(main)/convert_element_type" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/jmp/_src/policy.py" source_line=31 | |
Shape: bf16[32,14336,1024]{2,1,0:T(8,128)(2,1)} | |
Unpadded size: 896.00M | |
XLA label: convert.145 = convert(param.13) | |
Allocation type: HLO temp | |
========================== | |
7. Size: 896.00M | |
Operator: op_name="jit(_train_step)/jit(main)/convert_element_type" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/jmp/_src/policy.py" source_line=31 | |
Shape: bf16[32,1024,14336]{2,1,0:T(8,128)(2,1)} | |
Unpadded size: 896.00M | |
XLA label: convert.146 = convert(param.14) | |
Allocation type: HLO temp | |
========================== | |
8. Size: 256.00M | |
Operator: op_name="jit(_train_step)/jit(main)/convert_element_type" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/jmp/_src/policy.py" source_line=31 | |
Shape: bf16[32,8,4,128,1024]{4,3,2,1,0:T(8,128)(2,1)} | |
Unpadded size: 256.00M | |
XLA label: convert.140 = convert(param.8) | |
Allocation type: HLO temp | |
========================== | |
9. Size: 256.00M | |
Operator: op_name="dynamic_donated[1][0][0].model.transformer.layers.stacked.self_attn.o_proj.weight[<flat index 0>]" | |
Shape: bf16[32,1024,32,128]{3,1,2,0:T(8,128)(2,1)} | |
Unpadded size: 256.00M | |
XLA label: copy.432 = copy(param.11), sharding={devices=[1,4,1,1]<=[4]} | |
Allocation type: HLO temp | |
========================== |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Ok in my effort to spend a little more time teaching people how to fish, today we’re gonna look at an OOM trace on TPU. The context is that [redacted] was running into OOMs when finetuning a llama 3 model on a v4-8. Llama 3 is about 8B parameters, v4-8 has 4 devices with 30.75GiB . Seq len 4096, batch size 2 | |
The first bit is: | |
``` | |
Total hbm usage >= 34.85G: | |
reserved 1.25G | |
program 11.16G | |
arguments 22.44G | |
Output size 22.44G; shares 22.44G with arguments. | |
``` | |
Ok so this is telling us that most memory is “program arguments” which means the inputs to the jit compiled train_step, which means the model, optimizer states (and technically the data, but it’s pretty tiny). 22.44G is about right: `4 bytes * 3 * 8e9` is 96GB (the `* 3* comes from model, opt state momentum, opt state second moment), but we can device by 4 because of FSDP, giving 24GB. I don’t know how come it’s only 22.44, but I’m not gonna complain. | |
So program memory usage is right. The only things we can do about it are: | |
1) increase the number of TPUs (this is by far the easiest thing to do) | |
2) decrease precision of optimizer or model params (this might be ok for finetuning. It’s known to lead to degradations during pre-training) | |
Ok, so what about “program” memory? This is temporary memory allocated by XLA. We have a bit less control here because XLA is doing so much magic for us, but let’s see what we have: | |
``` | |
1. Size: 1.96G | |
Operator: op_name="jit(_train_step)/jit(main)/jvp(contract embed -> batch, position, vocab)/0 1 2 , 3 2 -> 0 1 3/dot_general" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/equinox/_jit.py" source_line=244 | |
Shape: bf16[2,4096,128258]{1,2,0:T(8,128)(2,1)} | |
Unpadded size: 1.96G | |
Extra memory due to padding: 96.0K (1.0x expansion) | |
XLA label: fusion.193.remat8 = fusion(all-gather.35.remat2.1.remat7, copy-done.24, copy-done.68, copy-done.89), kind=kOutput, calls=fused_computation.230.clone.clone.clone.clone.clone.clone.clone.clone | |
Allocation type: HLO temp | |
========================== | |
``` | |
This is the backward pass (jvp means we’re in the gradient). I helpfully add the named axes so you can see that this operation is the final logit matrix. Llama 3 has a huge tokenizer so this typically small buffer is pretty big. I have seen people say that XLA is supposed to fuse/partition this so it’s not materialized, but that isn’t happening. | |
However, we can do a quick fix by decreasing batch size. I also added a slow/not very good manual fusion/tiling of this operation with `--model.cross_entropy_block_size`. Setting that to, e.g., 16384 ought to work. | |
If we keep going, we can see | |
``` | |
2. Size: 1002.06M | |
Operator: op_name="jit(_train_step)/jit(main)/transpose(jvp(contract embed -> batch, position, vocab))/0 1 2 , 3 2 -> 0 1 3/dot_general" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/equinox/_jit.py" source_line=244 | |
Shape: bf16[128258,4096]{1,0:T(8,128)(2,1)} | |
Unpadded size: 1002.02M | |
Extra memory due to padding: 48.0K (1.0x expansion) | |
XLA label: all-reduce.37 = all-reduce(convolution_bitcast_fusion.2), channel_id=79, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=add.15.clone | |
Allocation type: HLO temp | |
========================== | |
``` | |
all-reduce tells us it’s broadcasting the gradients. I thought these were supposed to be sharded, maybe something to look into. | |
``` | |
3. Size: 1002.06M | |
Operator: op_name="jit(_train_step)/jit(main)/jvp(contract embed -> batch, position, vocab)/0 1 2 , 3 2 -> 0 1 3/dot_general" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/equinox/_jit.py" source_line=244 | |
Shape: bf16[128258,4096]{1,0:T(8,128)(2,1)} | |
Unpadded size: 1002.02M | |
Extra memory due to padding: 48.0K (1.0x expansion) | |
XLA label: all-gather.35.remat2.1.remat7 = all-gather(convert.151.remat2.1.remat5.2.remat.1.remat3), channel_id=155, replica_groups=[1,4]<=[4], dimensions={1}, use_global_device_ids=true | |
Allocation type: HLO temp | |
========================== | |
``` | |
This is the just-in-time gathering of the lm_head matrix. Not much you can do here (aside from tiling) | |
``` | |
4. Size: 1002.06M | |
Operator: op_name="jit(_train_step)/jit(main)/jvp(contract embed -> batch, position, vocab)/0 1 2 , 3 2 -> 0 1 3/dot_general" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/equinox/_jit.py" source_line=244 | |
Shape: bf16[128258,4096]{1,0:T(8,128)(2,1)} | |
Unpadded size: 1002.02M | |
Extra memory due to padding: 48.0K (1.0x expansion) | |
XLA label: all-gather.35.remat6 = all-gather(convert.151.remat2.1.remat5.2.remat3), channel_id=154, replica_groups=[1,4]<=[4], dimensions={1}, use_global_device_ids=true | |
Allocation type: HLO temp | |
========================== | |
``` | |
Ok, there’s another copy? This is confusing. | |
After that everything is sub-GB. There is some overhead where XLA is not just-in-time converting the parameters to bf16. I started work on that a long time ago, but I never finished it. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment