Last active
November 4, 2024 12:31
-
-
Save YouJiacheng/434e44be48e7f9d37bc820163869ff97 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from typing import cast | |
import torch | |
import torch._inductor.config as config | |
import torch.distributed as dist | |
def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7) -> torch.Tensor: ... | |
config.triton.cudagraph_support_input_mutation = True | |
@torch.compile(mode="reduce-overhead", fullgraph=True) | |
def compute_updates( | |
params_shard: list[torch.Tensor], | |
bufs_shard: list[torch.Tensor], | |
slice_list_shard: list[slice], | |
scale_list_shard: list[float], | |
params: list[torch.Tensor], | |
slice_list: list[slice], | |
momentum: float, | |
): | |
n = sum(p.numel() for p in params) | |
updates_full_flat = torch.zeros(n, dtype=torch.bfloat16, device="cuda") | |
# assert all(p.grad is not None for p in self.params_shard) | |
grads_shard = [cast(torch.Tensor, p.grad) for p in params_shard] | |
torch._foreach_mul_(bufs_shard, momentum) | |
torch._foreach_add_(bufs_shard, grads_shard) | |
# avoid mutating inputs from eager | |
vs = torch._foreach_add(grads_shard, bufs_shard, alpha=momentum) | |
update_views_shard = [updates_full_flat[s] for s in slice_list_shard] | |
for u, s, v in zip(update_views_shard, scale_list_shard, vs): | |
torch.mul(zeropower_via_newtonschulz5(v, steps=5).flatten(), s, out=u) | |
# sync updates across devices. we are not memory-constrained so can do this simple deserialization | |
dist.all_reduce(updates_full_flat, op=dist.ReduceOp.SUM) | |
return [updates_full_flat[s].view_as(p) for s, p in zip(slice_list, params)] | |
class Muon(torch.optim.Optimizer): | |
def __init__( | |
self, | |
params, | |
lr: float | torch.Tensor = 0.02, | |
momentum=0.95, | |
nesterov=True, | |
backend="newtonschulz5", | |
backend_steps=5, | |
): | |
assert nesterov | |
defaults = dict( | |
lr=lr, | |
momentum=momentum, | |
nesterov=nesterov, | |
backend=backend, | |
backend_steps=backend_steps, | |
) | |
super().__init__(params, defaults) | |
assert len(self.param_groups) == 1 | |
group = self.param_groups[0] | |
self.params_shard: list[torch.Tensor] = [] | |
self.momentum_buffer_list_shard: list[torch.Tensor] = [] | |
self.slice_list: list[slice] = [] | |
self.slice_list_shard: list[slice] = [] | |
self.scale_list_shard: list[float] = [] | |
offset = 0 | |
for i, p in enumerate(group["params"]): | |
assert isinstance(p, torch.Tensor) | |
_slice = slice(offset, offset + p.numel()) | |
self.slice_list.append(_slice) | |
if i % int(os.environ["WORLD_SIZE"]) == int(os.environ["RANK"]): | |
self.params_shard.append(p) | |
buf = torch.zeros_like(p) | |
torch._dynamo.mark_static_address(buf) | |
self.momentum_buffer_list_shard.append(buf) | |
self.state[p]["momentum_buffer"] = buf | |
self.slice_list_shard.append(_slice) | |
self.scale_list_shard.append(max(1, p.size(0) / p.size(1)) ** 0.5) | |
offset += p.numel() | |
# Tensor LR is slower than excluding lr from the compiled function | |
@torch.no_grad() | |
def step(self): | |
group = self.param_groups[0] | |
update_views = compute_updates( | |
self.params_shard, | |
self.momentum_buffer_list_shard, | |
self.slice_list_shard, | |
self.scale_list_shard, | |
group["params"], | |
self.slice_list, | |
group["momentum"], | |
) | |
# apply updates | |
torch._foreach_add_(group["params"], update_views, alpha=-group["lr"]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment