Skip to content

Instantly share code, notes, and snippets.

@dvruette
Last active May 3, 2025 07:14
Show Gist options
  • Save dvruette/5cb21202b4497adea9fae46a0ca8e07f to your computer and use it in GitHub Desktop.
Save dvruette/5cb21202b4497adea9fae46a0ca8e07f to your computer and use it in GitHub Desktop.
Minimal muP for MLP
import numpy as np
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch
from torch import nn
from torch.optim import SGD
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
data_dir = '/tmp'
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
train_ds = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
test_ds = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2)
# muP implementation of MLP, following Tab. 3, Tab. 8, and Tab. 9 from https://arxiv.org/pdf/2203.03466
class muMLPTab3(nn.Module):
def __init__(self, width=128, num_classes=10):
super().__init__()
self.width = width
self.input_mult = 1.0
self.output_mult = 1.0
self.fc_1 = nn.Linear(3072, width, bias=False)
self.fc_2 = nn.Linear(width, width, bias=False)
self.fc_3 = nn.Linear(width, num_classes, bias=False)
self.reset_parameters()
def reset_parameters(self):
nn.init.normal_(self.fc_1.weight, std=3072**-0.5)
nn.init.normal_(self.fc_2.weight, std=self.width**-0.5)
nn.init.normal_(self.fc_3.weight, std=1/self.width)
def forward(self, x):
activations = []
h = self.input_mult * self.fc_1(x)
activations.append(h)
h = self.fc_2(F.relu(h))
activations.append(h)
h = self.output_mult * self.fc_3(F.relu(h))
activations.append(h)
return h, activations
def get_param_groups(self, base_lr):
return [
dict(params=self.fc_1.parameters(), lr=base_lr*self.width),
dict(params=self.fc_2.parameters(), lr=base_lr),
dict(params=self.fc_3.parameters(), lr=base_lr/self.width),
]
class muMLPTab8(nn.Module):
def __init__(self, width=128, num_classes=10):
super().__init__()
self.width = width
self.input_mult = 1.0
self.output_mult = 1.0 / self.width
self.fc_1 = nn.Linear(3072, width, bias=False)
self.fc_2 = nn.Linear(width, width, bias=False)
self.fc_3 = nn.Linear(width, num_classes, bias=False)
self.reset_parameters()
def reset_parameters(self):
nn.init.normal_(self.fc_1.weight, std=3072**-0.5)
nn.init.normal_(self.fc_2.weight, std=self.width**-0.5)
nn.init.normal_(self.fc_3.weight, std=1.0)
def forward(self, x, return_activations=False):
activations = []
h = self.input_mult * self.fc_1(x)
activations.append(h)
h = self.fc_2(F.relu(h))
activations.append(h)
h = self.output_mult * self.fc_3(F.relu(h))
activations.append(h)
return h, activations
def get_param_groups(self, base_lr):
return [
dict(params=self.fc_1.parameters(), lr=base_lr*self.width),
dict(params=self.fc_2.parameters(), lr=base_lr),
dict(params=self.fc_3.parameters(), lr=base_lr*self.width),
]
class muMLPTab9(nn.Module):
def __init__(self, width=128, num_classes=10):
super().__init__()
self.width = width
self.input_mult = self.width**0.5
self.output_mult = self.width**-0.5
self.fc_1 = nn.Linear(3072, width, bias=False)
self.fc_2 = nn.Linear(width, width, bias=False)
self.fc_3 = nn.Linear(width, num_classes, bias=False)
self.reset_parameters()
def reset_parameters(self):
nn.init.normal_(self.fc_1.weight, std=(self.width*3072)**-0.5)
nn.init.normal_(self.fc_2.weight, std=self.width**-0.5)
nn.init.normal_(self.fc_3.weight, std=self.width**-0.5)
def forward(self, x):
activations = []
h = self.input_mult * self.fc_1(x)
activations.append(h)
h = self.fc_2(F.relu(h))
activations.append(h)
h = self.output_mult * self.fc_3(F.relu(h))
activations.append(h)
return h, activations
def get_param_groups(self, base_lr):
return [
dict(params=self.fc_1.parameters(), lr=base_lr),
dict(params=self.fc_2.parameters(), lr=base_lr),
dict(params=self.fc_3.parameters(), lr=base_lr),
]
# run coordinate check to test correctness
widths = [64, 128, 256, 512, 1024, 2048, 4096]
max_t = 5
num_seeds = 5
base_lr = 0.1
dataset = [next(iter(train_dl))] * max_t
all_metrics = []
for width in widths:
metrics = []
for seed in range(num_seeds):
torch.manual_seed(seed)
# model = muMLPTab3(width=width).to(device)
# model = muMLPTab8(width=width).to(device)
model = muMLPTab9(width=width).to(device)
optimizer = SGD(model.get_param_groups(base_lr))
acts_t = []
for batch_idx, (data, target) in enumerate(dataset):
if batch_idx >= max_t:
break
data, target = data.to(device), target.to(device)
output, acts = model(data.view(data.size(0), -1))
acts_t.append([a.detach().cpu() for a in acts])
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
del model
act_diff_std = []
acts_0 = acts_t[0]
for acts in acts_t[:]:
diffs = [(ht).abs().mean().item() for h0, ht in zip(acts_0, acts)]
act_diff_std.append(diffs)
metrics.append(act_diff_std)
all_metrics.append(np.stack(metrics, axis=0).mean(axis=0))
all_metrics = np.array(all_metrics)
for layer_idx in range(3):
fig, ax = plt.subplots()
ax.set_title(f"layer_idx={layer_idx+1}")
for t in range(max_t):
ax.plot(widths, all_metrics[:, t, layer_idx], label=f"t={t+1}")
ax.set_xlabel("width")
ax.set_ylabel("activation scale")
ax.legend()
ax.set_xscale("log")
ax.set_yscale("log")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment