Last active
July 20, 2022 08:19
-
-
Save kingoflolz/b0c9d36a8919378f5e33ee4fa87f8a42 to your computer and use it in GitHub Desktop.
End to end example of model parallelism in Haiku + Jax, proving that gradients are allclose to the unsharded case
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
# Run with a bunch of CPU devices. | |
def setUpModule(): | |
global prev_xla_flags | |
prev_xla_flags = os.getenv("XLA_FLAGS") | |
flags_str = prev_xla_flags or "" | |
if "xla_force_host_platform_device_count" not in flags_str: | |
os.environ["XLA_FLAGS"] = (flags_str + " --xla_force_host_platform_device_count=4") | |
setUpModule() | |
import haiku as hk | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
from jax.experimental.maps import thread_resources | |
# identity in forward pass, psum in backward | |
@jax.custom_vjp | |
def f_psum(x): | |
return x | |
def f_psum_fwd(x): | |
return f_psum(x), None | |
def f_psum_bwd(_, g): | |
return jax.lax.psum(g, "shard"), | |
f_psum.defvjp(f_psum_fwd, f_psum_bwd) | |
# psum in forward pass, identity in backward | |
@jax.custom_vjp | |
def g_psum(x): | |
return jax.lax.psum(x, "shard") | |
def g_psum_fwd(x): | |
return g_psum(x), None | |
def g_psum_bwd(_, g): | |
return g, | |
g_psum.defvjp(g_psum_fwd, g_psum_bwd) | |
class EmbeddingShard(hk.Module): | |
def __init__(self, vocab, d_model, shards): | |
super().__init__() | |
assert d_model % shards == 0 | |
self.in_dim = vocab | |
self.out_dim = d_model | |
self.in_dim_per_shard = vocab // shards | |
self.proj = hk.Linear(self.out_dim, w_init=hk.initializers.TruncatedNormal(stddev=1 / np.sqrt(vocab))) | |
def __call__(self, x): | |
shard_start_index = jax.lax.axis_index('shard') * self.in_dim_per_shard | |
shard_index = jnp.arange(0, self.in_dim_per_shard) + shard_start_index | |
proj_out = self.proj((shard_index.reshape(1, -1) == x.reshape(-1, 1)).astype(jnp.float32)) | |
return g_psum(proj_out) | |
class FFLayerShard(hk.Module): | |
def __init__(self, dim, shards): | |
super().__init__() | |
self.dense_proj = hk.Linear(4 * dim // shards) | |
self.dense_proj_o = hk.Linear(dim, w_init=hk.initializers.TruncatedNormal(stddev=1 / np.sqrt(dim))) | |
def __call__(self, x): | |
x = f_psum(x) | |
dense_proj = self.dense_proj(x) | |
dense_proj = jax.nn.gelu(dense_proj) | |
dense_out = self.dense_proj_o(dense_proj) | |
return g_psum(dense_out) | |
class ProjectionShard(hk.Module): | |
def __init__(self, vocab, shards): | |
super().__init__() | |
assert vocab % shards == 0 | |
self.dim = vocab | |
self.dim_per_shard = vocab // shards | |
self.proj = hk.Linear(self.dim_per_shard) | |
def __call__(self, x): | |
proj = self.proj(x) | |
all_proj = jax.lax.all_gather(proj, 'shard') | |
return hk.Flatten()(jnp.transpose(all_proj, (1, 0, 2))) | |
def loss(self, x, targets, z_loss=False): | |
x = f_psum(x) | |
logits = self.proj(x).astype(jnp.float32) | |
shard_start_index = jax.lax.axis_index('shard') * self.dim_per_shard | |
global_max = jax.lax.pmax(jax.lax.stop_gradient(logits.max(-1, keepdims=True)), "shard") | |
logits -= jax.lax.stop_gradient(global_max) | |
gt_onehot = jax.nn.one_hot(targets - shard_start_index, self.dim_per_shard) | |
predicted_logits = jnp.sum(jnp.multiply(gt_onehot, logits), axis=-1) | |
predicted_logits = g_psum(predicted_logits) | |
exp_logits = jnp.exp(logits) | |
sum_exp_logits = exp_logits.sum(axis=-1) | |
sum_exp_logits = g_psum(sum_exp_logits) | |
return jnp.log(sum_exp_logits) - predicted_logits | |
class Model: | |
def __init__(self, shards): | |
def train_loss(x, y): | |
x = EmbeddingShard(256, 128, shards)(x) | |
for i in range(2): | |
x += FFLayerShard(128, shards)(x) | |
return ProjectionShard(256, shards).loss(x, y).mean() | |
def grad(params, ctx, tgt): | |
train_loss_fn = hk.without_apply_rng(hk.transform(train_loss)).apply | |
value, grad = jax.value_and_grad(train_loss_fn)(params, ctx, tgt) | |
return value, grad | |
def init(key, x, y): | |
param_init_fn = hk.transform(train_loss).init | |
return param_init_fn(key, x, y) | |
self.train_loss = train_loss | |
self.grad = grad | |
self.init = init | |
self.init_xmap = jax.experimental.maps.xmap(fun=init, | |
in_axes=(["shard", ...], | |
["batch", ...], | |
["batch", ...]), | |
out_axes=["shard", ...], | |
axis_resources={'shard': 'mp', 'batch': 'dp'}) | |
self.grad_xmap = jax.experimental.maps.xmap(fun=grad, | |
in_axes=(["shard", ...], | |
["batch", ...], | |
["batch", ...]), | |
out_axes=(["batch", ...], ["shard", ...]), | |
axis_resources={'shard': 'mp', 'batch': 'dp'}) | |
def unshard(sharded, unsharded): | |
s_shape = sharded.shape | |
u_shape = unsharded.shape | |
if len(s_shape) == 2: | |
if np.prod(s_shape) == np.prod(u_shape): | |
new_unsharded = sharded.reshape(u_shape) | |
else: | |
new_unsharded = sharded[:1] | |
elif s_shape[0] * s_shape[2] == u_shape[2]: | |
new_unsharded = jnp.transpose(sharded, (1, 0, 2)).reshape(u_shape) | |
elif s_shape[0] * s_shape[1] == u_shape[1]: | |
new_unsharded = sharded.reshape(u_shape) | |
else: | |
raise Exception("unimplemented") | |
assert unsharded.shape == new_unsharded.shape | |
return new_unsharded | |
key = hk.PRNGSequence(42) | |
x = jax.random.uniform(next(key), (1, 64), maxval=256).astype(jnp.int32) | |
y = jax.random.uniform(next(key), (1, 64), maxval=256).astype(jnp.int32) | |
devices = np.array(jax.devices()).reshape((1, 4)) | |
with jax.experimental.maps.mesh(devices, ('dp', 'mp')): | |
sharded = Model(4) | |
shard_state = sharded.init_xmap(jnp.array(key.take(4)), x, y) | |
shard_out = sharded.grad_xmap(shard_state, x, y) | |
device = np.array(jax.devices()[:1]).reshape((1, 1)) | |
with jax.experimental.maps.mesh(device, ('dp', 'mp')): | |
unsharded = Model(1) | |
state = unsharded.init_xmap(jnp.array(key.take(1)), x, y) | |
new_state = jax.tree_multimap(unshard, shard_state, state) | |
out = unsharded.grad_xmap(new_state, x, y) | |
def check_close(x, y): | |
if x.shape != y.shape: | |
x = unshard(x, y) | |
x = np.array(x) | |
y = np.array(y) | |
assert np.allclose(x, y) | |
jax.tree_multimap(check_close, shard_out, out) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment