Created
April 15, 2020 22:35
-
-
Save bearpelican/e93f3f17f3787825b91610c7aa0521d1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# AUTOGENERATED! DO NOT EDIT! File to edit: 01_metrics.ipynb (unless otherwise specified). | |
__all__ = ['gaussian', 'create_window', 'SSIM', 'ssim', 'psnr', 'mutual_information', 'nmi', 'metric_fastai', | |
'ssim_fastai', 'psnr_fastai', 'nmi_fastai'] | |
# Cell | |
import torch | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
import numpy as np | |
from math import exp | |
from fastai2.vision.all import * | |
# Cell | |
def gaussian(window_size, sigma): | |
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) | |
return gauss/gauss.sum() | |
def create_window(window_size, channel): | |
_1D_window = gaussian(window_size, 1.5).unsqueeze(1) | |
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) | |
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) | |
return window | |
def _ssim(img1, img2, window, window_size, channel, data_range=255., padding=False, size_average = True): | |
if padding: | |
padding = window_size // 2 | |
mu1 = F.conv2d(img1, window, padding = int(padding), groups = channel) | |
mu2 = F.conv2d(img2, window, padding = int(padding), groups = channel) | |
mu1_sq = mu1.pow(2) | |
mu2_sq = mu2.pow(2) | |
mu1_mu2 = mu1*mu2 | |
sigma1_sq = F.conv2d(img1*img1, window, padding = int(padding), groups = channel) - mu1_sq | |
sigma2_sq = F.conv2d(img2*img2, window, padding = int(padding), groups = channel) - mu2_sq | |
sigma12 = F.conv2d(img1*img2, window, padding = int(padding), groups = channel) - mu1_mu2 | |
C1 = (0.01*data_range)**2 | |
C2 = (0.03*data_range)**2 | |
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) | |
if size_average: | |
return ssim_map.mean() | |
else: | |
return ssim_map.mean(1).mean(1).mean(1) | |
class SSIM(torch.nn.Module): | |
"Structural similarity index (SSIM) is a commonly used metric for CycleGAN experiments as it evaluates the preservation of content rather than style" | |
def __init__(self, window_size = 11, data_range=255, padding = False, size_average = True): | |
super(SSIM, self).__init__() | |
self.window_size = window_size | |
self.size_average = size_average | |
self.data_range = data_range | |
self.padding = padding | |
self.channel = 1 | |
self.window = create_window(window_size, self.channel) | |
def forward(self, img1, img2): | |
(_, channel, _, _) = img1.size() | |
if channel == self.channel and self.window.data.type() == img1.data.type(): | |
window = self.window | |
else: | |
window = create_window(self.window_size, channel) | |
if img1.is_cuda: | |
window = window.cuda(img1.get_device()) | |
window = window.type_as(img1) | |
self.window = window | |
self.channel = channel | |
return _ssim(img1, img2, window, self.window_size, channel, self.data_range, padding, self.size_average) | |
def ssim(img1, img2, window_size = 11, data_range=255, padding = False, size_average = True): | |
(_, channel, _, _) = img1.size() | |
window = create_window(window_size, channel) | |
if img1.is_cuda: | |
window = window.cuda(img1.get_device()) | |
window = window.type_as(img1) | |
return _ssim(img1, img2, window, window_size, channel, data_range, padding, size_average) | |
# Cell | |
def psnr(pred, targs, data_range=255): | |
mse = F.mse_loss(pred, targs) | |
return 20 * torch.log10(data_range / torch.sqrt(mse)) | |
# Cell | |
def mutual_information(img1, img2, numBins = 20): | |
# We compute the mutual information between img1 and img2, | |
# which are assumed to be grayscale images (stored as numpy arrays) | |
# numBins is a parameter that affects the mutual information score. | |
hgram, x_edges, y_edges = np.histogram2d(img1.ravel(), img2.ravel(), bins = numBins) | |
pxy = hgram/ float(np.sum(hgram)) | |
px = np.sum(pxy, axis = 1) # marginal for x over y | |
py = np.sum(pxy, axis = 0) # marginal for y over x | |
px_py = px[:, None] * py[None,:] # Broadcast to multiply marginals | |
nzs = pxy > 0 # Only non-zero pxy values contribute to the sum | |
return np.sum(pxy[nzs] * np.log(pxy[nzs] / px_py[nzs])) | |
def nmi(img1, img2, numBins=20): | |
img1 = (img1[:,0,...]*0.2989 + img1[:,1,...]*0.5870 + img1[:,2,...]*0.1140).unsqueeze(1) | |
img2 = (img2[:,0,...]*0.2989 + img2[:,1,...]*0.5870 + img2[:,2,...]*0.1140).unsqueeze(1) | |
nmi = [] | |
for i in range(img1.shape[0]): | |
nmi.append(mutual_information(img1.cpu().detach().numpy(),img2.cpu().detach().numpy())) | |
return np.mean(nmi) | |
# Cell | |
def metric_fastai(xb,yb,_,metric): | |
(real_A, real_B) = xb[0] | |
fake_A, fake_B, idt_A, idt_B = yb | |
return metric((real_A/2 + 0.5)*255,(fake_B/2 + 0.5)*255) | |
ssim_fastai = partial(metric_fastai, metric=ssim) | |
psnr_fastai = partial(metric_fastai, metric=psnr) | |
nmi_fastai = partial(metric_fastai, metric=nmi) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment