Skip to content

Instantly share code, notes, and snippets.

@yejingxin
Created November 29, 2022 22:42
Show Gist options
  • Save yejingxin/b0ca4490135297836f48e2c27c2e532d to your computer and use it in GitHub Desktop.
Save yejingxin/b0ca4490135297836f48e2c27c2e532d to your computer and use it in GitHub Desktop.

Notes to understand indeterministic output from different reduce order more:

A simple concrete reduce example is summation, if you sum up numbers in different order, it will produce different results. i.e. (n_1 + ((n_2 + n_3) + n_4)) != (((n_1 + n_2) + n_3) + n_4). The following simple code snippet simulates matrix multiplication with reduce sum operation, let us assume X is the input batch with all rows the same, and W is the weight matrix, and Z = XW is the model output, and input batch X is partitioned by column, and weight W is partitioned by row, and both of them are partitioned into four shards,

# this cell runs on CPU, it takes about a few mins to finish.
import numpy as np

x = np.random.normal(size=[8192])
x = np.stack([x] * 16, axis=0)
assert x.shape == (16, 8192)
w = np.random.normal(size=[8192,65536])
z = np.zeros((x.shape[0], y.shape[1]))

num_shards = 4
shards = np.arange(num_shards)
shard_size = x.shape[1] // num_shards
for i in range(x.shape[0]):
  for j in range(w.shape[1]):
    np.random.shuffle(shards)
    for shard in shards:
      z[i, j] += np.dot(x[i, shard * shard_size : (shard + 1) * shard_size],
                        w[shard * shard_size : (shard + 1) * shard_size, j])
  if i > 0:
    print(f"within first {i + 1} samples: error={(z[:(i+1),:].min(axis=0) != z[:(i+1),:].max(axis=0)).mean()}")

it is expected to observe error between 0.35 and 0.76. We can see similar error in TPU computation.

def main(unused_argv):
  override_flags()
  if FLAGS.disable_logging:
    tf.get_logger().setLevel('CRITICAL')
  cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
         FLAGS.tpu)
  _cluster_def = cluster_resolver.cluster_spec().as_cluster_def()
  _tpu = cluster_resolver.master()
  def _no_opt_sess_cfg():
    # Disable constant folding for convenience.
    return tf.config_pb2.ConfigProto(
        graph_options=tf.GraphOptions(
            optimizer_options=tf.OptimizerOptions(
                opt_level=tf.OptimizerOptions.L0,
                do_common_subexpression_elimination=False,
                do_function_inlining=False,
                do_constant_folding=False)),
        cluster_def=_cluster_def)
  sess = tf.Session(_tpu, config=_no_opt_sess_cfg())
  with sess.as_default(), sess.graph.as_default():
    tf.logging.info('Initializing TPU system')
    init_tpu = tf.tpu.initialize_system()
    topology_str = sess.run(init_tpu)
    topology = tf.tpu.experimental.Topology(topology_str)
    if FLAGS.computation_shape:
      computation_shape = list(map(int, FLAGS.computation_shape.split(',')))
    else:
      computation_shape = list(topology.mesh_shape)
    tf.logging.info('computation_shape: %r', computation_shape)
    tpu_cores = functools.reduce(lambda x, y: x * y, computation_shape)
    tf.logging.info('tpu_cores: %d', tpu_cores)
    device_assignment = tpu_device_assignment.device_assignment(
        topology_str, computation_shape=computation_shape, num_replicas=1)
    print('device assignmnet:', device_assignment)

    mesh_shape = [8, -1]
    device_mesh = np.arange(tpu_cores).reshape(mesh_shape)
    print('device_mesh:', device_mesh)
    print('mesh_shape:', device_mesh.shape)

    def tpu_fn():
      # x is a matrix [1024,8192] where all 1024 rows are the same
      x = tf.random.normal([8192],
                           mean=0.0,
                           stddev=1.0,
                           dtype=tf.bfloat16,
                           seed=10)
      x = tf.stack([x] * 1024, axis=0)
      assert x.shape == (1024, 8192)

      # x is sharded over mesh
      x = xla_sharding.mesh_split(
          x, device_mesh, [0, 1], use_sharding_op=True)

      # y is a random matrix
      y = tf.random.normal([8192, 65536],
                           mean=0.0,
                           stddev=1.0,
                           dtype=tf.bfloat16,
                           seed=20)

      # all rows in z should be the same
      return tf.einsum('ab,bc->ac', x, y)

    xla_op, run_op = tpu.split_compile_and_shard(
        computation=tpu_fn, num_shards=1, device_assignment=device_assignment)
    sess.run(xla_op)

    ret = sess.run(run_op)
    z = ret[0]
    tf.logging.info('row diff=%e', (z.min(axis=0) != z.max(axis=0)).mean())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment