Created
August 8, 2025 08:06
-
-
Save samsja/0a08c9e8937468de35adb6a44905bd75 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| from typing import Protocol | |
| import torch | |
| from torch.distributed.tensor import DTensor | |
| from torch.distributed import gather, scatter | |
| from collections import deque | |
| @torch.compile(fullgraph=True) | |
| def nsloop_torch(X: torch.Tensor, steps: int, *, a=3.4445, b=-4.7750, c=2.0315): | |
| """ | |
| When compiled down, inductor produces the following steps: | |
| 1. A = matmul X with reinterpret_tensor(X) | |
| 2. (triton) read A -> write b*A and c*A | |
| 3. B = addmm(b*A, c*A, A) | |
| 4. (triton) read X -> write a*X (this is stupid) | |
| 5. X = addmm(a*X, B, X) | |
| """ | |
| for _ in range(steps): | |
| A = X @ X.mT | |
| B = b * A + c * A @ A | |
| X = a * X + B @ X | |
| return X | |
| def zeropower_via_newtonschulz5(G, steps: int): | |
| """ | |
| Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a | |
| quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose | |
| of minimizing steps, it turns out to be empirically effective to keep increasing the slope at | |
| zero even beyond the point where the iteration no longer converges all the way to one everywhere | |
| on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T | |
| where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model | |
| performance at all relative to UV^T, where USV^T = G is the SVD. | |
| """ | |
| assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng | |
| a, b, c = (3.4445, -4.7750, 2.0315) | |
| X = G.bfloat16() | |
| if G.size(-2) > G.size(-1): | |
| X = X.mT | |
| # Ensure spectral norm is at most 1 | |
| X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) | |
| # Perform the NS iterations | |
| X = nsloop_torch(X, steps, a=a, b=b, c=c) | |
| if G.size(-2) > G.size(-1): | |
| X = X.mT | |
| return X | |
| def apply_momentum(grad, momentum, beta, nesterov): | |
| momentum.lerp_(grad, 1 - beta) | |
| update = grad.lerp_(momentum, beta) if nesterov else momentum | |
| if update.ndim == 4: # for the case of conv filters | |
| update = update.view(len(update), -1) | |
| return update | |
| def apply_scaling(grad, rms_scale=False ): | |
| if rms_scale: | |
| # https://github.com/MoonshotAI/Moonlight/blob/5afcb6911077e7f182d05865fe90d9f39abcbcbd/examples/toy_train.py#L146 | |
| grad *= 0.2 * math.sqrt(max(grad.shape[1], grad.shape[0])) | |
| return grad | |
| else: | |
| # https://github.com/KellerJordan/Muon/blob/f90a42b28e00b8d9d2d05865fe90d9f39abcbcbd/muon.py#L40 | |
| grad *= max(1, grad.size(-2) / grad.size(-1))**0.5 | |
| return grad | |
| def muon_update(grad, momentum, beta, nesterov, ns_steps, rms_scale): | |
| update = apply_momentum(grad, momentum, beta, nesterov) | |
| update = zeropower_via_newtonschulz5(update, ns_steps) | |
| update = apply_scaling(update, rms_scale) | |
| return update | |
| def adam_update(grad, buf1, buf2, step, betas, eps): | |
| buf1.lerp_(grad, 1 - betas[0]) | |
| buf2.lerp_(grad.square(), 1 - betas[1]) | |
| buf1c = buf1 / (1 - betas[0]**step) | |
| buf2c = buf2 / (1 - betas[1]**step) | |
| return buf1c / (buf2c.sqrt() + eps) | |
| class MuonSingleDevice(torch.optim.Optimizer): | |
| """ | |
| Non-distributed variant of Muon, original code https://github.com/KellerJordan/Muon/blob/f90a42b28e00b8d9d2d05865fe90d9f39abcbcbd/muon.py | |
| notable change: | |
| - add rms_scale argument to the optimizer | |
| - use torch.compile to speed up the nsloop_torch function | |
| param_groups args: | |
| lr: learning rate | |
| momentum: momentum | |
| weight_decay: weight decay | |
| use_muon: whether to use muon | |
| rms_scale: whether to scale the gradient by the RMS of the gradient . If true use the rms scale from the moonlight paper. | |
| https://github.com/MoonshotAI/Moonlight/blob/5afcb6911077e7f182d1d7faa3c2cd45acba4666/examples/toy_train.py#L146 | |
| This variant adjust the update so that the RMS match the one of adam, allowing to only have one learning rate for all parameters. | |
| """ | |
| def __init__(self, param_groups): | |
| for group in param_groups: | |
| assert "use_muon" in group | |
| if group["use_muon"]: | |
| # defaults | |
| group["lr"] = group.get("lr", 0.02) | |
| group["momentum"] = group.get("momentum", 0.95) | |
| group["weight_decay"] = group.get("weight_decay", 0) | |
| group["rms_scale"] = group.get("rms_scale", True) | |
| group["nesterov"] = group.get("nesterov", True) | |
| group["ns_steps"] = group.get("ns_steps", 5) | |
| assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon", "rms_scale", "nesterov", "ns_steps"]) | |
| else: | |
| # defaults | |
| group["lr"] = group.get("lr", 3e-4) | |
| group["betas"] = group.get("betas", (0.9, 0.95)) | |
| group["eps"] = group.get("eps", 1e-10) | |
| group["weight_decay"] = group.get("weight_decay", 0) | |
| assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"]) | |
| super().__init__(param_groups, dict()) | |
| @torch.no_grad() | |
| def step(self, closure=None): | |
| loss = None | |
| if closure is not None: | |
| with torch.enable_grad(): | |
| loss = closure() | |
| for group in self.param_groups: | |
| if group["use_muon"]: | |
| for p in group["params"]: | |
| if p.grad is None: | |
| # continue | |
| p.grad = torch.zeros_like(p) # Force synchronization | |
| state = self.state[p] | |
| if len(state) == 0: | |
| state["momentum_buffer"] = torch.zeros_like(p) | |
| update = muon_update(p.grad, state["momentum_buffer"], group["momentum"], group["nesterov"], group["ns_steps"], group["rms_scale"]) | |
| p.mul_(1 - group["lr"] * group["weight_decay"]) | |
| p.add_(update.reshape(p.shape), alpha=-group["lr"]) | |
| else: | |
| for p in group["params"]: | |
| if p.grad is None: | |
| # continue | |
| p.grad = torch.zeros_like(p) # Force synchronization | |
| state = self.state[p] | |
| if len(state) == 0: | |
| state["exp_avg"] = torch.zeros_like(p) | |
| state["exp_avg_sq"] = torch.zeros_like(p) | |
| state["step"] = 0 | |
| state["step"] += 1 | |
| update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], | |
| state["step"], group["betas"], group["eps"]) | |
| p.mul_(1 - group["lr"] * group["weight_decay"]) | |
| p.add_(update, alpha=-group["lr"]) | |
| return loss | |
| class Work(Protocol): | |
| def __init__(self, param, state, group, index: int): | |
| ... | |
| def start(self): | |
| ... | |
| def finish(self): | |
| ... | |
| class Fsdp1dWork: | |
| """ | |
| muon handle for fsdp2 1d mesh. | |
| """ | |
| def __init__(self, param, state, group, index: int): | |
| self.param = param | |
| self.state = state | |
| self.group = group | |
| self.index = index | |
| self._intermediate_state = None | |
| def start(self): | |
| self.param.grad = apply_momentum(self.param.grad, self.state["momentum_buffer"] , self.group["momentum"], self.group["nesterov"]) | |
| grad = self.param.grad | |
| assert isinstance(grad, DTensor), "only supports DTensor parameters" | |
| assert grad.device_mesh.ndim == 1, "only supports 1D mesh" | |
| rank = grad.device_mesh.get_rank() | |
| world_size = grad.device_mesh.size() | |
| pg = grad.device_mesh.get_group() | |
| dest_rank = self.index % world_size | |
| if rank == dest_rank: | |
| gather_lists = [torch.zeros_like(input=grad.to_local()) for _ in range(world_size)] | |
| gather_handle = gather(grad.to_local(), gather_lists, group_dst=dest_rank, group=pg, async_op=True) | |
| else: | |
| gather_lists = None | |
| gather_handle = gather(grad.to_local(), None, group_dst=dest_rank, group=pg, async_op=True) | |
| self._intermediate_state = [dest_rank, gather_handle, gather_lists] | |
| def finish(self): | |
| assert self._intermediate_state is not None, "gather work must be called first" | |
| grad = self.param.grad | |
| rank = grad.device_mesh.get_rank() | |
| world_size = grad.device_mesh.size() | |
| pg = grad.device_mesh.get_group() | |
| dest_rank, gather_handle, gather_lists = self._intermediate_state | |
| gather_handle.wait() | |
| if rank == dest_rank: | |
| g_full_block = torch.cat(gather_lists, dim=0) | |
| g_full_block.copy_(zeropower_via_newtonschulz5(g_full_block, self.group["ns_steps"])) | |
| g_full_block = g_full_block.type_as(grad) | |
| chunks = list(g_full_block.chunk(chunks=world_size, dim=0)) | |
| scatter(grad.to_local(), scatter_list=chunks, src=dest_rank, group=pg, async_op=False) | |
| else: | |
| scatter(grad.to_local(), None, src=dest_rank, group=pg, async_op=False) | |
| update = apply_scaling(grad, self.group["rms_scale"]) | |
| self.param.mul_(1 - self.group["lr"] * self.group["weight_decay"]) | |
| self.param.add_(update.reshape(self.param.shape), alpha=-self.group["lr"]) | |
| class TpFsdp2dWork: | |
| """ | |
| Muon work for TP + FSDP mesh | |
| """ | |
| def __init__(self, param, state, group, index: int): | |
| raise NotImplementedError("not implemented") | |
| class EpFsdp2dWork: | |
| """ | |
| Muon work for EP mesh | |
| """ | |
| def __init__(self, param, state, group, index: int): | |
| raise NotImplementedError("not implemented") | |
| class TpEpFsdp3dWork: | |
| """ | |
| Muon work for TP + EP mesh | |
| """ | |
| def __init__(self, param, state, group, index: int): | |
| raise NotImplementedError("not implemented") | |
| class SingelDeviceWork: | |
| """ | |
| muon handle for single device. | |
| """ | |
| def __init__(self, param, state, group, index: int): | |
| self.param = param | |
| self.state = state | |
| self.group = group | |
| def start(self): | |
| update = muon_update(self.param.grad, self.state["momentum_buffer"], self.group["momentum"], self.group["nesterov"], self.group["ns_steps"], self.group["rms_scale"]) | |
| self.param.mul_(1 - self.group["lr"] * self.group["weight_decay"]) | |
| self.param.add_(update.reshape(self.param.shape), alpha=-self.group["lr"]) | |
| def finish(self): | |
| pass | |
| class Muon(torch.optim.Optimizer): | |
| """ | |
| DTensor variant of Muon, original code https://github.com/KellerJordan/Muon/blob/f90a42b28e00b8d9d2d05865fe90d9f39abcbcbd/muon.py | |
| also support single device variant. | |
| notable change: | |
| - add rms_scale argument to the optimizer | |
| - use torch.compile to speed up the nsloop_torch function | |
| param_groups args: | |
| lr: learning rate | |
| momentum: momentum | |
| weight_decay: weight decay | |
| use_muon: whether to use muon | |
| rms_scale: whether to scale the gradient by the RMS of the gradient . If true use the rms scale from the moonlight paper. | |
| https://github.com/MoonshotAI/Moonlight/blob/5afcb6911077e7f182d1d7faa3c2cd45acba4666/examples/toy_train.py#L146 | |
| This variant adjust the update so that the RMS match the one of adam, allowing to only have one learning rate for all parameters. | |
| """ | |
| def __init__(self, param_groups): | |
| for group in param_groups: | |
| assert "use_muon" in group | |
| if group["use_muon"]: | |
| # defaults | |
| group["lr"] = group.get("lr", 0.02) | |
| group["momentum"] = group.get("momentum", 0.95) | |
| group["weight_decay"] = group.get("weight_decay", 0) | |
| group["rms_scale"] = group.get("rms_scale", True) | |
| group["nesterov"] = group.get("nesterov", True) | |
| group["ns_steps"] = group.get("ns_steps", 5) | |
| assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon", "rms_scale", "nesterov", "ns_steps"]) | |
| else: | |
| # defaults | |
| group["lr"] = group.get("lr", 3e-4) | |
| group["betas"] = group.get("betas", (0.9, 0.95)) | |
| group["eps"] = group.get("eps", 1e-10) | |
| group["weight_decay"] = group.get("weight_decay", 0) | |
| assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"]) | |
| super().__init__(param_groups, dict()) | |
| def _get_work_class(self, p: torch.Tensor) -> tuple[type[Work], int]: | |
| """ | |
| dispatch the work class based on the mesh dimension. | |
| """ | |
| if isinstance(p, DTensor): | |
| if p.device_mesh.ndim == 1: | |
| return Fsdp1dWork, 8 | |
| elif p.device_mesh.ndim == 2: | |
| return TpFsdp2dWork, 8 | |
| else: | |
| raise ValueError(f"Unsupported mesh dimension: {p.device_mesh.ndim}") | |
| else: | |
| return SingelDeviceWork, 1 | |
| @torch.no_grad() | |
| def step(self, closure=None): | |
| loss = None | |
| if closure is not None: | |
| with torch.enable_grad(): | |
| loss = closure() | |
| dq: deque[Work] = deque() | |
| for group in self.param_groups: | |
| if group["use_muon"]: | |
| for i ,p in enumerate(group["params"]): | |
| if p.grad is None: | |
| # continue | |
| p.grad = torch.zeros_like(p) # Force synchronization | |
| state = self.state[p] | |
| if len(state) == 0: | |
| state["momentum_buffer"] = torch.zeros_like(p) | |
| class_work, prefetch_factor = self._get_work_class(p) | |
| work = class_work(p, state, group, i) | |
| work.start() | |
| dq.append(work) | |
| if len(dq) > prefetch_factor: | |
| dq.popleft().finish() | |
| else: | |
| for p in group["params"]: | |
| if p.grad is None: | |
| # continue | |
| p.grad = torch.zeros_like(p) # Force synchronization | |
| state = self.state[p] | |
| if len(state) == 0: | |
| state["exp_avg"] = torch.zeros_like(p) | |
| state["exp_avg_sq"] = torch.zeros_like(p) | |
| state["step"] = 0 | |
| state["step"] += 1 | |
| update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], | |
| state["step"], group["betas"], group["eps"]) | |
| p.mul_(1 - group["lr"] * group["weight_decay"]) | |
| p.add_(update, alpha=-group["lr"]) | |
| for work in dq: | |
| work.finish() | |
| return loss | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment