Skip to content

Instantly share code, notes, and snippets.

@recoilme
Created October 25, 2024 19:05
Show Gist options
  • Save recoilme/9693e968e5abcbd71d43dcaa74681740 to your computer and use it in GitHub Desktop.
Save recoilme/9693e968e5abcbd71d43dcaa74681740 to your computer and use it in GitHub Desktop.
simple ESRGAN inference
import cv2
import glob
import numpy as np
import os
import torch
#download: https://github.com/xinntao/ESRGAN/blob/master/RRDBNet_arch.py
import RRDBNet_arch as arch
model_path = '/home/recoilme/forge/models/ESRGAN/ESRGAN_4x.pth'
device = torch.device('cuda') # if you want to run on CPU, change 'cuda' -> cpu
model = arch.RRDBNet(3, 3, 64, 23, gc=32)
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
model = model.to(device)
path = "8.jpg"
imgname = os.path.splitext(os.path.basename(path))[0]
print('Testing', imgname)
# read image
img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img = img.unsqueeze(0).to(device)
# inference
try:
with torch.no_grad():
output = model(img)
except Exception as error:
print('Error', error, imgname)
else:
# save image
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output = (output * 255.0).round().astype(np.uint8)
#cv2.imwrite(os.path.join("", f'{imgname}_ESRGAN.png'), output)
# Create PIL image from NumPy array
cvImg = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
pilImg = Image.fromarray(cvImg)
pilImg.save(os.path.join("", f'{imgname}_ESRGAN.jpg'), quality=97)
print('ok')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment