Last active
May 11, 2019 15:28
-
-
Save ata4/619b28422d288605685200a8c0edfd6b to your computer and use it in GitHub Desktop.
ESRGAN launcher with tiling support
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 | |
import sys | |
import os | |
import glob | |
import math | |
import argparse | |
import cv2 | |
import numpy as np | |
import torch | |
import architecture as arch | |
class ESRGAN: | |
def __init__(self, model_path, device, scale_factor=4, tile_size=256): | |
self.scale_factor = scale_factor | |
self.tile_size = tile_size | |
model = arch.RRDB_Net(3, 3, 64, 23, upscale=self.scale_factor) | |
model.load_state_dict(torch.load(model_path), strict=True) | |
model.eval() | |
for _, v in model.named_parameters(): | |
v.requires_grad = False | |
self.model = model.to(device) | |
self.device = device | |
def upscale(self, img): | |
img = img * 1.0 / 255 | |
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() | |
img_LR = img.unsqueeze(0).to(self.device) | |
output = self.model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy() | |
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) | |
output = (output * 255.0).round() | |
return output | |
def process(self, input_path, output_path): | |
# read input image | |
input = cv2.imread(input_path, cv2.IMREAD_COLOR) | |
width, height, depth = input.shape | |
# process small images directly without the use of tiles | |
if self.tile_size > 0 and width <= self.tile_size and height <= self.tile_size: | |
output = self.upscale(input) | |
cv2.imwrite(output_path, output) | |
return | |
# pre-allocate upscaled output image | |
output = np.zeros((width * self.scale_factor, height * self.scale_factor, depth), np.uint8) | |
tiles_x = math.ceil(width / self.tile_size) | |
tiles_y = math.ceil(height / self.tile_size) | |
for y in range(tiles_y): | |
for x in range(tiles_x): | |
# extract tile from input image | |
ofs_x = x * self.tile_size | |
ofs_y = y * self.tile_size | |
input_start_x = ofs_x | |
input_end_x = min(ofs_x + self.tile_size, width) | |
input_start_y = ofs_y | |
input_end_y = min(ofs_y + self.tile_size, height) | |
input_tile_width = input_end_x - input_start_x | |
input_tile_height = input_end_y - input_start_y | |
tile_idx = y * tiles_x + x + 1 | |
print('Tile %d/%d (x=%d y=%d %dx%d)' % (tile_idx, tiles_x * tiles_y, x, y, input_tile_width, input_tile_height), flush=True) | |
input_tile = input[input_start_x:input_end_x, input_start_y:input_end_y] | |
# upscale tile | |
output_tile = self.upscale(input_tile) | |
# put tile into output image | |
output_start_x = input_start_x * self.scale_factor | |
output_end_x = input_end_x * self.scale_factor | |
output_start_y = input_start_y * self.scale_factor | |
output_end_y = input_end_y * self.scale_factor | |
output[output_start_x:output_end_x, output_start_y:output_end_y] = output_tile | |
cv2.imwrite(output_path, output) | |
def main(): | |
parser = argparse.ArgumentParser(description='ESRGAN image upscaler with tiling support') | |
parser.add_argument('input', help='Path to input folder') | |
parser.add_argument('output', help='Path to output folder') | |
parser.add_argument('model', help='Path to model file') | |
parser.add_argument('--tilesize', type=int, metavar='N', default=256, help='size of tiles in pixels (0 = don\'t use tiles)') | |
parser.add_argument('--cpu', action='store_true', help='use CPU instead of GPU/CUDA (very slow!)') | |
args = parser.parse_args() | |
if args.cpu: | |
device = torch.device('cpu') | |
else: | |
device = torch.device('cuda') | |
input_folder = args.input | |
output_folder = args.output | |
model_path = args.model | |
print("Initializing ESRGAN using model '%s'" % os.path.basename(model_path), flush=True) | |
esrgan = ESRGAN(model_path, device, tile_size=args.tilesize) | |
for input_path in glob.glob(input_folder): | |
input_name = os.path.basename(input_path) | |
print('Upscaling', input_name, flush=True) | |
input_name = os.path.splitext(input_name)[0] | |
output_path = os.path.join(output_folder, input_name + '_esrgan.png') | |
esrgan.process(input_path, output_path) | |
if __name__ == '__main__': | |
exit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment