Created
July 17, 2022 16:53
-
-
Save yhcharles/0520b1b9cdaa77822e080536de9ff346 to your computer and use it in GitHub Desktop.
Compare distributed APIs of OneFlow and jax by implementing DistributedDataParallel
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #%% | |
| import jax | |
| import jax.numpy as jnp | |
| from jax import grad, pmap, lax | |
| import numpy as np | |
| from functools import partial | |
| print(jax.devices()) | |
| # %% | |
| def synthetic_data(w, b, num_examples): | |
| X = np.random.normal(size=(num_examples, len(w))) | |
| y = np.matmul(X, w.reshape((2, 1))) + b | |
| y += np.random.normal(size=y.shape) * 0.01 | |
| return X, y.reshape((-1, 1)) | |
| #%% | |
| def model(X, w, b): | |
| return jnp.matmul(X, w) + b | |
| def loss_func(y_hat, y): | |
| return ((y_hat - y.reshape(y_hat.shape)) ** 2 / 2).sum() | |
| def model_loss(w, b, X, y): | |
| return loss_func(model(X, w, b), y) | |
| def main(): | |
| true_w = np.array([2, -3.4]) | |
| true_b = 4.2 | |
| print(f"{true_w = }, {true_b = }") | |
| w = np.array([[100.0], [100.0]]) | |
| b = np.array([1.0]) | |
| print(f"{w = }, {b = }") | |
| num_devices = jax.device_count() | |
| replicate_array = lambda x: np.broadcast_to(x, (num_devices,) + x.shape) | |
| # replicate params to all devices | |
| rep_w, rep_b = replicate_array(w), replicate_array(b) | |
| num_epochs, num_iters, batch_size = 5, 1000, 32 | |
| assert ( | |
| batch_size % num_devices == 0 | |
| ), "batch_size must be divisible by num_devices" | |
| # add an extra dimension to the data to allow for parallelization | |
| def shard_data(X, y): | |
| num_shards = num_devices | |
| shard_size = batch_size // num_shards | |
| shard_X = X.reshape(num_devices, shard_size, *X.shape[1:]) | |
| shard_y = y.reshape(num_devices, shard_size, *y.shape[1:]) | |
| return shard_X, shard_y | |
| lr = 1e-3 | |
| @partial(pmap, axis_name="batch") | |
| def spmd_udpate(w, b, X, y): | |
| grads = grad(model_loss, argnums=[0, 1])(w, b, X, y) | |
| dw, db = [lax.psum(delta, "batch") for delta in grads] | |
| return w - lr * grads[0], b - lr * grads[1] | |
| for epoch in range(num_epochs): | |
| for i in range(num_iters): | |
| X, y = synthetic_data(true_w, true_b, batch_size) | |
| X, y = shard_data(X, y) | |
| rep_w, rep_b = spmd_udpate(rep_w, rep_b, X, y) | |
| print(f"{rep_w[0] = }, {rep_b[0] = }") | |
| if __name__ == "__main__": | |
| main() | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ This is a demo shows how to implement distributed data parallel with | |
| OneFlow's global tensor and SBP. | |
| We build a simple linear regression model in this demo. The training data is | |
| generated from a ground truth model given by the equation: | |
| Y = WX + B | |
| with some adding noise. Each rank generates its own dataset, simulating reading | |
| from different data shards in data parallel. The model parameter `w` and `b` | |
| are replicated across all ranks. The result shows model parameters converges to | |
| the ground truth value. | |
| To run this demo, we need to start at least 2 shells: | |
| # shell 0 | |
| MASTER_ADDR=127.0.0.1 MASTER_PORT=12701 WORLD_SIZE=2 RANK=0 LOCAL_RANK=0 python a.py | |
| # shell 1 | |
| MASTER_ADDR=127.0.0.1 MASTER_PORT=12701 WORLD_SIZE=2 RANK=1 LOCAL_RANK=1 python a.py | |
| """ | |
| import os | |
| import oneflow as flow | |
| def synthetic_data(w, b, num_examples): | |
| X = flow.randn(num_examples, len(w)) | |
| y = flow.matmul(X, w.reshape((2, 1)).to(X.device)) + b | |
| y += flow.randn(*y.shape) * 0.01 | |
| return X, y.reshape((-1, 1)) | |
| def model(X, w, b): | |
| return flow.matmul(X, w) + b | |
| def loss_func(y_hat, y): | |
| return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2 | |
| def main(): | |
| rank = int(os.environ["RANK"]) | |
| world_size = int(os.environ["WORLD_SIZE"]) | |
| placement = flow.placement("cpu", list(range(world_size))) | |
| sbp0 = flow.sbp.split(0) | |
| # params of the ground truth model | |
| true_w = flow.tensor([2, -3.4]) | |
| true_b = 4.2 | |
| print(f"{rank = }, {true_w = }, {true_b = }") | |
| # model params, replicated to all ranks with `sbp=flow.sbp.broadcast` | |
| w = flow.tensor( | |
| [[100.0], [100.0]], | |
| requires_grad=True, | |
| placement=placement, | |
| sbp=flow.sbp.broadcast, | |
| ) | |
| b = flow.zeros( | |
| 1, requires_grad=True, placement=placement, sbp=flow.sbp.broadcast | |
| ) | |
| print(f"{rank = }, {w = }, {b = }") | |
| optimizer = flow.optim.SGD([w, b], lr=1e-3) | |
| num_epochs, num_iters, batch_size = 5, 100, 10 | |
| for epoch in range(num_epochs): | |
| for i in range(num_iters): | |
| # X, y are local tensors, we need to create global tensor | |
| # as model input to implement data parallel | |
| X, y = synthetic_data(true_w, true_b, batch_size) | |
| g_X = X.to_global(placement=placement, sbp=sbp0) | |
| g_y = y.to_global(placement=placement, sbp=sbp0) | |
| l = loss_func(model(g_X, w, b), g_y) | |
| optimizer.zero_grad() | |
| l.sum().backward() | |
| optimizer.step() | |
| print(f"{rank = }, {w = }, {b = }") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment