Skip to content

Instantly share code, notes, and snippets.

@kingoflolz
Created March 13, 2021 10:52
Show Gist options
  • Save kingoflolz/5fdd80a39fbc58a57ac6c9f88d4626e3 to your computer and use it in GitHub Desktop.
Save kingoflolz/5fdd80a39fbc58a57ac6c9f88d4626e3 to your computer and use it in GitHub Desktop.
jax-sharded-transformer
import os
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
# Run with a bunch of CPU devices.
def setUpModule():
global prev_xla_flags
prev_xla_flags = os.getenv("XLA_FLAGS")
flags_str = prev_xla_flags or ""
if "xla_force_host_platform_device_count" not in flags_str:
os.environ["XLA_FLAGS"] = (flags_str + " --xla_force_host_platform_device_count=16")
setUpModule()
from transformer_shard import CausalTransformerShard
def causal_mask(ctx):
mask = jnp.zeros((ctx, ctx))
mask -= 10e10
mask = jnp.triu(mask, 1) # zero out the lower diagonal
return mask
def loss_fn(x):
model = CausalTransformerShard(128, 8, 4, 256)
return model(x)
init = jax.experimental.maps.xmap(fun=hk.transform(loss_fn).init,
in_axes=(["shard", ...],
["batch", ...]),
out_axes=["shard", ...],
axis_resources={'shard': 'shard', 'batch': 'batch'})
forward = jax.experimental.maps.xmap(fun=hk.without_apply_rng(hk.transform(loss_fn)).apply,
in_axes=(["shard", ...],
["batch", ...]),
out_axes=["batch", ...],
axis_resources={'shard': 'shard', 'batch': 'batch'})
key = hk.PRNGSequence(42)
x = jax.random.uniform(next(key), (8, 64), minval=0, maxval=255).astype(jnp.int32) # batch, len
devices = np.array(jax.devices()).reshape((2, 8))
with jax.experimental.maps.mesh(devices, ('batch', 'shard')):
state = init(jnp.array(key.take(8)), x)
o = forward(state, x)
import jax
import jax.numpy as jnp
import haiku as hk
import numpy as np
class EmbeddingShard(hk.Module):
def __init__(self, in_dim, out_dim, shards, name=None):
super().__init__(name=name)
assert in_dim % shards == 0
self.in_dim = in_dim
self.out_dim = out_dim
self.dim_per_shard = in_dim // shards
self.shards = shards
self.ln = hk.LayerNorm(-1, True, True)
self.proj = hk.Linear(self.out_dim)
def __call__(self, x):
shard_start_index = jax.lax.axis_index('shard') * self.dim_per_shard
shard_index = jnp.arange(0, self.dim_per_shard) + shard_start_index
proj_out = self.proj((shard_index.reshape(1, -1) == x.reshape(-1, 1)).astype(jnp.float32))
return jax.lax.pmean(proj_out, "shard")
class TransformerLayerShard(hk.Module):
def __init__(self, dim, heads, mask=None, name=None):
super().__init__(name=name)
assert dim % heads == 0
self.dim = dim
self.dim_per_head = dim // heads
self.heads = heads
if mask is None:
mask = jnp.zeros(())
self.mask = mask
self.ln = hk.LayerNorm(-1, True, True)
self.q = hk.Linear(self.dim_per_head)
self.v = hk.Linear(self.dim_per_head)
self.k = hk.Linear(self.dim_per_head)
self.o = hk.Linear(self.dim)
self.dense_proj = hk.Linear(self.dim_per_head * 4)
self.dense_proj_o = hk.Linear(self.dim)
def __call__(self, x):
x = self.ln(x)
q = self.q(x)
v = self.q(x)
k = self.q(x)
attention_logits = jnp.einsum("td,Td->tT", q, k)
sqrt_key_size = np.sqrt(self.dim_per_head).astype(k.dtype)
attention_logits = attention_logits / sqrt_key_size
attention_logits += self.mask
attention_weights = jax.nn.softmax(attention_logits)
attention_vec = jnp.einsum("tT,Td->td", attention_weights, v)
attn_out = self.o(attention_vec)
dense_proj = self.dense_proj(x)
dense_proj = jax.nn.gelu(dense_proj)
dense_out = self.dense_proj_o(dense_proj)
return jax.lax.pmean(attn_out + dense_out, "shard")
class ProjectionShard(hk.Module):
def __init__(self, out_dim, shards, name=None):
super().__init__(name=name)
assert out_dim % shards == 0
self.dim = out_dim
self.dim_per_shard = out_dim // shards
self.shards = shards
self.ln = hk.LayerNorm(-1, True, True)
self.proj = hk.Linear(self.dim_per_shard)
def __call__(self, x):
x = self.ln(x)
proj = self.proj(x)
all_proj = jax.lax.all_gather(proj, 'shard')
return hk.Flatten()(jnp.transpose(all_proj, (1, 0, 2)))
class CausalTransformerShard(hk.Module):
def __init__(self, dim, heads, layer_count, vocab):
super().__init__()
self.layers = []
self.embed = EmbeddingShard(vocab, dim, heads)
for i in range(layer_count):
self.layers.append(TransformerLayerShard(dim, heads))
self.proj = ProjectionShard(vocab, heads)
def __call__(self, x):
x = self.embed(x)
for l in self.layers:
x = x + l(x)
return self.proj(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment