MLX is lazy. No actual computation happens until you explicitly or implicitly evaluate the graph. Here are some ways that can happen:
- Explicit call to
mx.eval
- Call
a.item()
on a scalar array - Convert an array to NumPy, i.e.
np.array(a)
- Print an array
There is no hard-and-fast rule for when to evaluate the graph. Evaluating the graph has overhead since it is a synchronizing event. Do enough work between evaluations to amortize that overhead. Usually numerical computations are iterative. After each iteration is typically a good time to evaluate the graph. For example, a good time could be after one optimization step of SGD or generating one token with a language model.
Here is a common pitfall worth avoiding though:
counts = collections.Counter()
values = mx.random.randint(0, 100, shape=(100,))
for x in values:
counts[x.item()] += 1
That is bad because every access to values is a slice (to get the scalar) followed by a graph evaluation. To see it more explicitly you can write the exact same computation this way:
counts = collections.Counter()
values = mx.random.randint(0, 100, shape=(100,))
for i in range(values.size):
x = values[i]
mx.eval(x)
counts[x.item()] += 1
This is much better:
counts = collections.Counter()
values = mx.random.randint(0, 100, shape=(100,))
for x in values.tolist():
counts[x] += 1
This pattern is common (though values
is usually the output of some other computation
and you aren't typically counting its entries). Timing the two loops on an M1 Max gives
2.99 milliseconds for the slow version and 0.012 milliseconds for the fast one,
almost 250 times faster.