Skip to content

Instantly share code, notes, and snippets.

@YouJiacheng
Last active November 4, 2024 12:31
Show Gist options
  • Save YouJiacheng/434e44be48e7f9d37bc820163869ff97 to your computer and use it in GitHub Desktop.
Save YouJiacheng/434e44be48e7f9d37bc820163869ff97 to your computer and use it in GitHub Desktop.
import os
from typing import cast
import torch
import torch._inductor.config as config
import torch.distributed as dist
def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7) -> torch.Tensor: ...
config.triton.cudagraph_support_input_mutation = True
@torch.compile(mode="reduce-overhead", fullgraph=True)
def compute_updates(
params_shard: list[torch.Tensor],
bufs_shard: list[torch.Tensor],
slice_list_shard: list[slice],
scale_list_shard: list[float],
params: list[torch.Tensor],
slice_list: list[slice],
momentum: float,
):
n = sum(p.numel() for p in params)
updates_full_flat = torch.zeros(n, dtype=torch.bfloat16, device="cuda")
# assert all(p.grad is not None for p in self.params_shard)
grads_shard = [cast(torch.Tensor, p.grad) for p in params_shard]
torch._foreach_mul_(bufs_shard, momentum)
torch._foreach_add_(bufs_shard, grads_shard)
# avoid mutating inputs from eager
vs = torch._foreach_add(grads_shard, bufs_shard, alpha=momentum)
update_views_shard = [updates_full_flat[s] for s in slice_list_shard]
for u, s, v in zip(update_views_shard, scale_list_shard, vs):
torch.mul(zeropower_via_newtonschulz5(v, steps=5).flatten(), s, out=u)
# sync updates across devices. we are not memory-constrained so can do this simple deserialization
dist.all_reduce(updates_full_flat, op=dist.ReduceOp.SUM)
return [updates_full_flat[s].view_as(p) for s, p in zip(slice_list, params)]
class Muon(torch.optim.Optimizer):
def __init__(
self,
params,
lr: float | torch.Tensor = 0.02,
momentum=0.95,
nesterov=True,
backend="newtonschulz5",
backend_steps=5,
):
assert nesterov
defaults = dict(
lr=lr,
momentum=momentum,
nesterov=nesterov,
backend=backend,
backend_steps=backend_steps,
)
super().__init__(params, defaults)
assert len(self.param_groups) == 1
group = self.param_groups[0]
self.params_shard: list[torch.Tensor] = []
self.momentum_buffer_list_shard: list[torch.Tensor] = []
self.slice_list: list[slice] = []
self.slice_list_shard: list[slice] = []
self.scale_list_shard: list[float] = []
offset = 0
for i, p in enumerate(group["params"]):
assert isinstance(p, torch.Tensor)
_slice = slice(offset, offset + p.numel())
self.slice_list.append(_slice)
if i % int(os.environ["WORLD_SIZE"]) == int(os.environ["RANK"]):
self.params_shard.append(p)
buf = torch.zeros_like(p)
torch._dynamo.mark_static_address(buf)
self.momentum_buffer_list_shard.append(buf)
self.state[p]["momentum_buffer"] = buf
self.slice_list_shard.append(_slice)
self.scale_list_shard.append(max(1, p.size(0) / p.size(1)) ** 0.5)
offset += p.numel()
# Tensor LR is slower than excluding lr from the compiled function
@torch.no_grad()
def step(self):
group = self.param_groups[0]
update_views = compute_updates(
self.params_shard,
self.momentum_buffer_list_shard,
self.slice_list_shard,
self.scale_list_shard,
group["params"],
self.slice_list,
group["momentum"],
)
# apply updates
torch._foreach_add_(group["params"], update_views, alpha=-group["lr"])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment