Skip to content

Instantly share code, notes, and snippets.

@dlwh
Last active January 21, 2025 17:44
Show Gist options
  • Save dlwh/ee8f7b70f4752e9544090f221a4cd717 to your computer and use it in GitHub Desktop.
Save dlwh/ee8f7b70f4752e9544090f221a4cd717 to your computer and use it in GitHub Desktop.
oom error
Traceback (most recent call last):
File "/opt/levanter/src/levanter/main/sft.py", line 258, in <module>
levanter.config.main(train)()
File "/opt/levanter/src/levanter/config.py", line 84, in wrapper_inner
response = fn(cfg, *args, **kwargs)
File "/opt/levanter/src/levanter/main/sft.py", line 234, in train
trainer.train(state, loader)
File "/opt/levanter/src/levanter/trainer.py", line 431, in train
for info in self.training_steps(state, train_loader):
File "/opt/levanter/src/levanter/trainer.py", line 420, in training_steps
info = self.train_step(state, example)
File "/opt/levanter/src/levanter/trainer.py", line 397, in train_step
loss, new_state = self._jit_train_step_fn_no_hook(state, batch, batch_kwargs)
File "/opt/levanter/.venv/lib/python3.10/site-packages/haliax/partitioning.py", line 261, in __call__
return self._call(False, *args, **kwargs)
File "/opt/levanter/.venv/lib/python3.10/site-packages/equinox/_module.py", line 1096, in __call__
return self.__func__(self.__self__, *args, **kwargs)
File "/opt/levanter/.venv/lib/python3.10/site-packages/haliax/partitioning.py", line 337, in _call
out, out_static = cached_pjitted_fun(dynamic_donated, dynamic_reserved, static)
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 33.59G of 30.75G hbm. Exceeded hbm capacity by 2.85G.
Total hbm usage >= 34.85G:
reserved 1.25G
program 11.16G
arguments 22.44G
Output size 22.44G; shares 22.44G with arguments.
Program hbm requirement 11.16G:
global 8.84M
scoped 5.00M
HLO temp 11.14G (100.0% utilization: Unpadded (8.21G) Padded (8.21G), 26.4% fragmentation (2.94G))
Largest program allocations in hbm:
1. Size: 1.96G
Operator: op_name="jit(_train_step)/jit(main)/jvp(contract embed -> batch, position, vocab)/0 1 2 , 3 2 -> 0 1 3/dot_general" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/equinox/_jit.py" source_line=244
Shape: bf16[2,4096,128258]{1,2,0:T(8,128)(2,1)}
Unpadded size: 1.96G
Extra memory due to padding: 96.0K (1.0x expansion)
XLA label: fusion.193.remat8 = fusion(all-gather.35.remat2.1.remat7, copy-done.24, copy-done.68, copy-done.89), kind=kOutput, calls=fused_computation.230.clone.clone.clone.clone.clone.clone.clone.clone
Allocation type: HLO temp
==========================
2. Size: 1002.06M
Operator: op_name="jit(_train_step)/jit(main)/transpose(jvp(contract embed -> batch, position, vocab))/0 1 2 , 3 2 -> 0 1 3/dot_general" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/equinox/_jit.py" source_line=244
Shape: bf16[128258,4096]{1,0:T(8,128)(2,1)}
Unpadded size: 1002.02M
Extra memory due to padding: 48.0K (1.0x expansion)
XLA label: all-reduce.37 = all-reduce(convolution_bitcast_fusion.2), channel_id=79, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=add.15.clone
Allocation type: HLO temp
==========================
3. Size: 1002.06M
Operator: op_name="jit(_train_step)/jit(main)/jvp(contract embed -> batch, position, vocab)/0 1 2 , 3 2 -> 0 1 3/dot_general" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/equinox/_jit.py" source_line=244
Shape: bf16[128258,4096]{1,0:T(8,128)(2,1)}
Unpadded size: 1002.02M
Extra memory due to padding: 48.0K (1.0x expansion)
XLA label: all-gather.35.remat2.1.remat7 = all-gather(convert.151.remat2.1.remat5.2.remat.1.remat3), channel_id=155, replica_groups=[1,4]<=[4], dimensions={1}, use_global_device_ids=true
Allocation type: HLO temp
==========================
4. Size: 1002.06M
Operator: op_name="jit(_train_step)/jit(main)/jvp(contract embed -> batch, position, vocab)/0 1 2 , 3 2 -> 0 1 3/dot_general" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/equinox/_jit.py" source_line=244
Shape: bf16[128258,4096]{1,0:T(8,128)(2,1)}
Unpadded size: 1002.02M
Extra memory due to padding: 48.0K (1.0x expansion)
XLA label: all-gather.35.remat6 = all-gather(convert.151.remat2.1.remat5.2.remat3), channel_id=154, replica_groups=[1,4]<=[4], dimensions={1}, use_global_device_ids=true
Allocation type: HLO temp
==========================
5. Size: 896.00M
Operator: op_name="jit(_train_step)/jit(main)/convert_element_type" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/jmp/_src/policy.py" source_line=31
Shape: bf16[32,14336,1024]{2,1,0:T(8,128)(2,1)}
Unpadded size: 896.00M
XLA label: convert.144 = convert(param.12)
Allocation type: HLO temp
==========================
6. Size: 896.00M
Operator: op_name="jit(_train_step)/jit(main)/convert_element_type" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/jmp/_src/policy.py" source_line=31
Shape: bf16[32,14336,1024]{2,1,0:T(8,128)(2,1)}
Unpadded size: 896.00M
XLA label: convert.145 = convert(param.13)
Allocation type: HLO temp
==========================
7. Size: 896.00M
Operator: op_name="jit(_train_step)/jit(main)/convert_element_type" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/jmp/_src/policy.py" source_line=31
Shape: bf16[32,1024,14336]{2,1,0:T(8,128)(2,1)}
Unpadded size: 896.00M
XLA label: convert.146 = convert(param.14)
Allocation type: HLO temp
==========================
8. Size: 256.00M
Operator: op_name="jit(_train_step)/jit(main)/convert_element_type" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/jmp/_src/policy.py" source_line=31
Shape: bf16[32,8,4,128,1024]{4,3,2,1,0:T(8,128)(2,1)}
Unpadded size: 256.00M
XLA label: convert.140 = convert(param.8)
Allocation type: HLO temp
==========================
9. Size: 256.00M
Operator: op_name="dynamic_donated[1][0][0].model.transformer.layers.stacked.self_attn.o_proj.weight[<flat index 0>]"
Shape: bf16[32,1024,32,128]{3,1,2,0:T(8,128)(2,1)}
Unpadded size: 256.00M
XLA label: copy.432 = copy(param.11), sharding={devices=[1,4,1,1]<=[4]}
Allocation type: HLO temp
==========================
Ok in my effort to spend a little more time teaching people how to fish, today we’re gonna look at an OOM trace on TPU. The context is that [redacted] was running into OOMs when finetuning a llama 3 model on a v4-8. Llama 3 is about 8B parameters, v4-8 has 4 devices with 30.75GiB . Seq len 4096, batch size 2
The first bit is:
```
Total hbm usage >= 34.85G:
reserved 1.25G
program 11.16G
arguments 22.44G
Output size 22.44G; shares 22.44G with arguments.
```
Ok so this is telling us that most memory is “program arguments” which means the inputs to the jit compiled train_step, which means the model, optimizer states (and technically the data, but it’s pretty tiny). 22.44G is about right: `4 bytes * 3 * 8e9` is 96GB (the `* 3* comes from model, opt state momentum, opt state second moment), but we can device by 4 because of FSDP, giving 24GB. I don’t know how come it’s only 22.44, but I’m not gonna complain.
So program memory usage is right. The only things we can do about it are:
1) increase the number of TPUs (this is by far the easiest thing to do)
2) decrease precision of optimizer or model params (this might be ok for finetuning. It’s known to lead to degradations during pre-training)
Ok, so what about “program” memory? This is temporary memory allocated by XLA. We have a bit less control here because XLA is doing so much magic for us, but let’s see what we have:
```
1. Size: 1.96G
Operator: op_name="jit(_train_step)/jit(main)/jvp(contract embed -> batch, position, vocab)/0 1 2 , 3 2 -> 0 1 3/dot_general" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/equinox/_jit.py" source_line=244
Shape: bf16[2,4096,128258]{1,2,0:T(8,128)(2,1)}
Unpadded size: 1.96G
Extra memory due to padding: 96.0K (1.0x expansion)
XLA label: fusion.193.remat8 = fusion(all-gather.35.remat2.1.remat7, copy-done.24, copy-done.68, copy-done.89), kind=kOutput, calls=fused_computation.230.clone.clone.clone.clone.clone.clone.clone.clone
Allocation type: HLO temp
==========================
```
This is the backward pass (jvp means we’re in the gradient). I helpfully add the named axes so you can see that this operation is the final logit matrix. Llama 3 has a huge tokenizer so this typically small buffer is pretty big. I have seen people say that XLA is supposed to fuse/partition this so it’s not materialized, but that isn’t happening.
However, we can do a quick fix by decreasing batch size. I also added a slow/not very good manual fusion/tiling of this operation with `--model.cross_entropy_block_size`. Setting that to, e.g., 16384 ought to work.
If we keep going, we can see
```
2. Size: 1002.06M
Operator: op_name="jit(_train_step)/jit(main)/transpose(jvp(contract embed -> batch, position, vocab))/0 1 2 , 3 2 -> 0 1 3/dot_general" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/equinox/_jit.py" source_line=244
Shape: bf16[128258,4096]{1,0:T(8,128)(2,1)}
Unpadded size: 1002.02M
Extra memory due to padding: 48.0K (1.0x expansion)
XLA label: all-reduce.37 = all-reduce(convolution_bitcast_fusion.2), channel_id=79, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=add.15.clone
Allocation type: HLO temp
==========================
```
all-reduce tells us it’s broadcasting the gradients. I thought these were supposed to be sharded, maybe something to look into.
```
3. Size: 1002.06M
Operator: op_name="jit(_train_step)/jit(main)/jvp(contract embed -> batch, position, vocab)/0 1 2 , 3 2 -> 0 1 3/dot_general" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/equinox/_jit.py" source_line=244
Shape: bf16[128258,4096]{1,0:T(8,128)(2,1)}
Unpadded size: 1002.02M
Extra memory due to padding: 48.0K (1.0x expansion)
XLA label: all-gather.35.remat2.1.remat7 = all-gather(convert.151.remat2.1.remat5.2.remat.1.remat3), channel_id=155, replica_groups=[1,4]<=[4], dimensions={1}, use_global_device_ids=true
Allocation type: HLO temp
==========================
```
This is the just-in-time gathering of the lm_head matrix. Not much you can do here (aside from tiling)
```
4. Size: 1002.06M
Operator: op_name="jit(_train_step)/jit(main)/jvp(contract embed -> batch, position, vocab)/0 1 2 , 3 2 -> 0 1 3/dot_general" source_file="/opt/levanter/.venv/lib/python3.10/site-packages/equinox/_jit.py" source_line=244
Shape: bf16[128258,4096]{1,0:T(8,128)(2,1)}
Unpadded size: 1002.02M
Extra memory due to padding: 48.0K (1.0x expansion)
XLA label: all-gather.35.remat6 = all-gather(convert.151.remat2.1.remat5.2.remat3), channel_id=154, replica_groups=[1,4]<=[4], dimensions={1}, use_global_device_ids=true
Allocation type: HLO temp
==========================
```
Ok, there’s another copy? This is confusing.
After that everything is sub-GB. There is some overhead where XLA is not just-in-time converting the parameters to bf16. I started work on that a long time ago, but I never finished it.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment