Notes to understand indeterministic output from different reduce order more:
A simple concrete reduce example is summation, if you sum up numbers in different order, it will produce different results. i.e. (n_1 + ((n_2 + n_3) + n_4)) != (((n_1 + n_2) + n_3) + n_4)
.
The following simple code snippet simulates matrix multiplication with reduce sum operation, let us assume X
is the input batch with all rows the same, and W
is the weight matrix, and Z = XW
is the model output, and input batch X
is partitioned by column, and weight W
is partitioned by row, and both of them are partitioned into four shards,
# this cell runs on CPU, it takes about a few mins to finish.
import numpy as np
x = np.random.normal(size=[8192])
x = np.stack([x] * 16, axis=0)
assert x.shape == (16, 8192)
w = np.random.normal(size=[8192,65536])
z = np.zeros((x.shape[0], y.shape[1]))
num_shards = 4
shards = np.arange(num_shards)
shard_size = x.shape[1] // num_shards
for i in range(x.shape[0]):
for j in range(w.shape[1]):
np.random.shuffle(shards)
for shard in shards:
z[i, j] += np.dot(x[i, shard * shard_size : (shard + 1) * shard_size],
w[shard * shard_size : (shard + 1) * shard_size, j])
if i > 0:
print(f"within first {i + 1} samples: error={(z[:(i+1),:].min(axis=0) != z[:(i+1),:].max(axis=0)).mean()}")
it is expected to observe error between 0.35 and 0.76. We can see similar error in TPU computation.
def main(unused_argv):
override_flags()
if FLAGS.disable_logging:
tf.get_logger().setLevel('CRITICAL')
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
FLAGS.tpu)
_cluster_def = cluster_resolver.cluster_spec().as_cluster_def()
_tpu = cluster_resolver.master()
def _no_opt_sess_cfg():
# Disable constant folding for convenience.
return tf.config_pb2.ConfigProto(
graph_options=tf.GraphOptions(
optimizer_options=tf.OptimizerOptions(
opt_level=tf.OptimizerOptions.L0,
do_common_subexpression_elimination=False,
do_function_inlining=False,
do_constant_folding=False)),
cluster_def=_cluster_def)
sess = tf.Session(_tpu, config=_no_opt_sess_cfg())
with sess.as_default(), sess.graph.as_default():
tf.logging.info('Initializing TPU system')
init_tpu = tf.tpu.initialize_system()
topology_str = sess.run(init_tpu)
topology = tf.tpu.experimental.Topology(topology_str)
if FLAGS.computation_shape:
computation_shape = list(map(int, FLAGS.computation_shape.split(',')))
else:
computation_shape = list(topology.mesh_shape)
tf.logging.info('computation_shape: %r', computation_shape)
tpu_cores = functools.reduce(lambda x, y: x * y, computation_shape)
tf.logging.info('tpu_cores: %d', tpu_cores)
device_assignment = tpu_device_assignment.device_assignment(
topology_str, computation_shape=computation_shape, num_replicas=1)
print('device assignmnet:', device_assignment)
mesh_shape = [8, -1]
device_mesh = np.arange(tpu_cores).reshape(mesh_shape)
print('device_mesh:', device_mesh)
print('mesh_shape:', device_mesh.shape)
def tpu_fn():
# x is a matrix [1024,8192] where all 1024 rows are the same
x = tf.random.normal([8192],
mean=0.0,
stddev=1.0,
dtype=tf.bfloat16,
seed=10)
x = tf.stack([x] * 1024, axis=0)
assert x.shape == (1024, 8192)
# x is sharded over mesh
x = xla_sharding.mesh_split(
x, device_mesh, [0, 1], use_sharding_op=True)
# y is a random matrix
y = tf.random.normal([8192, 65536],
mean=0.0,
stddev=1.0,
dtype=tf.bfloat16,
seed=20)
# all rows in z should be the same
return tf.einsum('ab,bc->ac', x, y)
xla_op, run_op = tpu.split_compile_and_shard(
computation=tpu_fn, num_shards=1, device_assignment=device_assignment)
sess.run(xla_op)
ret = sess.run(run_op)
z = ret[0]
tf.logging.info('row diff=%e', (z.min(axis=0) != z.max(axis=0)).mean())