Skip to content

Instantly share code, notes, and snippets.

@Mikubill
Created November 19, 2023 15:24
Show Gist options
  • Save Mikubill/2360c5c1e9d0c2e8a031a2352778ee1c to your computer and use it in GitHub Desktop.
Save Mikubill/2360c5c1e9d0c2e8a031a2352778ee1c to your computer and use it in GitHub Desktop.
ConvNeXt V2 Image Classifier
import os
import random
import math, warnings
from copy import deepcopy
import numpy as np
import lightning
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from torchvision.transforms import functional
from ml_collections import ConfigDict
def drop_path(
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
with torch.no_grad():
return _trunc_normal_(tensor, mean, std, a, b)
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f"drop_prob={round(self.drop_prob,3):0.3f}"
class LayerNorm(nn.Module):
"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(
x, self.normalized_shape, self.weight, self.bias, self.eps
)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class GRN(nn.Module):
"""GRN (Global Response Normalization) layer"""
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
class Block(nn.Module):
"""ConvNeXtV2 Block.
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
"""
def __init__(self, dim, drop_path=0.0):
super().__init__()
self.dwconv = nn.Conv2d(
dim, dim, kernel_size=7, padding=3, groups=dim
) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(
dim, 4 * dim
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.grn = GRN(4 * dim)
self.pwconv2 = nn.Linear(4 * dim, dim)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = input + self.drop_path(x)
return x
class ConvNeXtV2(nn.Module):
"""ConvNeXt V2
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
dims (int[]): Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_path_rate (float): Stochastic depth rate. Default: 0.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
"""
def __init__(
self,
in_chans=3,
num_classes=1000,
depths=[3, 3, 9, 3],
dims=[96, 192, 384, 768],
drop_path_rate=0.0,
head_init_scale=1.0,
):
super().__init__()
self.depths = depths
self.downsample_layers = (
nn.ModuleList()
) # stem and 3 intermediate downsampling conv layers
stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
)
self.downsample_layers.append(stem)
for i in range(3):
downsample_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
)
self.downsample_layers.append(downsample_layer)
self.stages = (
nn.ModuleList()
) # 4 feature resolution stages, each consisting of multiple residual blocks
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
for i in range(4):
stage = nn.Sequential(
*[
Block(dim=dims[i], drop_path=dp_rates[cur + j])
for j in range(depths[i])
]
)
self.stages.append(stage)
cur += depths[i]
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
self.head = nn.Linear(dims[-1], num_classes)
self.apply(self._init_weights)
self.head.weight.data.mul_(head_init_scale)
self.head.bias.data.mul_(head_init_scale)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward_features(self, x):
for i in range(4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
return self.norm(
x.mean([-2, -1])
) # global average pooling, (N, C, H, W) -> (N, C)
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def convnextv2_atto(**kwargs):
model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs)
return model
def convnextv2_femto(**kwargs):
model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs)
return model
def convnext_pico(**kwargs):
model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs)
return model
def convnextv2_nano(**kwargs):
model = ConvNeXtV2(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs)
return model
def convnextv2_tiny(**kwargs):
model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
return model
def convnextv2_base(**kwargs):
model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
return model
def convnextv2_large(**kwargs):
model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
return model
def convnextv2_huge(**kwargs):
model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs)
return model
# ========= Dataset =========
def rescale_pad(image, output_size, random_pad=False):
h, w = image.shape[-2:]
if h != output_size or w != output_size:
r = min(output_size / h, output_size / w)
new_h, new_w = int(h * r), int(w * r)
if random_pad:
r2 = random.uniform(0.9, 1)
new_h, new_w = int(new_h * r2), int(new_w * r2)
ph = output_size - new_h
pw = output_size - new_w
left = random.randint(0, pw) if random_pad else pw // 2
right = pw - left
top = random.randint(0, ph) if random_pad else ph // 2
bottom = ph - top
image = transforms.functional.resize(image, [new_h, new_w])
image = transforms.functional.pad(
image, [left, top, right, bottom], random.uniform(0, 1) if random_pad else 0
)
return image
def random_crop(image, min_rate=0.8):
h, w = image.shape[-2:]
new_h, new_w = int(h * random.uniform(min_rate, 1)), int(
w * random.uniform(min_rate, 1)
)
top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)
image = image[:, top : top + new_h, left : left + new_w]
return image
class ClassifierDataset(Dataset):
def __init__(self, path, img_size):
self.path = path
self.img_size = img_size
self.image_list = []
self.label_list = []
self.label_dict = {}
label_folders = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
for i, label_folder in enumerate(label_folders):
self.label_dict[label_folder] = i
folder_path = os.path.join(path, label_folder)
for root, _dirs, files in os.walk(folder_path):
for fname in files:
if os.path.splitext(fname)[1].lower() in [".png", ".jpg", ".jpeg", ".webp"]:
self.image_list.append(os.path.join(root, fname))
self.label_list.append(i)
print("Label sequence:", self.label_dict)
def __len__(self):
length = len(self.image_list) * 2
return length
def __getitem__(self, index):
real_len = len(self.image_list)
fname = self.image_list[index % real_len]
label = self.label_list[index % real_len]
image = Image.open(os.path.join(self.path, fname)).convert("RGB")
image = transforms.functional.to_tensor(image)
image = random_crop(image, min_rate=0.8)
image = rescale_pad(image, self.img_size, True)
if index // real_len != 0:
image = transforms.functional.hflip(image)
label = torch.tensor([label], dtype=torch.float32)
return image, label
#
class Classifier(lightning.LightningModule):
def __init__(self, model, drop_path_rate=0.0, ema_decay=0):
super().__init__()
self.net = model(in_chans=3, num_classes=1, drop_path_rate=drop_path_rate)
self.ema = deepcopy(self.net)
self.ema_decay = 0.9999
self.ema.requires_grad_(False)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.net.parameters(), weight_decay=0.05)
return optimizer
def forward(self, x, use_ema=False):
x = transforms.functional.normalize(x, 0.5, 0.5)
return self.ema(x) if use_ema else self.net(x)
def training_step(self, batch, batch_idx):
images, labels = batch
loss = F.mse_loss(self.forward(images), labels)
self.log_dict({"train/loss": loss})
return loss
def validation_step(self, batch, batch_idx):
images, labels = batch
pred = self.forward(images)
mae = F.l1_loss(pred, labels)
pred_ema = self.forward(images, True)
mae_ema = F.l1_loss(pred_ema, labels)
logs = {
"val/mae": mae,
"val/mae_ema": mae_ema
}
self.log_dict(logs, sync_dist=True, prog_bar=True)
def on_train_batch_end(self, outputs, batch, batch_idx):
if self.ema is not None:
with torch.no_grad():
for ema_v, model_v in zip(self.ema.state_dict().values(), self.net.state_dict().values()):
ema_v.copy_(self.ema_decay * ema_v + (1.0 - self.ema_decay) * model_v)
def main(opt):
print(opt)
seed_everything(opt.seed)
torch.set_float32_matmul_precision('high')
generator = torch.Generator().manual_seed(opt.seed)
full_dataset = ClassifierDataset(opt.dataset.path, opt.dataset.img_size)
split_ratio = [opt.dataset.train_split, 1 - opt.dataset.train_split]
train_dataset, val_dataset = random_split(full_dataset, split_ratio, generator)
dataloader_args = {
"num_workers": 8,
"pin_memory": True,
"persistent_workers": True,
}
train_dataloader = DataLoader(train_dataset, batch_size=opt.bsz_train, shuffle=True, **dataloader_args)
val_dataloader = DataLoader(val_dataset, batch_size=opt.bsz_val, shuffle=False, **dataloader_args)
classifier = Classifier(model=opt.model, drop_path_rate=opt.model_drop_path)
callbacks = [
ModelCheckpoint(
monitor="val/mae",
mode="min",
save_top_k=1,
save_last=True,
auto_insert_metric_name=False,
filename="epoch={epoch},mae={val/mae:.4f}",
),
ModelCheckpoint(
monitor="val/mae_ema",
mode="min",
save_top_k=1,
save_last=False,
auto_insert_metric_name=False,
filename="epoch={epoch},mae-ema={val/mae_ema:.4f}",
)
]
trainer = lightning.Trainer(**opt.trainer, callbacks=callbacks)
trainer.fit(classifier, train_dataloader, val_dataloader)
def get_config():
config = ConfigDict()
# model args
config.model = convnextv2_tiny
config.model_drop_path = 0.1
config.resume = ""
config.seed = 42
config.bsz_train = 32
config.bsz_val = 2
# dataset args
dataset = config.dataset = ConfigDict()
dataset.path = "/tmp/classifier-dataset"
dataset.train_split = 0.9999
dataset.img_size = 768
# training args
trainer = config.trainer = ConfigDict()
trainer.max_epochs = 100
trainer.accumulate_grad_batches = 4
trainer.accelerator = "gpu"
trainer.precision = 32
trainer.benchmark = True
trainer.log_every_n_steps = 1
trainer.val_check_interval = 0.025
trainer.strategy = "auto" #"ddp_find_unused_parameters_false"
return config
if __name__ == "__main__":
main(get_config())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment