Skip to content

Instantly share code, notes, and snippets.

@yhcharles
Created July 17, 2022 16:53
Show Gist options
  • Select an option

  • Save yhcharles/0520b1b9cdaa77822e080536de9ff346 to your computer and use it in GitHub Desktop.

Select an option

Save yhcharles/0520b1b9cdaa77822e080536de9ff346 to your computer and use it in GitHub Desktop.
Compare distributed APIs of OneFlow and jax by implementing DistributedDataParallel
#%%
import jax
import jax.numpy as jnp
from jax import grad, pmap, lax
import numpy as np
from functools import partial
print(jax.devices())
# %%
def synthetic_data(w, b, num_examples):
X = np.random.normal(size=(num_examples, len(w)))
y = np.matmul(X, w.reshape((2, 1))) + b
y += np.random.normal(size=y.shape) * 0.01
return X, y.reshape((-1, 1))
#%%
def model(X, w, b):
return jnp.matmul(X, w) + b
def loss_func(y_hat, y):
return ((y_hat - y.reshape(y_hat.shape)) ** 2 / 2).sum()
def model_loss(w, b, X, y):
return loss_func(model(X, w, b), y)
def main():
true_w = np.array([2, -3.4])
true_b = 4.2
print(f"{true_w = }, {true_b = }")
w = np.array([[100.0], [100.0]])
b = np.array([1.0])
print(f"{w = }, {b = }")
num_devices = jax.device_count()
replicate_array = lambda x: np.broadcast_to(x, (num_devices,) + x.shape)
# replicate params to all devices
rep_w, rep_b = replicate_array(w), replicate_array(b)
num_epochs, num_iters, batch_size = 5, 1000, 32
assert (
batch_size % num_devices == 0
), "batch_size must be divisible by num_devices"
# add an extra dimension to the data to allow for parallelization
def shard_data(X, y):
num_shards = num_devices
shard_size = batch_size // num_shards
shard_X = X.reshape(num_devices, shard_size, *X.shape[1:])
shard_y = y.reshape(num_devices, shard_size, *y.shape[1:])
return shard_X, shard_y
lr = 1e-3
@partial(pmap, axis_name="batch")
def spmd_udpate(w, b, X, y):
grads = grad(model_loss, argnums=[0, 1])(w, b, X, y)
dw, db = [lax.psum(delta, "batch") for delta in grads]
return w - lr * grads[0], b - lr * grads[1]
for epoch in range(num_epochs):
for i in range(num_iters):
X, y = synthetic_data(true_w, true_b, batch_size)
X, y = shard_data(X, y)
rep_w, rep_b = spmd_udpate(rep_w, rep_b, X, y)
print(f"{rep_w[0] = }, {rep_b[0] = }")
if __name__ == "__main__":
main()
""" This is a demo shows how to implement distributed data parallel with
OneFlow's global tensor and SBP.
We build a simple linear regression model in this demo. The training data is
generated from a ground truth model given by the equation:
Y = WX + B
with some adding noise. Each rank generates its own dataset, simulating reading
from different data shards in data parallel. The model parameter `w` and `b`
are replicated across all ranks. The result shows model parameters converges to
the ground truth value.
To run this demo, we need to start at least 2 shells:
# shell 0
MASTER_ADDR=127.0.0.1 MASTER_PORT=12701 WORLD_SIZE=2 RANK=0 LOCAL_RANK=0 python a.py
# shell 1
MASTER_ADDR=127.0.0.1 MASTER_PORT=12701 WORLD_SIZE=2 RANK=1 LOCAL_RANK=1 python a.py
"""
import os
import oneflow as flow
def synthetic_data(w, b, num_examples):
X = flow.randn(num_examples, len(w))
y = flow.matmul(X, w.reshape((2, 1)).to(X.device)) + b
y += flow.randn(*y.shape) * 0.01
return X, y.reshape((-1, 1))
def model(X, w, b):
return flow.matmul(X, w) + b
def loss_func(y_hat, y):
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2
def main():
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
placement = flow.placement("cpu", list(range(world_size)))
sbp0 = flow.sbp.split(0)
# params of the ground truth model
true_w = flow.tensor([2, -3.4])
true_b = 4.2
print(f"{rank = }, {true_w = }, {true_b = }")
# model params, replicated to all ranks with `sbp=flow.sbp.broadcast`
w = flow.tensor(
[[100.0], [100.0]],
requires_grad=True,
placement=placement,
sbp=flow.sbp.broadcast,
)
b = flow.zeros(
1, requires_grad=True, placement=placement, sbp=flow.sbp.broadcast
)
print(f"{rank = }, {w = }, {b = }")
optimizer = flow.optim.SGD([w, b], lr=1e-3)
num_epochs, num_iters, batch_size = 5, 100, 10
for epoch in range(num_epochs):
for i in range(num_iters):
# X, y are local tensors, we need to create global tensor
# as model input to implement data parallel
X, y = synthetic_data(true_w, true_b, batch_size)
g_X = X.to_global(placement=placement, sbp=sbp0)
g_y = y.to_global(placement=placement, sbp=sbp0)
l = loss_func(model(g_X, w, b), g_y)
optimizer.zero_grad()
l.sum().backward()
optimizer.step()
print(f"{rank = }, {w = }, {b = }")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment