Some weird memory usage (VRAM) is reported (by torch and by NVML) when using 8-bit AdamW, paged or unpaged.
Here we train llama 2 on 4096-token sequences, using either --optim adamw_8bit
or --optim paged_adamw_8bit
.
We do a full finetune using qlora.py --full-finetune
, with our qlora.py fork, stepwise branch, commit 9a1045d
.
We print the memory usage using HF transformers trainer's on_step_end
callback. This is after optimizer.step(); model.zero_grad()
.
One would expect the memory usage at the end of step 1 to be the same as the end of step 2.
Yet for unpaged optimizer: memory usage leaps by 13.2GiB. End of step 1=70.4GiB, end of step 2=81.6GiB.
This appears to be a leap in PyTorch reserved memory only (32.6GiB -> 43.9GiB).