Skip to content

Instantly share code, notes, and snippets.

@awni
Last active November 9, 2024 04:33
Show Gist options
  • Save awni/f9d14ed391853e8ab7c7ed1a14ed90a2 to your computer and use it in GitHub Desktop.
Save awni/f9d14ed391853e8ab7c7ed1a14ed90a2 to your computer and use it in GitHub Desktop.

Use Lazy Loading to Reduce Peak Memory Use

Recall, MLX is lazy. No actual computation happens until you explicitly or implicitly evaluate the graph. Even loading arrays from a file is lazy:

weights = mx.load("model.safetensors")

The above function returns instantly, regardless of the file size. To actually load the weights into memory, you can do mx.eval(weights).

Assume the weights are stored on disk in 32-bit precision (i.e. mx.float32). But for your model you only need 16-bit precision. If you do this:

weights = mx.load("model.safetensors")
mx.eval(weights)
weights = {k: v.astype(mx.float16) for k, v in weights.items()}

the weights will be loaded into memory in full precision and then cast to 16-bit. This requires memory for all the weights in 32-bit plus memory for the weights in 16-bit.

This is much better:

weights = mx.load("model.safetensors")
weights = {k: v.astype(mx.float16) for k, v in weights.items()}
mx.eval(weights)

Evaluating after the cast to mx.float16 reduces peak memory by nearly a third. That's because all the weights are never fully materialized in 32-bit. Right after each weight is loaded in 32-bit precision it is cast to 16-bit. The memory for the 32-bit weight can be reused when loading the next weight.

Note, MLX is only able to lazy load from a file when it is given to mx.load as a string path. Due to lifetime management issues, lazy loading from file handles is not supported. So avoid this:

weights = mx.load(open("model.safetensors", 'rb'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment