Skip to content

Instantly share code, notes, and snippets.

@Mikubill
Last active November 1, 2023 07:29
Show Gist options
  • Save Mikubill/603a0c095a252a76587d716954096406 to your computer and use it in GitHub Desktop.
Save Mikubill/603a0c095a252a76587d716954096406 to your computer and use it in GitHub Desktop.
Use NAFNet to upscale images in batch.
#
# NAFNet Image Upscaler
# Mikubill, MIT License
#
# Usage:
# python nafnet.py -i [path_to_images]
#
# Parameters:
# -m, --model: Specifies the model to use. The default is NAFNet-REDS-width64.
# -i, --input: Path to the directory containing images to upscale. This parameter is required.
# -b, --batch-size: The number of patches to process in a single batch. Default is 4.
# --scale: The upscale factor. Default is 1.5.
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
# From: https://github.com/megvii-research/NAFNet/blob/main/basicsr/models/archs/local_arch.py
class AvgPool2d(nn.Module):
def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None):
super().__init__()
self.kernel_size = kernel_size
self.base_size = base_size
self.auto_pad = auto_pad
# only used for fast implementation
self.fast_imp = fast_imp
self.rs = [5, 4, 3, 2, 1]
self.max_r1 = self.rs[0]
self.max_r2 = self.rs[0]
self.train_size = train_size
def extra_repr(self) -> str:
return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format(
self.kernel_size, self.base_size, self.kernel_size, self.fast_imp
)
def forward(self, x):
if self.kernel_size is None and self.base_size:
train_size = self.train_size
if isinstance(self.base_size, int):
self.base_size = (self.base_size, self.base_size)
self.kernel_size = list(self.base_size)
self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2]
self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1]
# only used for fast implementation
self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2])
self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1])
if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1):
return F.adaptive_avg_pool2d(x, 1)
if self.fast_imp: # Non-equivalent implementation but faster
h, w = x.shape[2:]
if self.kernel_size[0] >= h and self.kernel_size[1] >= w:
out = F.adaptive_avg_pool2d(x, 1)
else:
r1 = [r for r in self.rs if h % r == 0][0]
r2 = [r for r in self.rs if w % r == 0][0]
# reduction_constraint
r1 = min(self.max_r1, r1)
r2 = min(self.max_r2, r2)
s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2)
n, c, h, w = s.shape
k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2)
out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2)
out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2))
else:
n, c, h, w = x.shape
s = x.cumsum(dim=-1).cumsum_(dim=-2)
s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience
k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1])
s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:]
out = s4 + s1 - s2 - s3
out = out / (k1 * k2)
if self.auto_pad:
n, c, h, w = x.shape
_h, _w = out.shape[2:]
# print(x.shape, self.kernel_size)
pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2)
out = torch.nn.functional.pad(out, pad2d, mode='replicate')
return out
def replace_layers(model, base_size, train_size, fast_imp, **kwargs):
for n, m in model.named_children():
if len(list(m.children())) > 0:
## compound module, go inside it
replace_layers(m, base_size, train_size, fast_imp, **kwargs)
if isinstance(m, nn.AdaptiveAvgPool2d):
pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size)
assert m.output_size == 1
setattr(model, n, pool)
'''
ref.
@article{chu2021tlsc,
title={Revisiting Global Statistics Aggregation for Improving Image Restoration},
author={Chu, Xiaojie and Chen, Liangyu and and Chen, Chengpeng and Lu, Xin},
journal={arXiv preprint arXiv:2112.04491},
year={2021}
}
'''
class LocalBase():
def convert(self, *args, train_size, **kwargs):
replace_layers(self, *args, train_size=train_size, **kwargs)
imgs = torch.rand(train_size)
with torch.no_grad():
self.forward(imgs)
# From: https://github.com/megvii-research/NAFNet/blob/main/basicsr/models/archs/arch_util.py
class LayerNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias, eps):
ctx.eps = eps
N, C, H, W = x.size()
mu = x.mean(1, keepdim=True)
var = (x - mu).pow(2).mean(1, keepdim=True)
y = (x - mu) / (var + eps).sqrt()
ctx.save_for_backward(y, var, weight)
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
return y
@staticmethod
def backward(ctx, grad_output):
eps = ctx.eps
N, C, H, W = grad_output.size()
y, var, weight = ctx.saved_variables
g = grad_output * weight.view(1, C, 1, 1)
mean_g = g.mean(dim=1, keepdim=True)
mean_gy = (g * y).mean(dim=1, keepdim=True)
gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
dim=0), None
class LayerNorm2d(nn.Module):
def __init__(self, channels, eps=1e-6):
super(LayerNorm2d, self).__init__()
self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
self.eps = eps
def forward(self, x):
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
# From: https://github.com/megvii-research/NAFNet/blob/main/basicsr/models/archs/NAFNet_arch.py
class SimpleGate(nn.Module):
def forward(self, x):
x1, x2 = x.chunk(2, dim=1)
return x1 * x2
class NAFBlock(nn.Module):
def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
super().__init__()
dw_channel = c * DW_Expand
self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
bias=True)
self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
# Simplified Channel Attention
self.sca = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
groups=1, bias=True),
)
# SimpleGate
self.sg = SimpleGate()
ffn_channel = FFN_Expand * c
self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
self.norm1 = LayerNorm2d(c)
self.norm2 = LayerNorm2d(c)
self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
def forward(self, inp):
x = inp
x = self.norm1(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.sg(x)
x = x * self.sca(x)
x = self.conv3(x)
x = self.dropout1(x)
y = inp + x * self.beta
x = self.conv4(self.norm2(y))
x = self.sg(x)
x = self.conv5(x)
x = self.dropout2(x)
return y + x * self.gamma
class NAFNet(nn.Module):
def __init__(
self,
img_channel=3,
width=16,
middle_blk_num=1,
enc_blk_nums=[],
dec_blk_nums=[]
):
super().__init__()
self.intro = nn.Conv2d(
in_channels=img_channel,
out_channels=width,
kernel_size=3,
padding=1,
stride=1,
groups=1,
bias=True
)
self.ending = nn.Conv2d(
in_channels=width,
out_channels=img_channel,
kernel_size=3,
padding=1,
stride=1,
groups=1,
bias=True
)
self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()
self.middle_blks = nn.ModuleList()
self.ups = nn.ModuleList()
self.downs = nn.ModuleList()
chan = width
for num in enc_blk_nums:
self.encoders.append(
nn.Sequential(
*[NAFBlock(chan) for _ in range(num)]
)
)
self.downs.append(
nn.Conv2d(chan, 2*chan, 2, 2)
)
chan = chan * 2
self.middle_blks = \
nn.Sequential(
*[NAFBlock(chan) for _ in range(middle_blk_num)]
)
for num in dec_blk_nums:
self.ups.append(
nn.Sequential(
nn.Conv2d(chan, chan * 2, 1, bias=False),
nn.PixelShuffle(2)
)
)
chan = chan // 2
self.decoders.append(
nn.Sequential(
*[NAFBlock(chan) for _ in range(num)]
)
)
self.padder_size = 2 ** len(self.encoders)
def forward(self, inp):
B, C, H, W = inp.shape
inp = self.check_image_size(inp)
x = self.intro(inp)
encs = []
for encoder, down in zip(self.encoders, self.downs):
x = encoder(x)
encs.append(x)
x = down(x)
x = self.middle_blks(x)
for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
x = up(x)
x = x + enc_skip
x = decoder(x)
x = self.ending(x)
x = x + inp
return x[:, :, :H, :W]
def check_image_size(self, x):
_, _, h, w = x.size()
mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
return x
class NAFNetLocal(LocalBase, NAFNet):
def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs):
LocalBase.__init__(self)
NAFNet.__init__(self, *args, **kwargs)
N, C, H, W = train_size
base_size = (int(H * 1.5), int(W * 1.5))
self.eval()
with torch.no_grad():
self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp)
import threading
import argparse
import os
import numpy as np
import torch.nn as nn
import torch
from queue import Queue
from pathlib import Path
from PIL import Image
from tqdm.auto import tqdm
from torchvision.transforms import ToTensor
model_list = {
'NAFNet-REDS-width64': {
"width": 64,
"enc_blk_nums": [1, 1, 1, 28],
"middle_blk_num": 1,
"dec_blk_nums": [1, 1, 1, 1],
},
'NAFNet-SIDD-width64': {
"width": 64,
"enc_blk_nums": [2, 2, 4, 8],
"middle_blk_num": 12,
"dec_blk_nums": [2, 2, 2, 2],
},
"NAFNet-GoPro-width64": {
"width": 64,
"enc_blk_nums": [1, 1, 1, 28],
"middle_blk_num": 1,
"dec_blk_nums": [1, 1, 1, 1],
},
}
# Optimized ImageRestoration class
class ImageRestoration(nn.Module):
def __init__(self, model_name='NAFNet-REDS-width64', tile_size=768, tile_overlap=32, bar=None):
super().__init__()
if model_name not in model_list:
raise ValueError(f'Invalid model name: {model_name}')
config = model_list[model_name]
state_dict = torch.hub.load_state_dict_from_url(
f'https://huggingface.co/nyanko7/nafnet-models/resolve/main/{model_name}.pth',
model_dir=".",
)
if 'params' in state_dict:
state_dict = state_dict['params']
net = NAFNet(**config)
net.load_state_dict(state_dict)
net = net.cuda().to(torch.float32)
try:
torch.compile(net, fullgraph=True, mode="max-autotune")
except Exception as e:
print(f"Skip torch.compile: {e}")
self.tile_size = tile_size
self.tile_overlap = tile_overlap
self.model = net
self.bar = bar
@torch.no_grad()
def forward(self, img_batch, batch_size=4): # change img to img_batch
# the shape of img_batch => torch.stack[batch_size, c, h, w]
canvas, counter, all_patch, all_idx = [], [], [], []
tile = min(self.tile_size, *(img.shape[-1] for img in img_batch), *(img.shape[-2] for img in img_batch))
stride = tile - self.tile_overlap
for i, img in enumerate(img_batch):
b, c, h, w = img.shape
canvas.append(torch.zeros(b, c, h, w, dtype=img.dtype, device=img.device))
counter.append(torch.zeros(b, c, h, w, dtype=img.dtype, device=img.device))
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
all_patch.append(in_patch)
all_idx.append((i, h_idx, w_idx))
for j in range(0, len(all_patch), batch_size):
out_patch = self.model(torch.cat(all_patch[j:j+batch_size]))
for idx, (p_idx, h_idx, w_idx) in enumerate(all_idx[j:j+batch_size]):
out_patch_mask = torch.ones_like(out_patch[idx])
canvas[p_idx][..., h_idx: h_idx + tile, w_idx: w_idx + tile].add_(out_patch[idx])
counter[p_idx][..., h_idx: h_idx + tile, w_idx: w_idx + tile].add_(out_patch_mask)
if self.bar is not None:
self.bar.set_description_str(f"Processing Images (blk {j}/{len(all_patch)})")
for i in range(len(canvas)):
canvas[i].div_(counter[i])
return canvas
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="NAFNet-REDS-width64")
parser.add_argument("-i", "--input", "--path", type=str, default=None, required=True)
parser.add_argument("-b", "--batch-size", type=int, default=4)
parser.add_argument("--postfix", type=str, default="_upscaled")
parser.add_argument("--scale", type=float, default=1.5)
parser.add_argument("--ir-first", action="store_true", default=False)
parser.add_argument("--tile-size", type=int, default=768)
parser.add_argument("--file-batch-size", type=int, default=4)
args = parser.parse_args()
return args
def resize_and_save(queue, ir_first):
while True:
fo, target_w, target_h, restored = queue.get()
if not ir_first:
restored = restored.resize((target_w, target_h), Image.Resampling.LANCZOS)
restored.save(fo)
queue.task_done()
def main():
args = parse_args()
# iter all files
file_to_process = []
for file in Path(args.input).glob("**/*"):
if file.is_file() and file.suffix in [".jpg", ".png", ".jpeg", ".webp"]:
file_to_process.append((file, file.parent / (file.stem + args.postfix + file.suffix)))
batch_idx = []
batch = []
bar = tqdm(total=len(file_to_process) ,desc=f"Processing Images")
irmodel = ImageRestoration(args.model, tile_size=args.tile_size, bar=bar)
queue = Queue()
thr = threading.Thread(target=resize_and_save, args=(queue, args.ir_first), daemon=True)
thr.start()
for idx, (fi, fo) in enumerate(file_to_process):
img = Image.open(fi).convert('RGB')
target_h, target_w = int(img.height * args.scale), int(img.width * args.scale)
if not args.ir_first:
img = img.resize((target_w, target_h), Image.Resampling.LANCZOS)
img = ToTensor()(img)
img = img.unsqueeze(0).cuda().to(torch.bfloat16)
batch.append(img)
batch_idx.append([fo, target_w, target_h])
if len(batch) == args.file_batch_size or idx == len(file_to_process) - 1:
result = irmodel(batch, batch_size=args.batch_size)
for (fo, target_w, target_h), restored in zip(batch_idx, result):
out = restored.squeeze(0).permute(1, 2, 0).float().cpu()
out = torch.clamp(out, 0, 1)
img_array = (out * 255).numpy().astype(np.uint8)
restored = Image.fromarray(img_array)
queue.put((fo, target_w, target_h, restored))
batch_idx = []
batch = []
bar.update(args.file_batch_size)
while not queue.empty():
time.sleep(0.1)
bar.set_description_str(f"Saving Images {len(file_to_process)-queue.qsize()}")
queue.join()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment