Created
August 28, 2018 16:24
-
-
Save charliememory/91a62ff0a102bc0bfe715711fd5c538d to your computer and use it in GitHub Desktop.
tensorflow implement of Multiscale SSIM
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#################### MS_SSIM Loss ##################### | |
## ref code: https://stackoverflow.com/questions/39051451/ssim-ms-ssim-for-tensorflow | |
def _tf_fspecial_gauss(size, sigma): | |
"""Function to mimic the 'fspecial' gaussian MATLAB function | |
""" | |
x_data, y_data = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] | |
x_data = np.expand_dims(x_data, axis=-1) | |
x_data = np.expand_dims(x_data, axis=-1) | |
y_data = np.expand_dims(y_data, axis=-1) | |
y_data = np.expand_dims(y_data, axis=-1) | |
x = tf.constant(x_data, dtype=tf.float32) | |
y = tf.constant(y_data, dtype=tf.float32) | |
g = tf.exp(-((x**2 + y**2)/(2.0*sigma**2))) | |
return g / tf.reduce_sum(g) | |
def tf_ssim(img1, img2, cs_map=False, mean_metric=True, size=11, sigma=1.5): | |
window = _tf_fspecial_gauss(size, sigma) # window shape [size, size] | |
K1 = 0.01 | |
K2 = 0.03 | |
L = 1 # depth of image (255 in case the image has a differnt scale) | |
C1 = (K1*L)**2 | |
C2 = (K2*L)**2 | |
mu1 = tf.nn.conv2d(img1, window, strides=[1,1,1,1], padding='VALID') | |
mu2 = tf.nn.conv2d(img2, window, strides=[1,1,1,1],padding='VALID') | |
mu1_sq = mu1*mu1 | |
mu2_sq = mu2*mu2 | |
mu1_mu2 = mu1*mu2 | |
sigma1_sq = tf.nn.conv2d(img1*img1, window, strides=[1,1,1,1],padding='VALID') - mu1_sq | |
sigma2_sq = tf.nn.conv2d(img2*img2, window, strides=[1,1,1,1],padding='VALID') - mu2_sq | |
sigma12 = tf.nn.conv2d(img1*img2, window, strides=[1,1,1,1],padding='VALID') - mu1_mu2 | |
if cs_map: | |
value = (((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)* | |
(sigma1_sq + sigma2_sq + C2)), | |
(2.0*sigma12 + C2)/(sigma1_sq + sigma2_sq + C2)) | |
else: | |
value = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)* | |
(sigma1_sq + sigma2_sq + C2)) | |
if mean_metric: | |
value = tf.reduce_mean(value) | |
return value | |
def tf_ms_ssim(img1, img2, mean_metric=True, level=5): | |
weight = tf.constant([0.0448, 0.2856, 0.3001, 0.2363, 0.1333], dtype=tf.float32) | |
mssim = [] | |
mcs = [] | |
for l in range(level): | |
ssim_map, cs_map = tf_ssim(img1, img2, cs_map=True, mean_metric=False) | |
mssim.append(tf.reduce_mean(ssim_map)) | |
mcs.append(tf.reduce_mean(cs_map)) | |
filtered_im1 = tf.nn.avg_pool(img1, [1,2,2,1], [1,2,2,1], padding='SAME') | |
filtered_im2 = tf.nn.avg_pool(img2, [1,2,2,1], [1,2,2,1], padding='SAME') | |
img1 = filtered_im1 | |
img2 = filtered_im2 | |
# list to tensor of dim D+1 | |
mssim = tf.stack(mssim, axis=0) | |
mcs = tf.stack(mcs, axis=0) | |
value = (tf.reduce_prod(mcs[0:level-1]**weight[0:level-1])* | |
(mssim[level-1]**weight[level-1])) | |
if mean_metric: | |
value = tf.reduce_mean(value) | |
return value | |
########################################################### |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@smoein see the solution @aakash30jan contributed to the following Github issue: tensorflow/tensorflow#33840.