Created
April 13, 2024 02:10
-
-
Save KellerJordan/5248bffd70694df7f5f0c259475b819c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
BatchNorm-free variant of airbench94 | |
90.6% mean accuracy in ~6 seconds on an H100 | |
Changes relative to airbench94: | |
- removed BatchNorms and added conv biases | |
- reduced batch size 1024 -> 384 | |
- reduced weight decay 0.015 -> 0.001 | |
- reduced lr 11.5 -> 10.0 | |
- increased epochs 9.9 -> 11 | |
""" | |
############################################# | |
# Setup/Hyperparameters # | |
############################################# | |
import os | |
import sys | |
import uuid | |
from math import ceil | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import torchvision | |
import torchvision.transforms as T | |
torch.backends.cudnn.benchmark = True | |
# We express the main training hyperparameters (batch size, learning rate, momentum, and weight decay) | |
# in decoupled form, so that each one can be tuned independently. This accomplishes the following: | |
# * Assuming time-constant gradients, the average step size is decoupled from everything but the lr. | |
# * The size of the weight decay update is decoupled from everything but the wd. | |
# In constrast, normally when we increase the (Nesterov) momentum, this also scales up the step size | |
# proportionally to 1 + 1 / (1 - momentum), meaning we cannot change momentum without having to re-tune | |
# the learning rate. Similarly, normally when we increase the learning rate this also increases the size | |
# of the weight decay, requiring a proportional decrease in the wd to maintain the same decay strength. | |
# | |
# The practical impact is that hyperparameter tuning is faster, since this parametrization allows each | |
# one to be tuned independently. See https://myrtle.ai/learn/how-to-train-your-resnet-5-hyperparameters/. | |
hyp = { | |
'opt': { | |
'train_epochs': 11, | |
'batch_size': 384, | |
'lr': 10.0, # learning rate per 1024 examples | |
'momentum': 0.85, | |
'weight_decay': 0.001, # weight decay per 1024 examples (decoupled from learning rate) | |
'label_smoothing': 0.2, | |
'whiten_bias_epochs': 3, # how many epochs to train the whitening layer bias before freezing | |
}, | |
'aug': { | |
'flip': True, | |
'translate': 2, | |
}, | |
'net': { | |
'widths': { | |
'block1': 64, | |
'block2': 256, | |
'block3': 256, | |
}, | |
'batchnorm_momentum': 0.6, | |
'scaling_factor': 1/9, | |
'tta_level': 2, # the level of test-time augmentation: 0=none, 1=mirror, 2=mirror+translate | |
}, | |
} | |
############################################# | |
# DataLoader # | |
############################################# | |
CIFAR_MEAN = torch.tensor((0.4914, 0.4822, 0.4465)) | |
CIFAR_STD = torch.tensor((0.2470, 0.2435, 0.2616)) | |
def batch_flip_lr(inputs): | |
flip_mask = (torch.rand(len(inputs), device=inputs.device) < 0.5).view(-1, 1, 1, 1) | |
return torch.where(flip_mask, inputs.flip(-1), inputs) | |
def batch_crop(images, crop_size): | |
r = (images.size(-1) - crop_size)//2 | |
shifts = torch.randint(-r, r+1, size=(len(images), 2), device=images.device) | |
images_out = torch.empty((len(images), 3, crop_size, crop_size), device=images.device, dtype=images.dtype) | |
# The two cropping methods in this if-else produce equivalent results, but the second is faster for r > 2. | |
if r <= 2: | |
for sy in range(-r, r+1): | |
for sx in range(-r, r+1): | |
mask = (shifts[:, 0] == sy) & (shifts[:, 1] == sx) | |
images_out[mask] = images[mask, :, r+sy:r+sy+crop_size, r+sx:r+sx+crop_size] | |
else: | |
images_tmp = torch.empty((len(images), 3, crop_size, crop_size+2*r), device=images.device, dtype=images.dtype) | |
for s in range(-r, r+1): | |
mask = (shifts[:, 0] == s) | |
images_tmp[mask] = images[mask, :, r+s:r+s+crop_size, :] | |
for s in range(-r, r+1): | |
mask = (shifts[:, 1] == s) | |
images_out[mask] = images_tmp[mask, :, :, r+s:r+s+crop_size] | |
return images_out | |
class CifarLoader: | |
def __init__(self, path, train=True, batch_size=500, aug=None, drop_last=None, shuffle=None, gpu=0): | |
data_path = os.path.join(path, 'train.pt' if train else 'test.pt') | |
if not os.path.exists(data_path): | |
dset = torchvision.datasets.CIFAR10(path, download=True, train=train) | |
images = torch.tensor(dset.data) | |
labels = torch.tensor(dset.targets) | |
torch.save({'images': images, 'labels': labels, 'classes': dset.classes}, data_path) | |
data = torch.load(data_path, map_location=torch.device(gpu)) | |
self.images, self.labels, self.classes = data['images'], data['labels'], data['classes'] | |
# It's faster to load+process uint8 data than to load preprocessed fp16 data | |
self.images = (self.images.half() / 255).permute(0, 3, 1, 2).to(memory_format=torch.channels_last) | |
self.normalize = T.Normalize(CIFAR_MEAN, CIFAR_STD) | |
self.proc_images = {} # Saved results of image processing to be done on the first epoch | |
self.epoch = 0 | |
self.aug = aug or {} | |
for k in self.aug.keys(): | |
assert k in ['flip', 'translate'], 'Unrecognized key: %s' % k | |
self.batch_size = batch_size | |
self.drop_last = train if drop_last is None else drop_last | |
self.shuffle = train if shuffle is None else shuffle | |
def __len__(self): | |
return len(self.images)//self.batch_size if self.drop_last else ceil(len(self.images)/self.batch_size) | |
def __iter__(self): | |
if self.epoch == 0: | |
images = self.proc_images['norm'] = self.normalize(self.images) | |
# Pre-flip images in order to do every-other epoch flipping scheme | |
if self.aug.get('flip', False): | |
images = self.proc_images['flip'] = batch_flip_lr(images) | |
# Pre-pad images to save time when doing random translation | |
pad = self.aug.get('translate', 0) | |
if pad > 0: | |
self.proc_images['pad'] = F.pad(images, (pad,)*4, 'reflect') | |
if self.aug.get('translate', 0) > 0: | |
images = batch_crop(self.proc_images['pad'], self.images.shape[-2]) | |
elif self.aug.get('flip', False): | |
images = self.proc_images['flip'] | |
else: | |
images = self.proc_images['norm'] | |
# Flip all images together every other epoch. This increases diversity relative to random flipping | |
if self.aug.get('flip', False): | |
if self.epoch % 2 == 1: | |
images = images.flip(-1) | |
self.epoch += 1 | |
indices = (torch.randperm if self.shuffle else torch.arange)(len(images), device=images.device) | |
for i in range(len(self)): | |
idxs = indices[i*self.batch_size:(i+1)*self.batch_size] | |
yield (images[idxs], self.labels[idxs]) | |
############################################# | |
# Network Components # | |
############################################# | |
class Flatten(nn.Module): | |
def forward(self, x): | |
return x.view(x.size(0), -1) | |
class Mul(nn.Module): | |
def __init__(self, scale): | |
super().__init__() | |
self.scale = scale | |
def forward(self, x): | |
return x * self.scale | |
class BatchNorm(nn.BatchNorm2d): | |
def __init__(self, num_features, momentum, eps=1e-12, | |
weight=False, bias=True): | |
super().__init__(num_features, eps=eps, momentum=1-momentum) | |
self.weight.requires_grad = weight | |
self.bias.requires_grad = bias | |
# Note that PyTorch already initializes the weights to one and bias to zero | |
class Conv(nn.Conv2d): | |
def __init__(self, in_channels, out_channels, kernel_size=3, padding='same', bias=False): | |
super().__init__(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias) | |
def reset_parameters(self): | |
super().reset_parameters() | |
if self.bias is not None: | |
self.bias.data.zero_() | |
w = self.weight.data | |
torch.nn.init.dirac_(w[:w.size(1)]) | |
class ConvGroup(nn.Module): | |
def __init__(self, channels_in, channels_out, batchnorm_momentum): | |
super().__init__() | |
self.conv1 = Conv(channels_in, channels_out, bias=True) | |
self.pool = nn.MaxPool2d(2) | |
#self.norm1 = BatchNorm(channels_out, batchnorm_momentum) | |
self.conv2 = Conv(channels_out, channels_out, bias=True) | |
#self.norm2 = BatchNorm(channels_out, batchnorm_momentum) | |
self.activ = nn.GELU() | |
def forward(self, x): | |
x = self.conv1(x) | |
x = self.pool(x) | |
#x = self.norm1(x) | |
x = self.activ(x) | |
x = self.conv2(x) | |
#x = self.norm2(x) | |
x = self.activ(x) | |
return x | |
############################################# | |
# Network Definition # | |
############################################# | |
def make_net(widths=hyp['net']['widths'], batchnorm_momentum=hyp['net']['batchnorm_momentum']): | |
whiten_kernel_size = 2 | |
whiten_width = 2 * 3 * whiten_kernel_size**2 | |
net = nn.Sequential( | |
Conv(3, whiten_width, whiten_kernel_size, padding=0, bias=True), | |
nn.GELU(), | |
ConvGroup(whiten_width, widths['block1'], batchnorm_momentum), | |
ConvGroup(widths['block1'], widths['block2'], batchnorm_momentum), | |
ConvGroup(widths['block2'], widths['block3'], batchnorm_momentum), | |
nn.MaxPool2d(3), | |
Flatten(), | |
nn.Linear(widths['block3'], 10, bias=False), | |
Mul(hyp['net']['scaling_factor']), | |
) | |
net[0].weight.requires_grad = False | |
net = net.half().cuda() | |
net = net.to(memory_format=torch.channels_last) | |
for mod in net.modules(): | |
if isinstance(mod, BatchNorm): | |
mod.float() | |
return net | |
############################################# | |
# Whitening Conv Initialization # | |
############################################# | |
def get_patches(x, patch_shape): | |
c, (h, w) = x.shape[1], patch_shape | |
return x.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1,c,h,w).float() | |
def get_whitening_parameters(patches): | |
n,c,h,w = patches.shape | |
patches_flat = patches.view(n, -1) | |
est_patch_covariance = (patches_flat.T @ patches_flat) / n | |
eigenvalues, eigenvectors = torch.linalg.eigh(est_patch_covariance, UPLO='U') | |
return eigenvalues.flip(0).view(-1, 1, 1, 1), eigenvectors.T.reshape(c*h*w,c,h,w).flip(0) | |
def init_whitening_conv(layer, train_set, eps=5e-4): | |
patches = get_patches(train_set, patch_shape=layer.weight.data.shape[2:]) | |
eigenvalues, eigenvectors = get_whitening_parameters(patches) | |
eigenvectors_scaled = eigenvectors / torch.sqrt(eigenvalues + eps) | |
layer.weight.data[:] = torch.cat((eigenvectors_scaled, -eigenvectors_scaled)) | |
############################################ | |
# Lookahead # | |
############################################ | |
class LookaheadState: | |
def __init__(self, net): | |
self.net_ema = {k: v.clone() for k, v in net.state_dict().items()} | |
def update(self, net, decay): | |
for ema_param, net_param in zip(self.net_ema.values(), net.state_dict().values()): | |
if net_param.dtype in (torch.half, torch.float): | |
ema_param.lerp_(net_param, 1-decay) | |
net_param.copy_(ema_param) | |
############################################ | |
# Logging # | |
############################################ | |
def print_columns(columns_list, is_head=False, is_final_entry=False): | |
print_string = '' | |
for col in columns_list: | |
print_string += '| %s ' % col | |
print_string += '|' | |
if is_head: | |
print('-'*len(print_string)) | |
print(print_string) | |
if is_head or is_final_entry: | |
print('-'*len(print_string)) | |
logging_columns_list = ['run ', 'epoch', 'train_loss', 'train_acc', 'val_acc', 'tta_val_acc', 'total_time_seconds'] | |
def print_training_details(variables, is_final_entry): | |
formatted = [] | |
for col in logging_columns_list: | |
var = variables.get(col.strip(), None) | |
if type(var) in (int, str): | |
res = str(var) | |
elif type(var) is float: | |
res = '{:0.4f}'.format(var) | |
else: | |
assert var is None | |
res = '' | |
formatted.append(res.rjust(len(col))) | |
print_columns(formatted, is_final_entry=is_final_entry) | |
############################################ | |
# Evaluation # | |
############################################ | |
def infer(model, loader, tta_level=0): | |
# Test-time augmentation strategy (for tta_level=2): | |
# 1. Flip/mirror the image left-to-right (50% of the time). | |
# 2. Translate the image by one pixel either up-and-left or down-and-right (50% of the time, | |
# i.e. both happen 25% of the time). | |
# | |
# This creates 6 views per image (left/right times the two translations and no-translation), | |
# which we evaluate and then weight according to the given probabilities. | |
def infer_basic(inputs, net): | |
return net(inputs).clone() | |
def infer_mirror(inputs, net): | |
return 0.5 * net(inputs) + 0.5 * net(inputs.flip(-1)) | |
def infer_mirror_translate(inputs, net): | |
logits = infer_mirror(inputs, net) | |
pad = 1 | |
padded_inputs = F.pad(inputs, (pad,)*4, 'reflect') | |
inputs_translate_list = [ | |
padded_inputs[:, :, 0:32, 0:32], | |
padded_inputs[:, :, 2:34, 2:34], | |
] | |
logits_translate_list = [infer_mirror(inputs_translate, net) | |
for inputs_translate in inputs_translate_list] | |
logits_translate = torch.stack(logits_translate_list).mean(0) | |
return 0.5 * logits + 0.5 * logits_translate | |
model.eval() | |
test_images = loader.normalize(loader.images) | |
infer_fn = [infer_basic, infer_mirror, infer_mirror_translate][tta_level] | |
with torch.no_grad(): | |
return torch.cat([infer_fn(inputs, model) for inputs in test_images.split(2000)]) | |
def evaluate(model, loader, tta_level=0): | |
logits = infer(model, loader, tta_level) | |
return (logits.argmax(1) == loader.labels).float().mean().item() | |
############################################ | |
# Training # | |
############################################ | |
def main(run): | |
batch_size = hyp['opt']['batch_size'] | |
epochs = hyp['opt']['train_epochs'] | |
momentum = hyp['opt']['momentum'] | |
# Assuming gradients are constant in time, for Nesterov momentum, the below ratio is how much | |
# larger the default steps will be than the underlying per-example gradients. We divide the | |
# learning rate by this ratio in order to ensure steps are the same scale as gradients, regardless | |
# of the choice of momentum. | |
kilostep_scale = 1024 * (1 + 1 / (1 - momentum)) | |
lr = hyp['opt']['lr'] / kilostep_scale # un-decoupled learning rate for PyTorch SGD | |
wd = hyp['opt']['weight_decay'] * batch_size / kilostep_scale | |
loss_fn = nn.CrossEntropyLoss(label_smoothing=hyp['opt']['label_smoothing'], reduction='none') | |
test_loader = CifarLoader('cifar10', train=False, batch_size=2000) | |
train_loader = CifarLoader('cifar10', train=True, batch_size=batch_size, aug=hyp['aug']) | |
if run == 'warmup': | |
# The only purpose of the first run is to warmup, so we can use dummy data | |
train_loader.labels = torch.randint(0, 10, size=(len(train_loader.labels),), device=train_loader.labels.device) | |
total_train_steps = ceil(len(train_loader) * epochs) | |
model = make_net() | |
current_steps = 0 | |
optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=wd/lr, momentum=momentum, nesterov=True) | |
def triangle(steps, start=0, end=0, peak=0.5): | |
xp = torch.tensor([0, int(peak * steps), steps]) | |
fp = torch.tensor([start, 1, end]) | |
x = torch.arange(1+steps) | |
m = (fp[1:] - fp[:-1]) / (xp[1:] - xp[:-1]) | |
b = fp[:-1] - (m * xp[:-1]) | |
indices = torch.sum(torch.ge(x[:, None], xp[None, :]), 1) - 1 | |
indices = torch.clamp(indices, 0, len(m) - 1) | |
return m[indices] * x + b[indices] | |
lr_schedule = triangle(total_train_steps, start=0.2, end=0.07, peak=0.23) | |
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda i: lr_schedule[i]) | |
alpha_schedule = 0.95**5 * (torch.arange(total_train_steps+1) / total_train_steps)**3 | |
lookahead_state = LookaheadState(model) | |
# For accurately timing GPU code | |
starter = torch.cuda.Event(enable_timing=True) | |
ender = torch.cuda.Event(enable_timing=True) | |
total_time_seconds = 0.0 | |
# Initialize the whitening layer using training images | |
starter.record() | |
train_images = train_loader.normalize(train_loader.images[:5000]) | |
init_whitening_conv(model[0], train_images) | |
ender.record() | |
torch.cuda.synchronize() | |
total_time_seconds += 1e-3 * starter.elapsed_time(ender) | |
for epoch in range(ceil(epochs)): | |
model[0].bias.requires_grad = (epoch < hyp['opt']['whiten_bias_epochs']) | |
#################### | |
# Training # | |
#################### | |
starter.record() | |
model.train() | |
for inputs, labels in train_loader: | |
outputs = model(inputs) | |
loss = loss_fn(outputs, labels).sum() | |
optimizer.zero_grad(set_to_none=True) | |
loss.backward() | |
optimizer.step() | |
scheduler.step() | |
current_steps += 1 | |
if current_steps % 5 == 0: | |
lookahead_state.update(model, decay=alpha_schedule[current_steps].item()) | |
if current_steps >= total_train_steps: | |
if lookahead_state is not None: | |
lookahead_state.update(model, decay=1.0) | |
break | |
ender.record() | |
torch.cuda.synchronize() | |
total_time_seconds += 1e-3 * starter.elapsed_time(ender) | |
#################### | |
# Evaluation # | |
#################### | |
# Save the accuracy and loss from the last training batch of the epoch | |
train_acc = (outputs.detach().argmax(1) == labels).float().mean().item() | |
train_loss = loss.item() / batch_size | |
val_acc = evaluate(model, test_loader, tta_level=0) | |
print_training_details(locals(), is_final_entry=False) | |
run = None # Only print the run number once | |
#################### | |
# TTA Evaluation # | |
#################### | |
starter.record() | |
tta_val_acc = evaluate(model, test_loader, tta_level=hyp['net']['tta_level']) | |
ender.record() | |
torch.cuda.synchronize() | |
total_time_seconds += 1e-3 * starter.elapsed_time(ender) | |
epoch = 'eval' | |
print_training_details(locals(), is_final_entry=True) | |
return tta_val_acc | |
if __name__ == "__main__": | |
with open(sys.argv[0]) as f: | |
code = f.read() | |
print_columns(logging_columns_list, is_head=True) | |
#main('warmup') | |
accs = torch.tensor([main(run) for run in range(25)]) | |
print('Mean: %.4f Std: %.4f' % (accs.mean(), accs.std())) | |
log = {'code': code, 'accs': accs} | |
log_dir = os.path.join('logs', str(uuid.uuid4())) | |
os.makedirs(log_dir, exist_ok=True) | |
log_path = os.path.join(log_dir, 'log.pt') | |
print(os.path.abspath(log_path)) | |
torch.save(log, os.path.join(log_dir, 'log.pt')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment