Skip to content

Instantly share code, notes, and snippets.

import math
from typing import Protocol
import torch
from torch.distributed.tensor import DTensor
from torch.distributed import gather, scatter
from collections import deque
@torch.compile(fullgraph=True)
def nsloop_torch(X: torch.Tensor, steps: int, *, a=3.4445, b=-4.7750, c=2.0315):
# /// script
# requires-python = ">=3.9"
# dependencies = [
# "wandb",
# "numpy",
# "pandas",
# ]
# ///
from argparse import ArgumentParser
@samsja
samsja / muon_fsdp_2_opt.py
Created March 16, 2025 22:38
muon_fsdp_2_opt.py
# ruff: noqa
# type: ignore
# fmt: off
# credits to https://gist.github.com/main-horse/7314170780e36f7443d1926418d75823
from typing import Generator
from collections import deque
import torch