Skip to content

Instantly share code, notes, and snippets.

@kingoflolz
Last active July 20, 2022 08:19
Show Gist options
  • Save kingoflolz/b0c9d36a8919378f5e33ee4fa87f8a42 to your computer and use it in GitHub Desktop.
Save kingoflolz/b0c9d36a8919378f5e33ee4fa87f8a42 to your computer and use it in GitHub Desktop.
End to end example of model parallelism in Haiku + Jax, proving that gradients are allclose to the unsharded case
import os
# Run with a bunch of CPU devices.
def setUpModule():
global prev_xla_flags
prev_xla_flags = os.getenv("XLA_FLAGS")
flags_str = prev_xla_flags or ""
if "xla_force_host_platform_device_count" not in flags_str:
os.environ["XLA_FLAGS"] = (flags_str + " --xla_force_host_platform_device_count=4")
setUpModule()
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.maps import thread_resources
# identity in forward pass, psum in backward
@jax.custom_vjp
def f_psum(x):
return x
def f_psum_fwd(x):
return f_psum(x), None
def f_psum_bwd(_, g):
return jax.lax.psum(g, "shard"),
f_psum.defvjp(f_psum_fwd, f_psum_bwd)
# psum in forward pass, identity in backward
@jax.custom_vjp
def g_psum(x):
return jax.lax.psum(x, "shard")
def g_psum_fwd(x):
return g_psum(x), None
def g_psum_bwd(_, g):
return g,
g_psum.defvjp(g_psum_fwd, g_psum_bwd)
class EmbeddingShard(hk.Module):
def __init__(self, vocab, d_model, shards):
super().__init__()
assert d_model % shards == 0
self.in_dim = vocab
self.out_dim = d_model
self.in_dim_per_shard = vocab // shards
self.proj = hk.Linear(self.out_dim, w_init=hk.initializers.TruncatedNormal(stddev=1 / np.sqrt(vocab)))
def __call__(self, x):
shard_start_index = jax.lax.axis_index('shard') * self.in_dim_per_shard
shard_index = jnp.arange(0, self.in_dim_per_shard) + shard_start_index
proj_out = self.proj((shard_index.reshape(1, -1) == x.reshape(-1, 1)).astype(jnp.float32))
return g_psum(proj_out)
class FFLayerShard(hk.Module):
def __init__(self, dim, shards):
super().__init__()
self.dense_proj = hk.Linear(4 * dim // shards)
self.dense_proj_o = hk.Linear(dim, w_init=hk.initializers.TruncatedNormal(stddev=1 / np.sqrt(dim)))
def __call__(self, x):
x = f_psum(x)
dense_proj = self.dense_proj(x)
dense_proj = jax.nn.gelu(dense_proj)
dense_out = self.dense_proj_o(dense_proj)
return g_psum(dense_out)
class ProjectionShard(hk.Module):
def __init__(self, vocab, shards):
super().__init__()
assert vocab % shards == 0
self.dim = vocab
self.dim_per_shard = vocab // shards
self.proj = hk.Linear(self.dim_per_shard)
def __call__(self, x):
proj = self.proj(x)
all_proj = jax.lax.all_gather(proj, 'shard')
return hk.Flatten()(jnp.transpose(all_proj, (1, 0, 2)))
def loss(self, x, targets, z_loss=False):
x = f_psum(x)
logits = self.proj(x).astype(jnp.float32)
shard_start_index = jax.lax.axis_index('shard') * self.dim_per_shard
global_max = jax.lax.pmax(jax.lax.stop_gradient(logits.max(-1, keepdims=True)), "shard")
logits -= jax.lax.stop_gradient(global_max)
gt_onehot = jax.nn.one_hot(targets - shard_start_index, self.dim_per_shard)
predicted_logits = jnp.sum(jnp.multiply(gt_onehot, logits), axis=-1)
predicted_logits = g_psum(predicted_logits)
exp_logits = jnp.exp(logits)
sum_exp_logits = exp_logits.sum(axis=-1)
sum_exp_logits = g_psum(sum_exp_logits)
return jnp.log(sum_exp_logits) - predicted_logits
class Model:
def __init__(self, shards):
def train_loss(x, y):
x = EmbeddingShard(256, 128, shards)(x)
for i in range(2):
x += FFLayerShard(128, shards)(x)
return ProjectionShard(256, shards).loss(x, y).mean()
def grad(params, ctx, tgt):
train_loss_fn = hk.without_apply_rng(hk.transform(train_loss)).apply
value, grad = jax.value_and_grad(train_loss_fn)(params, ctx, tgt)
return value, grad
def init(key, x, y):
param_init_fn = hk.transform(train_loss).init
return param_init_fn(key, x, y)
self.train_loss = train_loss
self.grad = grad
self.init = init
self.init_xmap = jax.experimental.maps.xmap(fun=init,
in_axes=(["shard", ...],
["batch", ...],
["batch", ...]),
out_axes=["shard", ...],
axis_resources={'shard': 'mp', 'batch': 'dp'})
self.grad_xmap = jax.experimental.maps.xmap(fun=grad,
in_axes=(["shard", ...],
["batch", ...],
["batch", ...]),
out_axes=(["batch", ...], ["shard", ...]),
axis_resources={'shard': 'mp', 'batch': 'dp'})
def unshard(sharded, unsharded):
s_shape = sharded.shape
u_shape = unsharded.shape
if len(s_shape) == 2:
if np.prod(s_shape) == np.prod(u_shape):
new_unsharded = sharded.reshape(u_shape)
else:
new_unsharded = sharded[:1]
elif s_shape[0] * s_shape[2] == u_shape[2]:
new_unsharded = jnp.transpose(sharded, (1, 0, 2)).reshape(u_shape)
elif s_shape[0] * s_shape[1] == u_shape[1]:
new_unsharded = sharded.reshape(u_shape)
else:
raise Exception("unimplemented")
assert unsharded.shape == new_unsharded.shape
return new_unsharded
key = hk.PRNGSequence(42)
x = jax.random.uniform(next(key), (1, 64), maxval=256).astype(jnp.int32)
y = jax.random.uniform(next(key), (1, 64), maxval=256).astype(jnp.int32)
devices = np.array(jax.devices()).reshape((1, 4))
with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
sharded = Model(4)
shard_state = sharded.init_xmap(jnp.array(key.take(4)), x, y)
shard_out = sharded.grad_xmap(shard_state, x, y)
device = np.array(jax.devices()[:1]).reshape((1, 1))
with jax.experimental.maps.mesh(device, ('dp', 'mp')):
unsharded = Model(1)
state = unsharded.init_xmap(jnp.array(key.take(1)), x, y)
new_state = jax.tree_multimap(unshard, shard_state, state)
out = unsharded.grad_xmap(new_state, x, y)
def check_close(x, y):
if x.shape != y.shape:
x = unshard(x, y)
x = np.array(x)
y = np.array(y)
assert np.allclose(x, y)
jax.tree_multimap(check_close, shard_out, out)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment