Created
March 13, 2021 10:52
-
-
Save kingoflolz/5fdd80a39fbc58a57ac6c9f88d4626e3 to your computer and use it in GitHub Desktop.
jax-sharded-transformer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import haiku as hk | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
# Run with a bunch of CPU devices. | |
def setUpModule(): | |
global prev_xla_flags | |
prev_xla_flags = os.getenv("XLA_FLAGS") | |
flags_str = prev_xla_flags or "" | |
if "xla_force_host_platform_device_count" not in flags_str: | |
os.environ["XLA_FLAGS"] = (flags_str + " --xla_force_host_platform_device_count=16") | |
setUpModule() | |
from transformer_shard import CausalTransformerShard | |
def causal_mask(ctx): | |
mask = jnp.zeros((ctx, ctx)) | |
mask -= 10e10 | |
mask = jnp.triu(mask, 1) # zero out the lower diagonal | |
return mask | |
def loss_fn(x): | |
model = CausalTransformerShard(128, 8, 4, 256) | |
return model(x) | |
init = jax.experimental.maps.xmap(fun=hk.transform(loss_fn).init, | |
in_axes=(["shard", ...], | |
["batch", ...]), | |
out_axes=["shard", ...], | |
axis_resources={'shard': 'shard', 'batch': 'batch'}) | |
forward = jax.experimental.maps.xmap(fun=hk.without_apply_rng(hk.transform(loss_fn)).apply, | |
in_axes=(["shard", ...], | |
["batch", ...]), | |
out_axes=["batch", ...], | |
axis_resources={'shard': 'shard', 'batch': 'batch'}) | |
key = hk.PRNGSequence(42) | |
x = jax.random.uniform(next(key), (8, 64), minval=0, maxval=255).astype(jnp.int32) # batch, len | |
devices = np.array(jax.devices()).reshape((2, 8)) | |
with jax.experimental.maps.mesh(devices, ('batch', 'shard')): | |
state = init(jnp.array(key.take(8)), x) | |
o = forward(state, x) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.numpy as jnp | |
import haiku as hk | |
import numpy as np | |
class EmbeddingShard(hk.Module): | |
def __init__(self, in_dim, out_dim, shards, name=None): | |
super().__init__(name=name) | |
assert in_dim % shards == 0 | |
self.in_dim = in_dim | |
self.out_dim = out_dim | |
self.dim_per_shard = in_dim // shards | |
self.shards = shards | |
self.ln = hk.LayerNorm(-1, True, True) | |
self.proj = hk.Linear(self.out_dim) | |
def __call__(self, x): | |
shard_start_index = jax.lax.axis_index('shard') * self.dim_per_shard | |
shard_index = jnp.arange(0, self.dim_per_shard) + shard_start_index | |
proj_out = self.proj((shard_index.reshape(1, -1) == x.reshape(-1, 1)).astype(jnp.float32)) | |
return jax.lax.pmean(proj_out, "shard") | |
class TransformerLayerShard(hk.Module): | |
def __init__(self, dim, heads, mask=None, name=None): | |
super().__init__(name=name) | |
assert dim % heads == 0 | |
self.dim = dim | |
self.dim_per_head = dim // heads | |
self.heads = heads | |
if mask is None: | |
mask = jnp.zeros(()) | |
self.mask = mask | |
self.ln = hk.LayerNorm(-1, True, True) | |
self.q = hk.Linear(self.dim_per_head) | |
self.v = hk.Linear(self.dim_per_head) | |
self.k = hk.Linear(self.dim_per_head) | |
self.o = hk.Linear(self.dim) | |
self.dense_proj = hk.Linear(self.dim_per_head * 4) | |
self.dense_proj_o = hk.Linear(self.dim) | |
def __call__(self, x): | |
x = self.ln(x) | |
q = self.q(x) | |
v = self.q(x) | |
k = self.q(x) | |
attention_logits = jnp.einsum("td,Td->tT", q, k) | |
sqrt_key_size = np.sqrt(self.dim_per_head).astype(k.dtype) | |
attention_logits = attention_logits / sqrt_key_size | |
attention_logits += self.mask | |
attention_weights = jax.nn.softmax(attention_logits) | |
attention_vec = jnp.einsum("tT,Td->td", attention_weights, v) | |
attn_out = self.o(attention_vec) | |
dense_proj = self.dense_proj(x) | |
dense_proj = jax.nn.gelu(dense_proj) | |
dense_out = self.dense_proj_o(dense_proj) | |
return jax.lax.pmean(attn_out + dense_out, "shard") | |
class ProjectionShard(hk.Module): | |
def __init__(self, out_dim, shards, name=None): | |
super().__init__(name=name) | |
assert out_dim % shards == 0 | |
self.dim = out_dim | |
self.dim_per_shard = out_dim // shards | |
self.shards = shards | |
self.ln = hk.LayerNorm(-1, True, True) | |
self.proj = hk.Linear(self.dim_per_shard) | |
def __call__(self, x): | |
x = self.ln(x) | |
proj = self.proj(x) | |
all_proj = jax.lax.all_gather(proj, 'shard') | |
return hk.Flatten()(jnp.transpose(all_proj, (1, 0, 2))) | |
class CausalTransformerShard(hk.Module): | |
def __init__(self, dim, heads, layer_count, vocab): | |
super().__init__() | |
self.layers = [] | |
self.embed = EmbeddingShard(vocab, dim, heads) | |
for i in range(layer_count): | |
self.layers.append(TransformerLayerShard(dim, heads)) | |
self.proj = ProjectionShard(vocab, heads) | |
def __call__(self, x): | |
x = self.embed(x) | |
for l in self.layers: | |
x = x + l(x) | |
return self.proj(x) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment