Last active
November 1, 2023 07:29
-
-
Save Mikubill/603a0c095a252a76587d716954096406 to your computer and use it in GitHub Desktop.
Use NAFNet to upscale images in batch.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# NAFNet Image Upscaler | |
# Mikubill, MIT License | |
# | |
# Usage: | |
# python nafnet.py -i [path_to_images] | |
# | |
# Parameters: | |
# -m, --model: Specifies the model to use. The default is NAFNet-REDS-width64. | |
# -i, --input: Path to the directory containing images to upscale. This parameter is required. | |
# -b, --batch-size: The number of patches to process in a single batch. Default is 4. | |
# --scale: The upscale factor. Default is 1.5. | |
# ------------------------------------------------------------------------ | |
# Copyright (c) 2022 megvii-model. All Rights Reserved. | |
# ------------------------------------------------------------------------ | |
import os | |
import time | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# From: https://github.com/megvii-research/NAFNet/blob/main/basicsr/models/archs/local_arch.py | |
class AvgPool2d(nn.Module): | |
def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None): | |
super().__init__() | |
self.kernel_size = kernel_size | |
self.base_size = base_size | |
self.auto_pad = auto_pad | |
# only used for fast implementation | |
self.fast_imp = fast_imp | |
self.rs = [5, 4, 3, 2, 1] | |
self.max_r1 = self.rs[0] | |
self.max_r2 = self.rs[0] | |
self.train_size = train_size | |
def extra_repr(self) -> str: | |
return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format( | |
self.kernel_size, self.base_size, self.kernel_size, self.fast_imp | |
) | |
def forward(self, x): | |
if self.kernel_size is None and self.base_size: | |
train_size = self.train_size | |
if isinstance(self.base_size, int): | |
self.base_size = (self.base_size, self.base_size) | |
self.kernel_size = list(self.base_size) | |
self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2] | |
self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1] | |
# only used for fast implementation | |
self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2]) | |
self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1]) | |
if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1): | |
return F.adaptive_avg_pool2d(x, 1) | |
if self.fast_imp: # Non-equivalent implementation but faster | |
h, w = x.shape[2:] | |
if self.kernel_size[0] >= h and self.kernel_size[1] >= w: | |
out = F.adaptive_avg_pool2d(x, 1) | |
else: | |
r1 = [r for r in self.rs if h % r == 0][0] | |
r2 = [r for r in self.rs if w % r == 0][0] | |
# reduction_constraint | |
r1 = min(self.max_r1, r1) | |
r2 = min(self.max_r2, r2) | |
s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2) | |
n, c, h, w = s.shape | |
k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2) | |
out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2) | |
out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2)) | |
else: | |
n, c, h, w = x.shape | |
s = x.cumsum(dim=-1).cumsum_(dim=-2) | |
s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience | |
k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1]) | |
s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:] | |
out = s4 + s1 - s2 - s3 | |
out = out / (k1 * k2) | |
if self.auto_pad: | |
n, c, h, w = x.shape | |
_h, _w = out.shape[2:] | |
# print(x.shape, self.kernel_size) | |
pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2) | |
out = torch.nn.functional.pad(out, pad2d, mode='replicate') | |
return out | |
def replace_layers(model, base_size, train_size, fast_imp, **kwargs): | |
for n, m in model.named_children(): | |
if len(list(m.children())) > 0: | |
## compound module, go inside it | |
replace_layers(m, base_size, train_size, fast_imp, **kwargs) | |
if isinstance(m, nn.AdaptiveAvgPool2d): | |
pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size) | |
assert m.output_size == 1 | |
setattr(model, n, pool) | |
''' | |
ref. | |
@article{chu2021tlsc, | |
title={Revisiting Global Statistics Aggregation for Improving Image Restoration}, | |
author={Chu, Xiaojie and Chen, Liangyu and and Chen, Chengpeng and Lu, Xin}, | |
journal={arXiv preprint arXiv:2112.04491}, | |
year={2021} | |
} | |
''' | |
class LocalBase(): | |
def convert(self, *args, train_size, **kwargs): | |
replace_layers(self, *args, train_size=train_size, **kwargs) | |
imgs = torch.rand(train_size) | |
with torch.no_grad(): | |
self.forward(imgs) | |
# From: https://github.com/megvii-research/NAFNet/blob/main/basicsr/models/archs/arch_util.py | |
class LayerNormFunction(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, x, weight, bias, eps): | |
ctx.eps = eps | |
N, C, H, W = x.size() | |
mu = x.mean(1, keepdim=True) | |
var = (x - mu).pow(2).mean(1, keepdim=True) | |
y = (x - mu) / (var + eps).sqrt() | |
ctx.save_for_backward(y, var, weight) | |
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) | |
return y | |
@staticmethod | |
def backward(ctx, grad_output): | |
eps = ctx.eps | |
N, C, H, W = grad_output.size() | |
y, var, weight = ctx.saved_variables | |
g = grad_output * weight.view(1, C, 1, 1) | |
mean_g = g.mean(dim=1, keepdim=True) | |
mean_gy = (g * y).mean(dim=1, keepdim=True) | |
gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) | |
return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum( | |
dim=0), None | |
class LayerNorm2d(nn.Module): | |
def __init__(self, channels, eps=1e-6): | |
super(LayerNorm2d, self).__init__() | |
self.register_parameter('weight', nn.Parameter(torch.ones(channels))) | |
self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) | |
self.eps = eps | |
def forward(self, x): | |
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) | |
# From: https://github.com/megvii-research/NAFNet/blob/main/basicsr/models/archs/NAFNet_arch.py | |
class SimpleGate(nn.Module): | |
def forward(self, x): | |
x1, x2 = x.chunk(2, dim=1) | |
return x1 * x2 | |
class NAFBlock(nn.Module): | |
def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): | |
super().__init__() | |
dw_channel = c * DW_Expand | |
self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) | |
self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, | |
bias=True) | |
self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) | |
# Simplified Channel Attention | |
self.sca = nn.Sequential( | |
nn.AdaptiveAvgPool2d(1), | |
nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, | |
groups=1, bias=True), | |
) | |
# SimpleGate | |
self.sg = SimpleGate() | |
ffn_channel = FFN_Expand * c | |
self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) | |
self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) | |
self.norm1 = LayerNorm2d(c) | |
self.norm2 = LayerNorm2d(c) | |
self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() | |
self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() | |
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) | |
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) | |
def forward(self, inp): | |
x = inp | |
x = self.norm1(x) | |
x = self.conv1(x) | |
x = self.conv2(x) | |
x = self.sg(x) | |
x = x * self.sca(x) | |
x = self.conv3(x) | |
x = self.dropout1(x) | |
y = inp + x * self.beta | |
x = self.conv4(self.norm2(y)) | |
x = self.sg(x) | |
x = self.conv5(x) | |
x = self.dropout2(x) | |
return y + x * self.gamma | |
class NAFNet(nn.Module): | |
def __init__( | |
self, | |
img_channel=3, | |
width=16, | |
middle_blk_num=1, | |
enc_blk_nums=[], | |
dec_blk_nums=[] | |
): | |
super().__init__() | |
self.intro = nn.Conv2d( | |
in_channels=img_channel, | |
out_channels=width, | |
kernel_size=3, | |
padding=1, | |
stride=1, | |
groups=1, | |
bias=True | |
) | |
self.ending = nn.Conv2d( | |
in_channels=width, | |
out_channels=img_channel, | |
kernel_size=3, | |
padding=1, | |
stride=1, | |
groups=1, | |
bias=True | |
) | |
self.encoders = nn.ModuleList() | |
self.decoders = nn.ModuleList() | |
self.middle_blks = nn.ModuleList() | |
self.ups = nn.ModuleList() | |
self.downs = nn.ModuleList() | |
chan = width | |
for num in enc_blk_nums: | |
self.encoders.append( | |
nn.Sequential( | |
*[NAFBlock(chan) for _ in range(num)] | |
) | |
) | |
self.downs.append( | |
nn.Conv2d(chan, 2*chan, 2, 2) | |
) | |
chan = chan * 2 | |
self.middle_blks = \ | |
nn.Sequential( | |
*[NAFBlock(chan) for _ in range(middle_blk_num)] | |
) | |
for num in dec_blk_nums: | |
self.ups.append( | |
nn.Sequential( | |
nn.Conv2d(chan, chan * 2, 1, bias=False), | |
nn.PixelShuffle(2) | |
) | |
) | |
chan = chan // 2 | |
self.decoders.append( | |
nn.Sequential( | |
*[NAFBlock(chan) for _ in range(num)] | |
) | |
) | |
self.padder_size = 2 ** len(self.encoders) | |
def forward(self, inp): | |
B, C, H, W = inp.shape | |
inp = self.check_image_size(inp) | |
x = self.intro(inp) | |
encs = [] | |
for encoder, down in zip(self.encoders, self.downs): | |
x = encoder(x) | |
encs.append(x) | |
x = down(x) | |
x = self.middle_blks(x) | |
for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): | |
x = up(x) | |
x = x + enc_skip | |
x = decoder(x) | |
x = self.ending(x) | |
x = x + inp | |
return x[:, :, :H, :W] | |
def check_image_size(self, x): | |
_, _, h, w = x.size() | |
mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size | |
mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size | |
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) | |
return x | |
class NAFNetLocal(LocalBase, NAFNet): | |
def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs): | |
LocalBase.__init__(self) | |
NAFNet.__init__(self, *args, **kwargs) | |
N, C, H, W = train_size | |
base_size = (int(H * 1.5), int(W * 1.5)) | |
self.eval() | |
with torch.no_grad(): | |
self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp) | |
import threading | |
import argparse | |
import os | |
import numpy as np | |
import torch.nn as nn | |
import torch | |
from queue import Queue | |
from pathlib import Path | |
from PIL import Image | |
from tqdm.auto import tqdm | |
from torchvision.transforms import ToTensor | |
model_list = { | |
'NAFNet-REDS-width64': { | |
"width": 64, | |
"enc_blk_nums": [1, 1, 1, 28], | |
"middle_blk_num": 1, | |
"dec_blk_nums": [1, 1, 1, 1], | |
}, | |
'NAFNet-SIDD-width64': { | |
"width": 64, | |
"enc_blk_nums": [2, 2, 4, 8], | |
"middle_blk_num": 12, | |
"dec_blk_nums": [2, 2, 2, 2], | |
}, | |
"NAFNet-GoPro-width64": { | |
"width": 64, | |
"enc_blk_nums": [1, 1, 1, 28], | |
"middle_blk_num": 1, | |
"dec_blk_nums": [1, 1, 1, 1], | |
}, | |
} | |
# Optimized ImageRestoration class | |
class ImageRestoration(nn.Module): | |
def __init__(self, model_name='NAFNet-REDS-width64', tile_size=768, tile_overlap=32, bar=None): | |
super().__init__() | |
if model_name not in model_list: | |
raise ValueError(f'Invalid model name: {model_name}') | |
config = model_list[model_name] | |
state_dict = torch.hub.load_state_dict_from_url( | |
f'https://huggingface.co/nyanko7/nafnet-models/resolve/main/{model_name}.pth', | |
model_dir=".", | |
) | |
if 'params' in state_dict: | |
state_dict = state_dict['params'] | |
net = NAFNet(**config) | |
net.load_state_dict(state_dict) | |
net = net.cuda().to(torch.float32) | |
try: | |
torch.compile(net, fullgraph=True, mode="max-autotune") | |
except Exception as e: | |
print(f"Skip torch.compile: {e}") | |
self.tile_size = tile_size | |
self.tile_overlap = tile_overlap | |
self.model = net | |
self.bar = bar | |
@torch.no_grad() | |
def forward(self, img_batch, batch_size=4): # change img to img_batch | |
# the shape of img_batch => torch.stack[batch_size, c, h, w] | |
canvas, counter, all_patch, all_idx = [], [], [], [] | |
tile = min(self.tile_size, *(img.shape[-1] for img in img_batch), *(img.shape[-2] for img in img_batch)) | |
stride = tile - self.tile_overlap | |
for i, img in enumerate(img_batch): | |
b, c, h, w = img.shape | |
canvas.append(torch.zeros(b, c, h, w, dtype=img.dtype, device=img.device)) | |
counter.append(torch.zeros(b, c, h, w, dtype=img.dtype, device=img.device)) | |
h_idx_list = list(range(0, h - tile, stride)) + [h - tile] | |
w_idx_list = list(range(0, w - tile, stride)) + [w - tile] | |
for h_idx in h_idx_list: | |
for w_idx in w_idx_list: | |
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] | |
all_patch.append(in_patch) | |
all_idx.append((i, h_idx, w_idx)) | |
for j in range(0, len(all_patch), batch_size): | |
out_patch = self.model(torch.cat(all_patch[j:j+batch_size])) | |
for idx, (p_idx, h_idx, w_idx) in enumerate(all_idx[j:j+batch_size]): | |
out_patch_mask = torch.ones_like(out_patch[idx]) | |
canvas[p_idx][..., h_idx: h_idx + tile, w_idx: w_idx + tile].add_(out_patch[idx]) | |
counter[p_idx][..., h_idx: h_idx + tile, w_idx: w_idx + tile].add_(out_patch_mask) | |
if self.bar is not None: | |
self.bar.set_description_str(f"Processing Images (blk {j}/{len(all_patch)})") | |
for i in range(len(canvas)): | |
canvas[i].div_(counter[i]) | |
return canvas | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-m", "--model", type=str, default="NAFNet-REDS-width64") | |
parser.add_argument("-i", "--input", "--path", type=str, default=None, required=True) | |
parser.add_argument("-b", "--batch-size", type=int, default=4) | |
parser.add_argument("--postfix", type=str, default="_upscaled") | |
parser.add_argument("--scale", type=float, default=1.5) | |
parser.add_argument("--ir-first", action="store_true", default=False) | |
parser.add_argument("--tile-size", type=int, default=768) | |
parser.add_argument("--file-batch-size", type=int, default=4) | |
args = parser.parse_args() | |
return args | |
def resize_and_save(queue, ir_first): | |
while True: | |
fo, target_w, target_h, restored = queue.get() | |
if not ir_first: | |
restored = restored.resize((target_w, target_h), Image.Resampling.LANCZOS) | |
restored.save(fo) | |
queue.task_done() | |
def main(): | |
args = parse_args() | |
# iter all files | |
file_to_process = [] | |
for file in Path(args.input).glob("**/*"): | |
if file.is_file() and file.suffix in [".jpg", ".png", ".jpeg", ".webp"]: | |
file_to_process.append((file, file.parent / (file.stem + args.postfix + file.suffix))) | |
batch_idx = [] | |
batch = [] | |
bar = tqdm(total=len(file_to_process) ,desc=f"Processing Images") | |
irmodel = ImageRestoration(args.model, tile_size=args.tile_size, bar=bar) | |
queue = Queue() | |
thr = threading.Thread(target=resize_and_save, args=(queue, args.ir_first), daemon=True) | |
thr.start() | |
for idx, (fi, fo) in enumerate(file_to_process): | |
img = Image.open(fi).convert('RGB') | |
target_h, target_w = int(img.height * args.scale), int(img.width * args.scale) | |
if not args.ir_first: | |
img = img.resize((target_w, target_h), Image.Resampling.LANCZOS) | |
img = ToTensor()(img) | |
img = img.unsqueeze(0).cuda().to(torch.bfloat16) | |
batch.append(img) | |
batch_idx.append([fo, target_w, target_h]) | |
if len(batch) == args.file_batch_size or idx == len(file_to_process) - 1: | |
result = irmodel(batch, batch_size=args.batch_size) | |
for (fo, target_w, target_h), restored in zip(batch_idx, result): | |
out = restored.squeeze(0).permute(1, 2, 0).float().cpu() | |
out = torch.clamp(out, 0, 1) | |
img_array = (out * 255).numpy().astype(np.uint8) | |
restored = Image.fromarray(img_array) | |
queue.put((fo, target_w, target_h, restored)) | |
batch_idx = [] | |
batch = [] | |
bar.update(args.file_batch_size) | |
while not queue.empty(): | |
time.sleep(0.1) | |
bar.set_description_str(f"Saving Images {len(file_to_process)-queue.qsize()}") | |
queue.join() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment