Notes to understand indeterministic output from different reduce order more:
A simple concrete reduce example is summation, if you sum up numbers in different order, it will produce different results. i.e. (n_1 + ((n_2 + n_3) + n_4)) != (((n_1 + n_2) + n_3) + n_4).
The following simple code snippet simulates matrix multiplication with reduce sum operation, let us assume X is the input batch with all rows the same, and W is the weight matrix, and Z = XW is the model output, and input batch X is partitioned by column, and weight W is partitioned by row, and both of them are partitioned into four shards,
# this cell runs on CPU, it takes about a few mins to finish.
import numpy as np
x = np.random.normal(size=[8192])