Recall, MLX is lazy. No actual computation happens until you explicitly or implicitly evaluate the graph. Even loading arrays from a file is lazy:
weights = mx.load("model.safetensors")
The above function returns instantly, regardless of the file size.
To actually load the weights into memory, you can do mx.eval(weights)
.
Assume the weights are stored on disk in 32-bit precision (i.e. mx.float32
).
But for your model you only need 16-bit precision. If you do this:
weights = mx.load("model.safetensors")
mx.eval(weights)
weights = {k: v.astype(mx.float16) for k, v in weights.items()}
the weights will be loaded into memory in full precision and then cast to 16-bit. This requires memory for all the weights in 32-bit plus memory for the weights in 16-bit.
This is much better:
weights = mx.load("model.safetensors")
weights = {k: v.astype(mx.float16) for k, v in weights.items()}
mx.eval(weights)
Evaluating after the cast to mx.float16
reduces peak memory by nearly a third.
That's because all the weights are never fully materialized in 32-bit.
Right after each weight is loaded in 32-bit precision it is cast to 16-bit.
The memory for the 32-bit weight can be reused when loading the next weight.
Note, MLX is only able to lazy load from a file when it is given to mx.load
as a string path. Due to lifetime management issues, lazy loading from file handles
is not supported. So avoid this:
weights = mx.load(open("model.safetensors", 'rb'))