Created
November 19, 2023 15:24
-
-
Save Mikubill/2360c5c1e9d0c2e8a031a2352778ee1c to your computer and use it in GitHub Desktop.
ConvNeXt V2 Image Classifier
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import random | |
import math, warnings | |
from copy import deepcopy | |
import numpy as np | |
import lightning | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from PIL import Image | |
from lightning.pytorch import seed_everything | |
from lightning.pytorch.callbacks import ModelCheckpoint | |
from torch.utils.data import DataLoader, Dataset, random_split | |
from torchvision import transforms | |
from torchvision.transforms import functional | |
from ml_collections import ConfigDict | |
def drop_path( | |
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True | |
): | |
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, | |
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... | |
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for | |
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use | |
'survival rate' as the argument. | |
""" | |
if drop_prob == 0.0 or not training: | |
return x | |
keep_prob = 1 - drop_prob | |
shape = (x.shape[0],) + (1,) * ( | |
x.ndim - 1 | |
) # work with diff dim tensors, not just 2D ConvNets | |
random_tensor = x.new_empty(shape).bernoulli_(keep_prob) | |
if keep_prob > 0.0 and scale_by_keep: | |
random_tensor.div_(keep_prob) | |
return x * random_tensor | |
def _trunc_normal_(tensor, mean, std, a, b): | |
# Cut & paste from PyTorch official master until it's in a few official releases - RW | |
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf | |
def norm_cdf(x): | |
# Computes standard normal cumulative distribution function | |
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 | |
if (mean < a - 2 * std) or (mean > b + 2 * std): | |
warnings.warn( | |
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " | |
"The distribution of values may be incorrect.", | |
stacklevel=2, | |
) | |
# Values are generated by using a truncated uniform distribution and | |
# then using the inverse CDF for the normal distribution. | |
# Get upper and lower cdf values | |
l = norm_cdf((a - mean) / std) | |
u = norm_cdf((b - mean) / std) | |
# Uniformly fill tensor with values from [l, u], then translate to | |
# [2l-1, 2u-1]. | |
tensor.uniform_(2 * l - 1, 2 * u - 1) | |
# Use inverse cdf transform for normal distribution to get truncated | |
# standard normal | |
tensor.erfinv_() | |
# Transform to proper mean, std | |
tensor.mul_(std * math.sqrt(2.0)) | |
tensor.add_(mean) | |
# Clamp to ensure it's in the proper range | |
tensor.clamp_(min=a, max=b) | |
return tensor | |
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): | |
with torch.no_grad(): | |
return _trunc_normal_(tensor, mean, std, a, b) | |
class DropPath(nn.Module): | |
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" | |
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): | |
super(DropPath, self).__init__() | |
self.drop_prob = drop_prob | |
self.scale_by_keep = scale_by_keep | |
def forward(self, x): | |
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) | |
def extra_repr(self): | |
return f"drop_prob={round(self.drop_prob,3):0.3f}" | |
class LayerNorm(nn.Module): | |
"""LayerNorm that supports two data formats: channels_last (default) or channels_first. | |
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with | |
shape (batch_size, height, width, channels) while channels_first corresponds to inputs | |
with shape (batch_size, channels, height, width). | |
""" | |
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): | |
super().__init__() | |
self.weight = nn.Parameter(torch.ones(normalized_shape)) | |
self.bias = nn.Parameter(torch.zeros(normalized_shape)) | |
self.eps = eps | |
self.data_format = data_format | |
if self.data_format not in ["channels_last", "channels_first"]: | |
raise NotImplementedError | |
self.normalized_shape = (normalized_shape,) | |
def forward(self, x): | |
if self.data_format == "channels_last": | |
return F.layer_norm( | |
x, self.normalized_shape, self.weight, self.bias, self.eps | |
) | |
elif self.data_format == "channels_first": | |
u = x.mean(1, keepdim=True) | |
s = (x - u).pow(2).mean(1, keepdim=True) | |
x = (x - u) / torch.sqrt(s + self.eps) | |
x = self.weight[:, None, None] * x + self.bias[:, None, None] | |
return x | |
class GRN(nn.Module): | |
"""GRN (Global Response Normalization) layer""" | |
def __init__(self, dim): | |
super().__init__() | |
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) | |
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) | |
def forward(self, x): | |
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) | |
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) | |
return self.gamma * (x * Nx) + self.beta + x | |
class Block(nn.Module): | |
"""ConvNeXtV2 Block. | |
Args: | |
dim (int): Number of input channels. | |
drop_path (float): Stochastic depth rate. Default: 0.0 | |
""" | |
def __init__(self, dim, drop_path=0.0): | |
super().__init__() | |
self.dwconv = nn.Conv2d( | |
dim, dim, kernel_size=7, padding=3, groups=dim | |
) # depthwise conv | |
self.norm = LayerNorm(dim, eps=1e-6) | |
self.pwconv1 = nn.Linear( | |
dim, 4 * dim | |
) # pointwise/1x1 convs, implemented with linear layers | |
self.act = nn.GELU() | |
self.grn = GRN(4 * dim) | |
self.pwconv2 = nn.Linear(4 * dim, dim) | |
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
def forward(self, x): | |
input = x | |
x = self.dwconv(x) | |
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) | |
x = self.norm(x) | |
x = self.pwconv1(x) | |
x = self.act(x) | |
x = self.grn(x) | |
x = self.pwconv2(x) | |
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) | |
x = input + self.drop_path(x) | |
return x | |
class ConvNeXtV2(nn.Module): | |
"""ConvNeXt V2 | |
Args: | |
in_chans (int): Number of input image channels. Default: 3 | |
num_classes (int): Number of classes for classification head. Default: 1000 | |
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] | |
dims (int[]): Feature dimension at each stage. Default: [96, 192, 384, 768] | |
drop_path_rate (float): Stochastic depth rate. Default: 0. | |
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. | |
""" | |
def __init__( | |
self, | |
in_chans=3, | |
num_classes=1000, | |
depths=[3, 3, 9, 3], | |
dims=[96, 192, 384, 768], | |
drop_path_rate=0.0, | |
head_init_scale=1.0, | |
): | |
super().__init__() | |
self.depths = depths | |
self.downsample_layers = ( | |
nn.ModuleList() | |
) # stem and 3 intermediate downsampling conv layers | |
stem = nn.Sequential( | |
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), | |
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), | |
) | |
self.downsample_layers.append(stem) | |
for i in range(3): | |
downsample_layer = nn.Sequential( | |
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), | |
nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), | |
) | |
self.downsample_layers.append(downsample_layer) | |
self.stages = ( | |
nn.ModuleList() | |
) # 4 feature resolution stages, each consisting of multiple residual blocks | |
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] | |
cur = 0 | |
for i in range(4): | |
stage = nn.Sequential( | |
*[ | |
Block(dim=dims[i], drop_path=dp_rates[cur + j]) | |
for j in range(depths[i]) | |
] | |
) | |
self.stages.append(stage) | |
cur += depths[i] | |
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer | |
self.head = nn.Linear(dims[-1], num_classes) | |
self.apply(self._init_weights) | |
self.head.weight.data.mul_(head_init_scale) | |
self.head.bias.data.mul_(head_init_scale) | |
def _init_weights(self, m): | |
if isinstance(m, (nn.Conv2d, nn.Linear)): | |
trunc_normal_(m.weight, std=0.02) | |
nn.init.constant_(m.bias, 0) | |
def forward_features(self, x): | |
for i in range(4): | |
x = self.downsample_layers[i](x) | |
x = self.stages[i](x) | |
return self.norm( | |
x.mean([-2, -1]) | |
) # global average pooling, (N, C, H, W) -> (N, C) | |
def forward(self, x): | |
x = self.forward_features(x) | |
x = self.head(x) | |
return x | |
def convnextv2_atto(**kwargs): | |
model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs) | |
return model | |
def convnextv2_femto(**kwargs): | |
model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs) | |
return model | |
def convnext_pico(**kwargs): | |
model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs) | |
return model | |
def convnextv2_nano(**kwargs): | |
model = ConvNeXtV2(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs) | |
return model | |
def convnextv2_tiny(**kwargs): | |
model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) | |
return model | |
def convnextv2_base(**kwargs): | |
model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) | |
return model | |
def convnextv2_large(**kwargs): | |
model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) | |
return model | |
def convnextv2_huge(**kwargs): | |
model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs) | |
return model | |
# ========= Dataset ========= | |
def rescale_pad(image, output_size, random_pad=False): | |
h, w = image.shape[-2:] | |
if h != output_size or w != output_size: | |
r = min(output_size / h, output_size / w) | |
new_h, new_w = int(h * r), int(w * r) | |
if random_pad: | |
r2 = random.uniform(0.9, 1) | |
new_h, new_w = int(new_h * r2), int(new_w * r2) | |
ph = output_size - new_h | |
pw = output_size - new_w | |
left = random.randint(0, pw) if random_pad else pw // 2 | |
right = pw - left | |
top = random.randint(0, ph) if random_pad else ph // 2 | |
bottom = ph - top | |
image = transforms.functional.resize(image, [new_h, new_w]) | |
image = transforms.functional.pad( | |
image, [left, top, right, bottom], random.uniform(0, 1) if random_pad else 0 | |
) | |
return image | |
def random_crop(image, min_rate=0.8): | |
h, w = image.shape[-2:] | |
new_h, new_w = int(h * random.uniform(min_rate, 1)), int( | |
w * random.uniform(min_rate, 1) | |
) | |
top = np.random.randint(0, h - new_h) | |
left = np.random.randint(0, w - new_w) | |
image = image[:, top : top + new_h, left : left + new_w] | |
return image | |
class ClassifierDataset(Dataset): | |
def __init__(self, path, img_size): | |
self.path = path | |
self.img_size = img_size | |
self.image_list = [] | |
self.label_list = [] | |
self.label_dict = {} | |
label_folders = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))] | |
for i, label_folder in enumerate(label_folders): | |
self.label_dict[label_folder] = i | |
folder_path = os.path.join(path, label_folder) | |
for root, _dirs, files in os.walk(folder_path): | |
for fname in files: | |
if os.path.splitext(fname)[1].lower() in [".png", ".jpg", ".jpeg", ".webp"]: | |
self.image_list.append(os.path.join(root, fname)) | |
self.label_list.append(i) | |
print("Label sequence:", self.label_dict) | |
def __len__(self): | |
length = len(self.image_list) * 2 | |
return length | |
def __getitem__(self, index): | |
real_len = len(self.image_list) | |
fname = self.image_list[index % real_len] | |
label = self.label_list[index % real_len] | |
image = Image.open(os.path.join(self.path, fname)).convert("RGB") | |
image = transforms.functional.to_tensor(image) | |
image = random_crop(image, min_rate=0.8) | |
image = rescale_pad(image, self.img_size, True) | |
if index // real_len != 0: | |
image = transforms.functional.hflip(image) | |
label = torch.tensor([label], dtype=torch.float32) | |
return image, label | |
# | |
class Classifier(lightning.LightningModule): | |
def __init__(self, model, drop_path_rate=0.0, ema_decay=0): | |
super().__init__() | |
self.net = model(in_chans=3, num_classes=1, drop_path_rate=drop_path_rate) | |
self.ema = deepcopy(self.net) | |
self.ema_decay = 0.9999 | |
self.ema.requires_grad_(False) | |
def configure_optimizers(self): | |
optimizer = torch.optim.AdamW(self.net.parameters(), weight_decay=0.05) | |
return optimizer | |
def forward(self, x, use_ema=False): | |
x = transforms.functional.normalize(x, 0.5, 0.5) | |
return self.ema(x) if use_ema else self.net(x) | |
def training_step(self, batch, batch_idx): | |
images, labels = batch | |
loss = F.mse_loss(self.forward(images), labels) | |
self.log_dict({"train/loss": loss}) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
images, labels = batch | |
pred = self.forward(images) | |
mae = F.l1_loss(pred, labels) | |
pred_ema = self.forward(images, True) | |
mae_ema = F.l1_loss(pred_ema, labels) | |
logs = { | |
"val/mae": mae, | |
"val/mae_ema": mae_ema | |
} | |
self.log_dict(logs, sync_dist=True, prog_bar=True) | |
def on_train_batch_end(self, outputs, batch, batch_idx): | |
if self.ema is not None: | |
with torch.no_grad(): | |
for ema_v, model_v in zip(self.ema.state_dict().values(), self.net.state_dict().values()): | |
ema_v.copy_(self.ema_decay * ema_v + (1.0 - self.ema_decay) * model_v) | |
def main(opt): | |
print(opt) | |
seed_everything(opt.seed) | |
torch.set_float32_matmul_precision('high') | |
generator = torch.Generator().manual_seed(opt.seed) | |
full_dataset = ClassifierDataset(opt.dataset.path, opt.dataset.img_size) | |
split_ratio = [opt.dataset.train_split, 1 - opt.dataset.train_split] | |
train_dataset, val_dataset = random_split(full_dataset, split_ratio, generator) | |
dataloader_args = { | |
"num_workers": 8, | |
"pin_memory": True, | |
"persistent_workers": True, | |
} | |
train_dataloader = DataLoader(train_dataset, batch_size=opt.bsz_train, shuffle=True, **dataloader_args) | |
val_dataloader = DataLoader(val_dataset, batch_size=opt.bsz_val, shuffle=False, **dataloader_args) | |
classifier = Classifier(model=opt.model, drop_path_rate=opt.model_drop_path) | |
callbacks = [ | |
ModelCheckpoint( | |
monitor="val/mae", | |
mode="min", | |
save_top_k=1, | |
save_last=True, | |
auto_insert_metric_name=False, | |
filename="epoch={epoch},mae={val/mae:.4f}", | |
), | |
ModelCheckpoint( | |
monitor="val/mae_ema", | |
mode="min", | |
save_top_k=1, | |
save_last=False, | |
auto_insert_metric_name=False, | |
filename="epoch={epoch},mae-ema={val/mae_ema:.4f}", | |
) | |
] | |
trainer = lightning.Trainer(**opt.trainer, callbacks=callbacks) | |
trainer.fit(classifier, train_dataloader, val_dataloader) | |
def get_config(): | |
config = ConfigDict() | |
# model args | |
config.model = convnextv2_tiny | |
config.model_drop_path = 0.1 | |
config.resume = "" | |
config.seed = 42 | |
config.bsz_train = 32 | |
config.bsz_val = 2 | |
# dataset args | |
dataset = config.dataset = ConfigDict() | |
dataset.path = "/tmp/classifier-dataset" | |
dataset.train_split = 0.9999 | |
dataset.img_size = 768 | |
# training args | |
trainer = config.trainer = ConfigDict() | |
trainer.max_epochs = 100 | |
trainer.accumulate_grad_batches = 4 | |
trainer.accelerator = "gpu" | |
trainer.precision = 32 | |
trainer.benchmark = True | |
trainer.log_every_n_steps = 1 | |
trainer.val_check_interval = 0.025 | |
trainer.strategy = "auto" #"ddp_find_unused_parameters_false" | |
return config | |
if __name__ == "__main__": | |
main(get_config()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment