Skip to content

Instantly share code, notes, and snippets.

@zhangqiaorjc
Created January 20, 2022 03:24
Show Gist options
  • Select an option

  • Save zhangqiaorjc/1166435bd68fd14c8e338a2399afa64d to your computer and use it in GitHub Desktop.

Select an option

Save zhangqiaorjc/1166435bd68fd14c8e338a2399afa64d to your computer and use it in GitHub Desktop.
spmd_pipeline_xmap.py
import itertools as it
import jax
import jax.numpy as jnp
from jax.experimental import maps
jax.config.update('jax_enable_x64', True)
jax.config.update('jax_platform_name', 'cpu')
jax.config.update('experimental_xmap_spmd_lowering', True)
L = num_stages = 5
N = batch_size = 6
M = num_microbatches = 2
B = microbatch_size = 3
assert N == M * B
F = num_feat = 3
params = jax.random.normal(jax.random.PRNGKey(0), (L, F, F))
inputs = jnp.arange(N * F).reshape(N, F).astype(jnp.float32)
def fn(params, inputs):
assert params.ndim == 2 and inputs.ndim == 1
return jnp.tanh(jnp.dot(params, inputs))
state = inputs
for i in range(L):
state = jax.vmap(fn, (None, 0))(params[i], state)
outputs = state
print(outputs)
def spmd_pipeline(fn, params, inputs):
inputs = jnp.pad(inputs[:, None], [[0, L-1], [0, L-1], [0, 0], [0, 0]])
outputs = jnp.zeros((M+L-1, B, F))
state = jnp.zeros([L, B, F])
for i in range(M + L - 1):
state = shift_and_insert(state, inputs[i])
# Takes ((F, F), (B, F))
batched_fn = jax.vmap(fn, (None, 0))
# Takes ((L, F, F), (B, F))
state = maps.xmap(batched_fn,
in_axes=['num_stages', ...],
out_axes=['num_stages', ...],
axis_resources={})(params, state)
outputs = outputs.at[i].set(state[-1]) # last layer output
return outputs[L-1:]
def shift_and_insert(arr, x):
padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1)
arr = jnp.pad(arr, padding)[:-1]
iota = jax.lax.broadcasted_iota('int32', arr.shape, 0)
return jnp.where(iota == 0, x, arr)
outputs2 = spmd_pipeline(fn, params, inputs.reshape(M, B, F)).reshape(N, F)
print(outputs2)
def loss(params, inputs):
y = spmd_pipeline(fn, params, inputs.reshape(M, B, F))
return jnp.sum(y)
print(jax.grad(loss, argnums=1)(params, inputs))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment