Last active
May 3, 2025 07:14
-
-
Save dvruette/5cb21202b4497adea9fae46a0ca8e07f to your computer and use it in GitHub Desktop.
Minimal muP for MLP
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import torch.nn.functional as F | |
| from torchvision import datasets, transforms | |
| import torch | |
| from torch import nn | |
| from torch.optim import SGD | |
| import matplotlib.pyplot as plt | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| batch_size = 128 | |
| data_dir = '/tmp' | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
| ]) | |
| train_ds = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform) | |
| train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2) | |
| test_ds = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform) | |
| test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2) | |
| # muP implementation of MLP, following Tab. 3, Tab. 8, and Tab. 9 from https://arxiv.org/pdf/2203.03466 | |
| class muMLPTab3(nn.Module): | |
| def __init__(self, width=128, num_classes=10): | |
| super().__init__() | |
| self.width = width | |
| self.input_mult = 1.0 | |
| self.output_mult = 1.0 | |
| self.fc_1 = nn.Linear(3072, width, bias=False) | |
| self.fc_2 = nn.Linear(width, width, bias=False) | |
| self.fc_3 = nn.Linear(width, num_classes, bias=False) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| nn.init.normal_(self.fc_1.weight, std=3072**-0.5) | |
| nn.init.normal_(self.fc_2.weight, std=self.width**-0.5) | |
| nn.init.normal_(self.fc_3.weight, std=1/self.width) | |
| def forward(self, x): | |
| activations = [] | |
| h = self.input_mult * self.fc_1(x) | |
| activations.append(h) | |
| h = self.fc_2(F.relu(h)) | |
| activations.append(h) | |
| h = self.output_mult * self.fc_3(F.relu(h)) | |
| activations.append(h) | |
| return h, activations | |
| def get_param_groups(self, base_lr): | |
| return [ | |
| dict(params=self.fc_1.parameters(), lr=base_lr*self.width), | |
| dict(params=self.fc_2.parameters(), lr=base_lr), | |
| dict(params=self.fc_3.parameters(), lr=base_lr/self.width), | |
| ] | |
| class muMLPTab8(nn.Module): | |
| def __init__(self, width=128, num_classes=10): | |
| super().__init__() | |
| self.width = width | |
| self.input_mult = 1.0 | |
| self.output_mult = 1.0 / self.width | |
| self.fc_1 = nn.Linear(3072, width, bias=False) | |
| self.fc_2 = nn.Linear(width, width, bias=False) | |
| self.fc_3 = nn.Linear(width, num_classes, bias=False) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| nn.init.normal_(self.fc_1.weight, std=3072**-0.5) | |
| nn.init.normal_(self.fc_2.weight, std=self.width**-0.5) | |
| nn.init.normal_(self.fc_3.weight, std=1.0) | |
| def forward(self, x, return_activations=False): | |
| activations = [] | |
| h = self.input_mult * self.fc_1(x) | |
| activations.append(h) | |
| h = self.fc_2(F.relu(h)) | |
| activations.append(h) | |
| h = self.output_mult * self.fc_3(F.relu(h)) | |
| activations.append(h) | |
| return h, activations | |
| def get_param_groups(self, base_lr): | |
| return [ | |
| dict(params=self.fc_1.parameters(), lr=base_lr*self.width), | |
| dict(params=self.fc_2.parameters(), lr=base_lr), | |
| dict(params=self.fc_3.parameters(), lr=base_lr*self.width), | |
| ] | |
| class muMLPTab9(nn.Module): | |
| def __init__(self, width=128, num_classes=10): | |
| super().__init__() | |
| self.width = width | |
| self.input_mult = self.width**0.5 | |
| self.output_mult = self.width**-0.5 | |
| self.fc_1 = nn.Linear(3072, width, bias=False) | |
| self.fc_2 = nn.Linear(width, width, bias=False) | |
| self.fc_3 = nn.Linear(width, num_classes, bias=False) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| nn.init.normal_(self.fc_1.weight, std=(self.width*3072)**-0.5) | |
| nn.init.normal_(self.fc_2.weight, std=self.width**-0.5) | |
| nn.init.normal_(self.fc_3.weight, std=self.width**-0.5) | |
| def forward(self, x): | |
| activations = [] | |
| h = self.input_mult * self.fc_1(x) | |
| activations.append(h) | |
| h = self.fc_2(F.relu(h)) | |
| activations.append(h) | |
| h = self.output_mult * self.fc_3(F.relu(h)) | |
| activations.append(h) | |
| return h, activations | |
| def get_param_groups(self, base_lr): | |
| return [ | |
| dict(params=self.fc_1.parameters(), lr=base_lr), | |
| dict(params=self.fc_2.parameters(), lr=base_lr), | |
| dict(params=self.fc_3.parameters(), lr=base_lr), | |
| ] | |
| # run coordinate check to test correctness | |
| widths = [64, 128, 256, 512, 1024, 2048, 4096] | |
| max_t = 5 | |
| num_seeds = 5 | |
| base_lr = 0.1 | |
| dataset = [next(iter(train_dl))] * max_t | |
| all_metrics = [] | |
| for width in widths: | |
| metrics = [] | |
| for seed in range(num_seeds): | |
| torch.manual_seed(seed) | |
| # model = muMLPTab3(width=width).to(device) | |
| # model = muMLPTab8(width=width).to(device) | |
| model = muMLPTab9(width=width).to(device) | |
| optimizer = SGD(model.get_param_groups(base_lr)) | |
| acts_t = [] | |
| for batch_idx, (data, target) in enumerate(dataset): | |
| if batch_idx >= max_t: | |
| break | |
| data, target = data.to(device), target.to(device) | |
| output, acts = model(data.view(data.size(0), -1)) | |
| acts_t.append([a.detach().cpu() for a in acts]) | |
| loss = F.cross_entropy(output, target) | |
| loss.backward() | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| del model | |
| act_diff_std = [] | |
| acts_0 = acts_t[0] | |
| for acts in acts_t[:]: | |
| diffs = [(ht).abs().mean().item() for h0, ht in zip(acts_0, acts)] | |
| act_diff_std.append(diffs) | |
| metrics.append(act_diff_std) | |
| all_metrics.append(np.stack(metrics, axis=0).mean(axis=0)) | |
| all_metrics = np.array(all_metrics) | |
| for layer_idx in range(3): | |
| fig, ax = plt.subplots() | |
| ax.set_title(f"layer_idx={layer_idx+1}") | |
| for t in range(max_t): | |
| ax.plot(widths, all_metrics[:, t, layer_idx], label=f"t={t+1}") | |
| ax.set_xlabel("width") | |
| ax.set_ylabel("activation scale") | |
| ax.legend() | |
| ax.set_xscale("log") | |
| ax.set_yscale("log") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment