Created
September 23, 2024 07:22
-
-
Save KellerJordan/aa4e4a0ddbddb42cecd08396f3ff1100 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
clean_airbench.py | |
Variant of airbench94 which removes the following: | |
* Lookahead optimization | |
* Progressive freezing | |
And increases the training duration slightly via the following: | |
* Epochs 9.9 -> 10.0 | |
* Batch size 1024 -> 1000 | |
+imports the dataloader from airbench package. | |
Attains 93.97 mean accuracy (n=200). | |
""" | |
############################################# | |
# Setup/Hyperparameters # | |
############################################# | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import torchvision | |
import torchvision.transforms as T | |
from airbench import evaluate, CifarLoader | |
torch.backends.cudnn.benchmark = True | |
hyp = { | |
'opt': { | |
'epochs': 10, | |
'batch_size': 1000, | |
'lr': 10.0, # learning rate per 1024 examples -- 5.0 is optimal with no smoothing, 10.0 with smoothing. | |
'momentum': 0.85, | |
'weight_decay': 0.015, # weight decay per 1024 examples (decoupled from learning rate) | |
'bias_scaler': 64.0, # scales up learning rate (but not weight decay) for BatchNorm biases | |
'label_smoothing': 0.2, | |
}, | |
'aug': { | |
'flip': True, | |
'translate': 2, | |
}, | |
'net': { | |
'widths': { | |
'block1': 64, | |
'block2': 256, | |
'block3': 256, | |
}, | |
'scaling_factor': 1/9, | |
'tta_level': 2, | |
}, | |
} | |
############################################# | |
# Network Components # | |
############################################# | |
class Flatten(nn.Module): | |
def forward(self, x): | |
return x.view(x.size(0), -1) | |
class Mul(nn.Module): | |
def __init__(self, scale): | |
super().__init__() | |
self.scale = scale | |
def forward(self, x): | |
return x * self.scale | |
class BatchNorm(nn.BatchNorm2d): | |
def __init__(self, num_features, eps=1e-12, | |
weight=False, bias=True): | |
super().__init__(num_features, eps=eps) | |
self.weight.requires_grad = weight | |
self.bias.requires_grad = bias | |
# Note that PyTorch already initializes the weights to one and bias to zero | |
class Conv(nn.Conv2d): | |
def __init__(self, in_channels, out_channels, kernel_size=3, padding='same', bias=False): | |
super().__init__(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias) | |
def reset_parameters(self): | |
super().reset_parameters() | |
if self.bias is not None: | |
self.bias.data.zero_() | |
w = self.weight.data | |
torch.nn.init.dirac_(w[:w.size(1)]) | |
class ConvGroup(nn.Module): | |
def __init__(self, channels_in, channels_out): | |
super().__init__() | |
self.conv1 = Conv(channels_in, channels_out) | |
self.pool = nn.MaxPool2d(2) | |
self.norm1 = BatchNorm(channels_out) | |
self.conv2 = Conv(channels_out, channels_out) | |
self.norm2 = BatchNorm(channels_out) | |
self.activ = nn.GELU() | |
def forward(self, x): | |
x = self.conv1(x) | |
x = self.pool(x) | |
x = self.norm1(x) | |
x = self.activ(x) | |
x = self.conv2(x) | |
x = self.norm2(x) | |
x = self.activ(x) | |
return x | |
############################################# | |
# Network Definition # | |
############################################# | |
## The eigenvectors of the covariance matrix of 2x2 patches of CIFAR-10 divided by sqrt eigenvalues. | |
eigenvectors_scaled = torch.tensor([ | |
8.9172e-02, 8.9172e-02, 8.8684e-02, 8.8623e-02, 9.3872e-02, 9.3872e-02, 9.3018e-02, 9.3018e-02, | |
9.0027e-02, 9.0027e-02, 8.9233e-02, 8.9172e-02, 3.8818e-01, 3.8794e-01, 3.9111e-01, 3.9038e-01, | |
-1.0767e-03, -1.3609e-03, 3.8567e-03, 3.2330e-03, -3.9087e-01, -3.9087e-01, -3.8428e-01, -3.8452e-01, | |
-4.8242e-01, -4.7485e-01, 4.5435e-01, 4.6216e-01, -4.6240e-01, -4.5557e-01, 4.8975e-01, 4.9658e-01, | |
-4.3311e-01, -4.2725e-01, 4.2285e-01, 4.2896e-01, -5.0781e-01, 5.1514e-01, -5.1562e-01, 5.0879e-01, | |
-5.1807e-01, 5.2783e-01, -5.2539e-01, 5.1904e-01, -4.6460e-01, 4.7070e-01, -4.7168e-01, 4.6240e-01, | |
-4.7290e-01, -4.7461e-01, -5.0635e-01, -5.0684e-01, 9.5410e-01, 9.5117e-01, 9.2090e-01, 9.1846e-01, | |
-4.7363e-01, -4.7607e-01, -5.0439e-01, -5.0586e-01, -1.2539e+00, 1.2490e+00, 1.2383e+00, -1.2354e+00, | |
-1.2637e+00, 1.2666e+00, 1.2715e+00, -1.2725e+00, -1.1396e+00, 1.1416e+00, 1.1494e+00, -1.1514e+00, | |
-2.8262e+00, -2.7578e+00, 2.7617e+00, 2.8438e+00, 3.9404e-01, 3.7622e-01, -3.8330e-01, -3.9502e-01, | |
2.6602e+00, 2.5801e+00, -2.6055e+00, -2.6738e+00, -2.9473e+00, 3.0312e+00, -3.0488e+00, 2.9648e+00, | |
3.9111e-01, -4.0063e-01, 3.7939e-01, -3.7451e-01, 2.8242e+00, -2.9023e+00, 2.8789e+00, -2.8008e+00, | |
2.6582e+00, 2.3105e+00, -2.3105e+00, -2.6484e+00, -5.9336e+00, -5.1680e+00, 5.1719e+00, 5.9258e+00, | |
3.6855e+00, 3.2285e+00, -3.2148e+00, -3.6992e+00, -2.4668e+00, 2.8281e+00, -2.8379e+00, 2.4785e+00, | |
5.4062e+00, -6.2031e+00, 6.1797e+00, -5.3906e+00, -3.3223e+00, 3.8164e+00, -3.8223e+00, 3.3340e+00, | |
-8.0000e+00, 8.0000e+00, 8.0000e+00, -8.0078e+00, 9.7656e-01, -9.9414e-01, -9.8584e-01, 1.0039e+00, | |
7.5938e+00, -7.5820e+00, -7.6133e+00, 7.6016e+00, 5.5508e+00, -5.5430e+00, -5.5430e+00, 5.5352e+00, | |
-1.2133e+01, 1.2133e+01, 1.2148e+01, -1.2148e+01, 7.4141e+00, -7.4180e+00, -7.4219e+00, 7.4297e+00, | |
]).reshape(12, 3, 2, 2) | |
def make_net(): | |
widths = hyp['net']['widths'] | |
whiten_kernel_size = 2 | |
whiten_width = 2 * 3 * whiten_kernel_size**2 | |
net = nn.Sequential( | |
Conv(3, whiten_width, whiten_kernel_size, padding=0, bias=True), | |
nn.GELU(), | |
ConvGroup(whiten_width, widths['block1']), | |
ConvGroup(widths['block1'], widths['block2']), | |
ConvGroup(widths['block2'], widths['block3']), | |
nn.MaxPool2d(3), | |
Flatten(), | |
nn.Linear(widths['block3'], 10, bias=False), | |
Mul(hyp['net']['scaling_factor']), | |
) | |
net[0].weight.data[:] = torch.cat((eigenvectors_scaled, -eigenvectors_scaled)) | |
net[0].weight.requires_grad = False | |
net = net.half().cuda() | |
net = net.to(memory_format=torch.channels_last) | |
for mod in net.modules(): | |
if isinstance(mod, BatchNorm): | |
mod.float() | |
return net | |
############################################ | |
# Training and Inference # | |
############################################ | |
def train(train_loader): | |
momentum = hyp['opt']['momentum'] | |
epochs = hyp['opt']['epochs'] | |
# Assuming gradients are constant in time, for Nesterov momentum, the below ratio is how much | |
# larger the default steps will be than the underlying per-example gradients. We divide the | |
# learning rate by this ratio in order to ensure steps are the same scale as gradients, regardless | |
# of the choice of momentum. | |
kilostep_scale = 1024 * (1 + 1 / (1 - momentum)) | |
lr = hyp['opt']['lr'] / kilostep_scale # un-decoupled learning rate for PyTorch SGD | |
wd = hyp['opt']['weight_decay'] * train_loader.batch_size / kilostep_scale | |
lr_biases = lr * hyp['opt']['bias_scaler'] | |
model = make_net() | |
loss_fn = nn.CrossEntropyLoss(label_smoothing=hyp['opt']['label_smoothing'], reduction='none') | |
total_train_steps = epochs * len(train_loader) | |
norm_biases = [p for k, p in model.named_parameters() if 'norm' in k] | |
other_params = [p for k, p in model.named_parameters() if 'norm' not in k] | |
param_configs = [dict(params=norm_biases, lr=lr_biases, weight_decay=wd/lr_biases), | |
dict(params=other_params, lr=lr, weight_decay=wd/lr)] | |
optimizer = torch.optim.SGD(param_configs, momentum=momentum, nesterov=True) | |
def get_lr(step): | |
warmup_steps = int(total_train_steps * 0.2) | |
warmdown_steps = total_train_steps - warmup_steps | |
if step < warmup_steps: | |
frac = step / warmup_steps | |
return 0.2 * (1 - frac) + 1.0 * frac | |
else: | |
frac = (step - warmup_steps) / warmdown_steps | |
return (1 - frac) | |
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, get_lr) | |
current_steps = 0 | |
train_loader.epoch = 0 | |
model.train() | |
from tqdm import tqdm | |
for epoch in tqdm(range(epochs)): | |
for inputs, labels in train_loader: | |
outputs = model(inputs) | |
loss = loss_fn(outputs, labels).sum() | |
optimizer.zero_grad(set_to_none=True) | |
loss.backward() | |
optimizer.step() | |
scheduler.step() | |
current_steps += 1 | |
if current_steps == total_train_steps: | |
break | |
return model | |
if __name__ == '__main__': | |
train_loader = CifarLoader('/tmp/cifar10', train=True, batch_size=hyp['opt']['batch_size'], aug=hyp['aug'], altflip=True) | |
test_loader = CifarLoader('/tmp/cifar10', train=False) | |
print(evaluate(train(train_loader), test_loader, tta_level=hyp['net']['tta_level'])) | |
print(torch.std_mean(torch.tensor([evaluate(train(train_loader), test_loader, tta_level=hyp['net']['tta_level']) for _ in range(50)]))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment