Created
December 1, 2022 19:23
-
-
Save FindHao/db610e0b7d38c362558f413793eb7207 to your computer and use it in GitHub Desktop.
an resnet example for jax
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://github.com/phlippe/uvadlc_notebooks_benchmarking/blob/main/PyTorch/Tutorial5_Inception_ResNet_DenseNet.py | |
from torchvision.datasets import CIFAR10 | |
from torchvision import transforms | |
from torch.utils.tensorboard import SummaryWriter | |
import torch.utils.data as data | |
import torch | |
from flax.training import train_state, checkpoints | |
from flax import linen as nn | |
from jax import random | |
import jax.numpy as jnp | |
import jax | |
from tqdm.auto import tqdm | |
import os | |
import numpy as np | |
from typing import Any | |
from collections import defaultdict | |
import time | |
import optax | |
DATASET_PATH = "../data" | |
CHECKPOINT_PATH = "../saved_models/tutorial5_jax" | |
timestr = time.strftime("%Y_%m_%d__%H_%M_%S") | |
LOG_FILE = open(f'../logs/tutorial5_jax_{timestr}.txt', 'w') | |
main_rng = random.PRNGKey(42) | |
print("Device:", jax.devices()[0]) | |
train_dataset = CIFAR10(root=DATASET_PATH, train=True, download=True) | |
DATA_MEANS = (train_dataset.data / 255.0).mean(axis=(0, 1, 2)) | |
DATA_STD = (train_dataset.data / 255.0).std(axis=(0, 1, 2)) | |
def image_to_numpy(img): | |
img = np.array(img, dtype=np.float32) | |
img = (img / 255. - DATA_MEANS) / DATA_STD | |
return img | |
def numpy_collate(batch): | |
if isinstance(batch[0], np.ndarray): | |
return np.stack(batch) | |
elif isinstance(batch[0], (tuple, list)): | |
transposed = zip(*batch) | |
return [numpy_collate(samples) for samples in transposed] | |
else: | |
return np.array(batch) | |
test_transform = image_to_numpy | |
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), | |
transforms.RandomResizedCrop( | |
(32, 32), scale=(0.8, 1.0), ratio=(0.9, 1.1)), | |
image_to_numpy | |
]) | |
train_dataset = CIFAR10(root=DATASET_PATH, train=True, | |
transform=train_transform, download=True) | |
val_dataset = CIFAR10(root=DATASET_PATH, train=True, | |
transform=test_transform, download=True) | |
train_set, _ = torch.utils.data.random_split( | |
train_dataset, [45000, 5000], generator=torch.Generator().manual_seed(42)) | |
_, val_set = torch.utils.data.random_split( | |
val_dataset, [45000, 5000], generator=torch.Generator().manual_seed(42)) | |
test_set = CIFAR10(root=DATASET_PATH, train=False, | |
transform=test_transform, download=True) | |
train_loader = data.DataLoader(train_set, | |
batch_size=128, | |
shuffle=True, | |
drop_last=True, | |
collate_fn=numpy_collate, | |
num_workers=8, | |
persistent_workers=True) | |
val_loader = data.DataLoader(val_set, | |
batch_size=128, | |
shuffle=False, | |
drop_last=False, | |
collate_fn=numpy_collate, | |
num_workers=4, | |
persistent_workers=True) | |
test_loader = data.DataLoader(test_set, | |
batch_size=128, | |
shuffle=False, | |
drop_last=False, | |
collate_fn=numpy_collate, | |
num_workers=4, | |
persistent_workers=True) | |
class TrainState(train_state.TrainState): | |
# A simple extension of TrainState to also include batch statistics | |
batch_stats: Any | |
class TrainerModule: | |
def __init__(self, | |
model_name: str, | |
model_class: nn.Module, | |
model_hparams: dict, | |
optimizer_name: str, | |
optimizer_hparams: dict, | |
exmp_imgs: Any, | |
seed=42): | |
""" | |
Module for summarizing all training functionalities for classification on CIFAR10. | |
Inputs: | |
model_name - String of the class name, used for logging and saving | |
model_class - Class implementing the neural network | |
model_hparams - Hyperparameters of the model, used as input to model constructor | |
optimizer_name - String of the optimizer name, supporting ['sgd', 'adam', 'adamw'] | |
optimizer_hparams - Hyperparameters of the optimizer, including learning rate as 'lr' | |
exmp_imgs - Example imgs, used as input to initialize the model | |
seed - Seed to use in the model initialization | |
""" | |
super().__init__() | |
self.model_name = model_name | |
self.model_class = model_class | |
self.model_hparams = model_hparams | |
self.optimizer_name = optimizer_name | |
self.optimizer_hparams = optimizer_hparams | |
self.seed = seed | |
# Create empty model. Note: no parameters yet | |
self.model = self.model_class(**self.model_hparams) | |
# Prepare logging | |
self.log_dir = os.path.join(CHECKPOINT_PATH, self.model_name) | |
self.logger = SummaryWriter(log_dir=self.log_dir) | |
# Create jitted training and eval functions | |
self.create_functions() | |
# Initialize model | |
self.init_model(exmp_imgs) | |
def create_functions(self): | |
# Function to calculate the classification loss and accuracy for a model | |
def calculate_loss(params, batch_stats, batch, train): | |
imgs, labels = batch | |
labels_onehot = jax.nn.one_hot( | |
labels, num_classes=self.model.num_classes) | |
# Run model. During training, we need to update the BatchNorm statistics. | |
outs = self.model.apply({'params': params, 'batch_stats': batch_stats}, | |
imgs, | |
train=train, | |
mutable=['batch_stats'] if train else False) | |
logits, new_model_state = outs if train else (outs, None) | |
loss = optax.softmax_cross_entropy(logits, labels_onehot).mean() | |
acc = (logits.argmax(axis=-1) == labels).mean() | |
return loss, (acc, new_model_state) | |
# Training function | |
def train_step(state, batch): | |
def loss_fn(params): return calculate_loss( | |
params, state.batch_stats, batch, train=True) | |
# Get loss, gradients for loss, and other outputs of loss function | |
ret, grads = jax.value_and_grad( | |
loss_fn, has_aux=True)(state.params) | |
loss, acc, new_model_state = ret[0], *ret[1] | |
# Update parameters and batch statistics | |
state = state.apply_gradients( | |
grads=grads, batch_stats=new_model_state['batch_stats']) | |
return state, loss, acc | |
# Eval function | |
def eval_step(state, batch): | |
# Return the accuracy for a single batch | |
_, (acc, _) = calculate_loss(state.params, | |
state.batch_stats, batch, train=False) | |
return acc | |
# jit for efficiency | |
self.train_step = jax.jit(train_step) | |
self.eval_step = jax.jit(eval_step) | |
def init_model(self, exmp_imgs): | |
# Initialize model | |
init_rng = jax.random.PRNGKey(self.seed) | |
variables = self.model.init(init_rng, exmp_imgs, train=True) | |
self.init_params, self.init_batch_stats = variables['params'], variables['batch_stats'] | |
self.state = None | |
def init_optimizer(self, num_epochs, num_steps_per_epoch): | |
# Initialize learning rate schedule and optimizer | |
if self.optimizer_name.lower() == 'adam': | |
opt_class = optax.adam | |
elif self.optimizer_name.lower() == 'adamw': | |
opt_class = optax.adamw | |
elif self.optimizer_name.lower() == 'sgd': | |
opt_class = optax.sgd | |
else: | |
assert False, f'Unknown optimizer "{opt_class}"' | |
# We decrease the learning rate by a factor of 0.1 after 60% and 85% of the training | |
lr_schedule = optax.piecewise_constant_schedule( | |
init_value=self.optimizer_hparams.pop('lr'), | |
boundaries_and_scales={int(num_steps_per_epoch*num_epochs*0.6): 0.1, | |
int(num_steps_per_epoch*num_epochs*0.85): 0.1} | |
) | |
# Clip gradients at max value, and evt. apply weight decay | |
transf = [optax.clip(1.0)] | |
if opt_class == optax.sgd and 'weight_decay' in self.optimizer_hparams: # wd is integrated in adamw | |
transf.append(optax.add_decayed_weights( | |
self.optimizer_hparams.pop('weight_decay'))) | |
optimizer = optax.chain( | |
*transf, | |
opt_class(lr_schedule, **self.optimizer_hparams) | |
) | |
# Initialize training state | |
self.state = TrainState.create(apply_fn=self.model.apply, | |
params=self.init_params if self.state is None else self.state.params, | |
batch_stats=self.init_batch_stats if self.state is None else self.state.batch_stats, | |
tx=optimizer) | |
def train_model(self, train_loader, val_loader, num_epochs=200): | |
# Train model for defined number of epochs | |
# We first need to create optimizer and the scheduler for the given number of epochs | |
self.init_optimizer(num_epochs, len(train_loader)) | |
# Track best eval accuracy | |
best_eval = 0.0 | |
for epoch_idx in tqdm(range(1, num_epochs+1)): | |
self.train_epoch(train_loader, epoch=epoch_idx) | |
if epoch_idx % 2 == 0: | |
eval_acc = self.eval_model(val_loader) | |
self.logger.add_scalar( | |
'val/acc', eval_acc, global_step=epoch_idx) | |
if eval_acc >= best_eval: | |
best_eval = eval_acc | |
self.save_model(step=epoch_idx) | |
self.logger.flush() | |
def train_epoch(self, train_loader, epoch): | |
# Train model for one epoch, and log avg loss and accuracy | |
metrics = defaultdict(list) | |
for batch in tqdm(train_loader, desc='Training', leave=False): | |
self.state, loss, acc = self.train_step(self.state, batch) | |
metrics['loss'].append(loss) | |
metrics['acc'].append(acc) | |
for key in metrics: | |
avg_val = np.stack(jax.device_get(metrics[key])).mean() | |
self.logger.add_scalar('train/'+key, avg_val, global_step=epoch) | |
def eval_model(self, data_loader): | |
# Test model on all images of a data loader and return avg loss | |
correct_class, count = 0, 0 | |
for batch in data_loader: | |
acc = self.eval_step(self.state, batch) | |
correct_class += acc * batch[0].shape[0] | |
count += batch[0].shape[0] | |
eval_acc = (correct_class / count).item() | |
return eval_acc | |
def save_model(self, step=0): | |
# Save current model at certain training iteration | |
checkpoints.save_checkpoint(ckpt_dir=self.log_dir, | |
target={'params': self.state.params, | |
'batch_stats': self.state.batch_stats}, | |
step=step, | |
overwrite=True) | |
def load_model(self, pretrained=False): | |
# Load model. We use different checkpoint for pretrained models | |
if not pretrained: | |
state_dict = checkpoints.restore_checkpoint( | |
ckpt_dir=self.log_dir, target=None) | |
else: | |
state_dict = checkpoints.restore_checkpoint(ckpt_dir=os.path.join( | |
CHECKPOINT_PATH, f'{self.model_name}.ckpt'), target=None) | |
self.state = TrainState.create(apply_fn=self.model.apply, | |
params=state_dict['params'], | |
batch_stats=state_dict['batch_stats'], | |
tx=self.state.tx if self.state else optax.sgd( | |
0.1) # Default optimizer | |
) | |
def checkpoint_exists(self): | |
# Check whether a pretrained model exist for this autoencoder | |
return os.path.isfile(os.path.join(CHECKPOINT_PATH, f'{self.model_name}.ckpt')) | |
def train_classifier(*args, num_epochs=200, **kwargs): | |
# Create a trainer module with specified hyperparameters | |
trainer = TrainerModule(*args, **kwargs) | |
start_time = time.time() | |
with jax.profiler.trace("/tmp/jax-trace"): | |
trainer.train_model(train_loader, val_loader, num_epochs=num_epochs) | |
train_time = time.time() | |
print(trainer.model_name, ' - Full training time:', | |
time.strftime('%H:%M:%S', time.gmtime(train_time - start_time)), | |
file=LOG_FILE, flush=True) | |
return None, None | |
resnet_kernel_init = nn.initializers.variance_scaling( | |
2.0, mode='fan_out', distribution='normal') | |
class ResNetBlock(nn.Module): | |
act_fn: callable # Activation function | |
c_out: int # Output feature size | |
subsample: bool = False # If True, we apply a stride inside F | |
@nn.compact | |
def __call__(self, x, train=True): | |
# Network representing F | |
z = nn.Conv(self.c_out, kernel_size=(3, 3), | |
strides=(1, 1) if not self.subsample else (2, 2), | |
kernel_init=resnet_kernel_init, | |
use_bias=False)(x) | |
z = nn.BatchNorm()(z, use_running_average=not train) | |
z = self.act_fn(z) | |
z = nn.Conv(self.c_out, kernel_size=(3, 3), | |
kernel_init=resnet_kernel_init, | |
use_bias=False)(z) | |
z = nn.BatchNorm()(z, use_running_average=not train) | |
if self.subsample: | |
x = nn.Conv(self.c_out, kernel_size=(1, 1), strides=( | |
2, 2), kernel_init=resnet_kernel_init)(x) | |
x_out = self.act_fn(z + x) | |
return x_out | |
class ResNet(nn.Module): | |
num_classes: int | |
act_fn: callable | |
block_class: nn.Module | |
num_blocks: tuple = (3, 3, 3) | |
c_hidden: tuple = (16, 32, 64) | |
@nn.compact | |
def __call__(self, x, train=True): | |
# A first convolution on the original image to scale up the channel size | |
x = nn.Conv(self.c_hidden[0], kernel_size=( | |
3, 3), kernel_init=resnet_kernel_init, use_bias=False)(x) | |
if self.block_class == ResNetBlock: # If pre-activation block, we do not apply non-linearities yet | |
x = nn.BatchNorm()(x, use_running_average=not train) | |
x = self.act_fn(x) | |
# Creating the ResNet blocks | |
for block_idx, block_count in enumerate(self.num_blocks): | |
for bc in range(block_count): | |
# Subsample the first block of each group, except the very first one. | |
subsample = (bc == 0 and block_idx > 0) | |
# ResNet block | |
x = self.block_class(c_out=self.c_hidden[block_idx], | |
act_fn=self.act_fn, | |
subsample=subsample)(x, train=train) | |
# Mapping to classification output | |
x = x.mean(axis=(1, 2)) | |
x = nn.Dense(self.num_classes)(x) | |
return x | |
resnet_trainer, resnet_results = train_classifier(model_name="ResNet", | |
model_class=ResNet, | |
model_hparams={"num_classes": 10, | |
"c_hidden": (16, 32, 64), | |
"num_blocks": (3, 3, 3), | |
"act_fn": nn.relu, | |
"block_class": ResNetBlock}, | |
optimizer_name="SGD", | |
optimizer_hparams={"lr": 0.1, | |
"momentum": 0.9, | |
"weight_decay": 1e-4}, | |
exmp_imgs=jax.device_put( | |
next(iter(train_loader))[0]), | |
num_epochs=1) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment