Skip to content

Instantly share code, notes, and snippets.

@bearpelican
Created April 15, 2020 22:35
Show Gist options
  • Save bearpelican/e93f3f17f3787825b91610c7aa0521d1 to your computer and use it in GitHub Desktop.
Save bearpelican/e93f3f17f3787825b91610c7aa0521d1 to your computer and use it in GitHub Desktop.
# AUTOGENERATED! DO NOT EDIT! File to edit: 01_metrics.ipynb (unless otherwise specified).
__all__ = ['gaussian', 'create_window', 'SSIM', 'ssim', 'psnr', 'mutual_information', 'nmi', 'metric_fastai',
'ssim_fastai', 'psnr_fastai', 'nmi_fastai']
# Cell
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from math import exp
from fastai2.vision.all import *
# Cell
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
def _ssim(img1, img2, window, window_size, channel, data_range=255., padding=False, size_average = True):
if padding:
padding = window_size // 2
mu1 = F.conv2d(img1, window, padding = int(padding), groups = channel)
mu2 = F.conv2d(img2, window, padding = int(padding), groups = channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1*mu2
sigma1_sq = F.conv2d(img1*img1, window, padding = int(padding), groups = channel) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, window, padding = int(padding), groups = channel) - mu2_sq
sigma12 = F.conv2d(img1*img2, window, padding = int(padding), groups = channel) - mu1_mu2
C1 = (0.01*data_range)**2
C2 = (0.03*data_range)**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
class SSIM(torch.nn.Module):
"Structural similarity index (SSIM) is a commonly used metric for CycleGAN experiments as it evaluates the preservation of content rather than style"
def __init__(self, window_size = 11, data_range=255, padding = False, size_average = True):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.data_range = data_range
self.padding = padding
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.data.type() == img1.data.type():
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
self.channel = channel
return _ssim(img1, img2, window, self.window_size, channel, self.data_range, padding, self.size_average)
def ssim(img1, img2, window_size = 11, data_range=255, padding = False, size_average = True):
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, data_range, padding, size_average)
# Cell
def psnr(pred, targs, data_range=255):
mse = F.mse_loss(pred, targs)
return 20 * torch.log10(data_range / torch.sqrt(mse))
# Cell
def mutual_information(img1, img2, numBins = 20):
# We compute the mutual information between img1 and img2,
# which are assumed to be grayscale images (stored as numpy arrays)
# numBins is a parameter that affects the mutual information score.
hgram, x_edges, y_edges = np.histogram2d(img1.ravel(), img2.ravel(), bins = numBins)
pxy = hgram/ float(np.sum(hgram))
px = np.sum(pxy, axis = 1) # marginal for x over y
py = np.sum(pxy, axis = 0) # marginal for y over x
px_py = px[:, None] * py[None,:] # Broadcast to multiply marginals
nzs = pxy > 0 # Only non-zero pxy values contribute to the sum
return np.sum(pxy[nzs] * np.log(pxy[nzs] / px_py[nzs]))
def nmi(img1, img2, numBins=20):
img1 = (img1[:,0,...]*0.2989 + img1[:,1,...]*0.5870 + img1[:,2,...]*0.1140).unsqueeze(1)
img2 = (img2[:,0,...]*0.2989 + img2[:,1,...]*0.5870 + img2[:,2,...]*0.1140).unsqueeze(1)
nmi = []
for i in range(img1.shape[0]):
nmi.append(mutual_information(img1.cpu().detach().numpy(),img2.cpu().detach().numpy()))
return np.mean(nmi)
# Cell
def metric_fastai(xb,yb,_,metric):
(real_A, real_B) = xb[0]
fake_A, fake_B, idt_A, idt_B = yb
return metric((real_A/2 + 0.5)*255,(fake_B/2 + 0.5)*255)
ssim_fastai = partial(metric_fastai, metric=ssim)
psnr_fastai = partial(metric_fastai, metric=psnr)
nmi_fastai = partial(metric_fastai, metric=nmi)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment