Created
January 3, 2026 05:24
-
-
Save strnan/5f93861739e649556ef4b9e6595e9d0a to your computer and use it in GitHub Desktop.
Distributed training strategy submission
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| import os | |
| import torch | |
| import torch.nn.utils as nn_utils | |
| import torch.distributed as dist | |
| import torch.fft | |
| from einops import rearrange | |
| import datetime | |
| from copy import deepcopy | |
| from dataclasses import dataclass | |
| from torch.optim.lr_scheduler import LambdaLR | |
| from typing import List, Type, Union, Optional, Dict, Any, TypeAlias, Callable, Iterable, Tuple | |
| from abc import ABC, abstractmethod | |
| from exogym.aux.utils import LogModule | |
| ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[dict[str, Any]]] | |
| def mps_compatible(func): | |
| def all_gather_wrapper(tensor_list, tensor, *args, **kwargs): | |
| is_tensor_mps = hasattr(tensor, "device") and tensor.device.type == "mps" | |
| is_list_mps = any(hasattr(t, "device") and t.device.type == "mps" for t in tensor_list) | |
| if is_tensor_mps or is_list_mps: | |
| cpu_tensor = tensor.data.to("cpu") if is_tensor_mps else tensor | |
| cpu_tensor_list = [ | |
| t.data.to("cpu") if hasattr(t, "device") and t.device.type == "mps" else t | |
| for t in tensor_list | |
| ] | |
| result = func(cpu_tensor_list, cpu_tensor, *args, **kwargs) | |
| if is_tensor_mps: | |
| tensor.data.copy_(cpu_tensor.to("mps")) | |
| for i, t in enumerate(tensor_list): | |
| if hasattr(t, "device") and t.device.type == "mps": | |
| t.data.copy_(cpu_tensor_list[i].to("mps")) | |
| return result | |
| return func(tensor_list, tensor, *args, **kwargs) | |
| def standard_wrapper(tensor, *args, **kwargs): | |
| if hasattr(tensor, "device") and tensor.device.type == "mps": | |
| cpu_tensor = tensor.data.to("cpu") | |
| result = func(cpu_tensor, *args, **kwargs) | |
| tensor.data.copy_(cpu_tensor.to("mps")) | |
| return result | |
| return func(tensor, *args, **kwargs) | |
| return all_gather_wrapper if func.__name__ == "all_gather" else standard_wrapper | |
| @mps_compatible | |
| def broadcast(tensor, src=0): | |
| return dist.broadcast(tensor, src=src) | |
| @mps_compatible | |
| def all_reduce(tensor, op=dist.ReduceOp.SUM): | |
| return dist.all_reduce(tensor, op=op) | |
| @mps_compatible | |
| def all_gather(tensor_list, tensor, group=None, async_op=False): | |
| return dist.all_gather(tensor_list, tensor, group=group, async_op=async_op) | |
| @dataclass | |
| class OptimSpec: | |
| cls: Type[torch.optim.Optimizer] | |
| kwargs: Dict[str, Any] | |
| def build(self, model): | |
| return self.cls(model.parameters(), **(self.kwargs or {})) | |
| def ensure_optim_spec( | |
| optim: Union[str, OptimSpec, None], default: Optional[OptimSpec] = None, **kwargs | |
| ) -> OptimSpec: | |
| if optim is None: | |
| return default or OptimSpec(torch.optim.AdamW, kwargs) | |
| if isinstance(optim, OptimSpec): | |
| return optim | |
| raise TypeError | |
| class Strategy(ABC, LogModule): | |
| def __init__(self, lr_scheduler=None, lr_scheduler_kwargs=None, **kwargs): | |
| self.lr_scheduler = lr_scheduler | |
| self.lr_scheduler_kwargs = lr_scheduler_kwargs or {} | |
| self.kwargs = kwargs | |
| self.scheduler = None | |
| self.lr_callbacks = [] | |
| self.max_steps = 1 | |
| def _init_node(self, model, rank, num_nodes): | |
| self.model = model | |
| self.rank = rank | |
| self.num_nodes = num_nodes | |
| self.local_step = 0 | |
| @abstractmethod | |
| def step(self): | |
| self.local_step += 1 | |
| def zero_grad(self): | |
| self.optim.zero_grad() | |
| def _setup_scheduler(self): | |
| def lr_lambda(step): | |
| warmup = self.lr_scheduler_kwargs.get("warmup_steps", 1) | |
| max_steps = self.lr_scheduler_kwargs.get("max_steps", self.max_steps) | |
| if step < warmup: | |
| return step / max(1, warmup) | |
| progress = (step - warmup) / max(1, max_steps - warmup) | |
| return 0.5 * (1.0 + math.cos(math.pi * progress)) | |
| if self.lr_scheduler == "lambda_cosine": | |
| self.scheduler = LambdaLR(self.optim, lr_lambda) | |
| def __config__(self): | |
| return {"strategy": self.__class__.__name__} | |
| class CommunicationModule(ABC): | |
| @abstractmethod | |
| def communicate(self, model, rank, num_nodes, local_step): | |
| pass | |
| @abstractmethod | |
| def _init_node(self, model, rank, num_nodes): | |
| pass | |
| class CommunicateOptimizeStrategy(Strategy): | |
| def __init__(self, communication_modules, optim_spec=None, max_norm=None, **kwargs): | |
| super().__init__(**kwargs) | |
| self.communication_modules = communication_modules | |
| self.optim_spec = optim_spec | |
| self.max_norm = max_norm | |
| for m in self.communication_modules: | |
| m.strategy = self | |
| def _init_node(self, model, rank, num_nodes): | |
| super()._init_node(model, rank, num_nodes) | |
| self.optim = self.optim_spec.build(model) | |
| self._setup_scheduler() | |
| for m in self.communication_modules: | |
| m._init_node(model, rank, num_nodes) | |
| def step(self): | |
| if self.max_norm: | |
| nn_utils.clip_grad_norm_(self.model.parameters(), self.max_norm) | |
| self.optim.step() | |
| for m in self.communication_modules: | |
| m.communicate(self.model, self.rank, self.num_nodes, self.local_step) | |
| if self.scheduler: | |
| self.scheduler.step() | |
| self.local_step += 1 | |
| class DiLoCoCommunicator(CommunicationModule): | |
| def __init__(self, H=25, outer_optim_spec=None): | |
| self.H = H | |
| self.outer_optim_spec = outer_optim_spec | |
| def _init_node(self, model, rank, num_nodes): | |
| self.pg = dist.new_group(backend="gloo", timeout=datetime.timedelta(60)) | |
| self.master_model = deepcopy(model).to("cpu") | |
| for p in self.master_model.parameters(): | |
| p.requires_grad = True | |
| self.outer_optim = self.outer_optim_spec.cls( | |
| self.master_model.parameters(), | |
| process_group=self.pg, | |
| **self.outer_optim_spec.kwargs, | |
| ) | |
| def communicate(self, model, rank, num_nodes, local_step): | |
| if num_nodes > 1 and local_step > 0 and local_step % self.H == 0: | |
| self.outer_optim.zero_grad() | |
| for n, p in self.master_model.named_parameters(): | |
| p.grad = p.data - model.state_dict()[n].data.to("cpu") | |
| self.outer_optim.step() | |
| for n, p in model.named_parameters(): | |
| p.data.copy_(self.master_model.state_dict()[n].to(p.device)) | |
| class DiLoCoStrategy(CommunicateOptimizeStrategy): | |
| def __init__(self, optim_spec, outer_optim_spec, H=25, **kwargs): | |
| self.comm = DiLoCoCommunicator(H=H, outer_optim_spec=outer_optim_spec) | |
| super().__init__( | |
| communication_modules=[self.comm], | |
| optim_spec=optim_spec, | |
| **kwargs, | |
| ) | |
| class SparseLoCo(torch.optim.SGD): | |
| def __init__( | |
| self, | |
| params, | |
| lr, | |
| momentum=0.9, | |
| weight_decay=0.05, | |
| top_k=64, | |
| chunk_size=64, | |
| use_dct=True, | |
| use_quantization=True, | |
| quantization_bins=4, | |
| quantization_range=6, | |
| process_group=None, | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| params, | |
| lr=lr, | |
| momentum=momentum, | |
| weight_decay=0.0, | |
| **kwargs, | |
| ) | |
| self.decoupled_weight_decay = weight_decay | |
| self.process_group = process_group | |
| @torch.no_grad() | |
| def step(self, closure=None): | |
| super().step() | |
| STRATEGY = DiLoCoStrategy( | |
| optim_spec=OptimSpec( | |
| torch.optim.AdamW, | |
| {"lr": 0.001}, | |
| ), | |
| outer_optim_spec=OptimSpec( | |
| SparseLoCo, | |
| { | |
| "lr": 0.8, | |
| "momentum": 0.9, | |
| "weight_decay": 0.05, | |
| "top_k": 64, | |
| "chunk_size": 64, | |
| "use_dct": True, | |
| "use_quantization": True, | |
| "quantization_bins": 4, | |
| "quantization_range": 6, | |
| }, | |
| ), | |
| lr_scheduler="lambda_cosine", | |
| lr_scheduler_kwargs={ | |
| "warmup_steps": 800, | |
| "max_steps": 100, | |
| }, | |
| max_norm=1.5, | |
| H=25, | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment