Skip to content

Instantly share code, notes, and snippets.

@duhaime
Last active September 10, 2024 20:14
Show Gist options
  • Save duhaime/211365edaddf7ff89c0a36d9f3f7956c to your computer and use it in GitHub Desktop.
Save duhaime/211365edaddf7ff89c0a36d9f3f7956c to your computer and use it in GitHub Desktop.
Compare image similarity in Python using Structural Similarity, Pixel Comparisons, Wasserstein Distance (Earth Mover's Distance), and SIFT
import warnings
from skimage.measure import compare_ssim
from skimage.transform import resize
from scipy.stats import wasserstein_distance
from scipy.misc import imsave
from scipy.ndimage import imread
import numpy as np
import cv2
##
# Globals
##
warnings.filterwarnings('ignore')
# specify resized image sizes
height = 2**10
width = 2**10
##
# Functions
##
def get_img(path, norm_size=True, norm_exposure=False):
'''
Prepare an image for image processing tasks
'''
# flatten returns a 2d grayscale array
img = imread(path, flatten=True).astype(int)
# resizing returns float vals 0:255; convert to ints for downstream tasks
if norm_size:
img = resize(img, (height, width), anti_aliasing=True, preserve_range=True)
if norm_exposure:
img = normalize_exposure(img)
return img
def get_histogram(img):
'''
Get the histogram of an image. For an 8-bit, grayscale image, the
histogram will be a 256 unit vector in which the nth value indicates
the percent of the pixels in the image with the given darkness level.
The histogram's values sum to 1.
'''
h, w = img.shape
hist = [0.0] * 256
for i in range(h):
for j in range(w):
hist[img[i, j]] += 1
return np.array(hist) / (h * w)
def normalize_exposure(img):
'''
Normalize the exposure of an image.
'''
img = img.astype(int)
hist = get_histogram(img)
# get the sum of vals accumulated by each position in hist
cdf = np.array([sum(hist[:i+1]) for i in range(len(hist))])
# determine the normalization values for each unit of the cdf
sk = np.uint8(255 * cdf)
# normalize each position in the output image
height, width = img.shape
normalized = np.zeros_like(img)
for i in range(0, height):
for j in range(0, width):
normalized[i, j] = sk[img[i, j]]
return normalized.astype(int)
def earth_movers_distance(path_a, path_b):
'''
Measure the Earth Mover's distance between two images
@args:
{str} path_a: the path to an image file
{str} path_b: the path to an image file
@returns:
TODO
'''
img_a = get_img(path_a, norm_exposure=True)
img_b = get_img(path_b, norm_exposure=True)
hist_a = get_histogram(img_a)
hist_b = get_histogram(img_b)
return wasserstein_distance(hist_a, hist_b)
def structural_sim(path_a, path_b):
'''
Measure the structural similarity between two images
@args:
{str} path_a: the path to an image file
{str} path_b: the path to an image file
@returns:
{float} a float {-1:1} that measures structural similarity
between the input images
'''
img_a = get_img(path_a)
img_b = get_img(path_b)
sim, diff = compare_ssim(img_a, img_b, full=True)
return sim
def pixel_sim(path_a, path_b):
'''
Measure the pixel-level similarity between two images
@args:
{str} path_a: the path to an image file
{str} path_b: the path to an image file
@returns:
{float} a float {-1:1} that measures structural similarity
between the input images
'''
img_a = get_img(path_a, norm_exposure=True)
img_b = get_img(path_b, norm_exposure=True)
return np.sum(np.absolute(img_a - img_b)) / (height*width) / 255
def sift_sim(path_a, path_b):
'''
Use SIFT features to measure image similarity
@args:
{str} path_a: the path to an image file
{str} path_b: the path to an image file
@returns:
TODO
'''
# initialize the sift feature detector
orb = cv2.ORB_create()
# get the images
img_a = cv2.imread(path_a)
img_b = cv2.imread(path_b)
# find the keypoints and descriptors with SIFT
kp_a, desc_a = orb.detectAndCompute(img_a, None)
kp_b, desc_b = orb.detectAndCompute(img_b, None)
# initialize the bruteforce matcher
bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
# match.distance is a float between {0:100} - lower means more similar
matches = bf.match(desc_a, desc_b)
similar_regions = [i for i in matches if i.distance < 70]
if len(matches) == 0:
return 0
return len(similar_regions) / len(matches)
if __name__ == '__main__':
img_a = 'a.jpg'
img_b = 'b.jpg'
# get the similarity values
structural_sim = structural_sim(img_a, img_b)
pixel_sim = pixel_sim(img_a, img_b)
sift_sim = sift_sim(img_a, img_b)
emd = earth_movers_distance(img_a, img_b)
print(structural_sim, pixel_sim, sift_sim, emd)
@duhaime
Copy link
Author

duhaime commented Oct 17, 2019

Thanks for this research, ohjho--these are great references!

I would also encourage you to check out more modern image similarity techniques, like using pretrained neural networks (e.g. Inception) or training your own Autoencoder to measure image similarity.

These latter techniques can capture much more flexible notions of image similarity than the older methods shown above!

@clerein
Copy link

clerein commented Dec 5, 2019

@FLS-BP-US
Copy link

Can you updated from "scipy.ndimage import imread"?

@qathom
Copy link

qathom commented Jul 29, 2020

You can replace this import scipy.ndimage import imread by from imageio import imread

Then replace get_img() by:

from imageio import imread

def get_img(path, norm_size=True, norm_exposure=False):
  '''
  Prepare an image for image processing tasks
  '''
  # flatten returns a 2d grayscale array
  img = imread(path, as_gray=True).astype(int)

  # resizing returns float vals 0:255; convert to ints for downstream tasks
  if norm_size:
    img = resize(img, (height, width), anti_aliasing=True, preserve_range=True)
  if norm_exposure:
    img = normalize_exposure(img)
  return img

@mufticsanjin
Copy link

Thank you very much for the code. I would like to use it to compare across formats (ie TIFFs with JPGs so that i can see if the files are versions of each other). I seem to run into an issue when working with tiffs and I wanted to know if there was a way around it:

~/env/lib/python3.7/site-packages/imageio/plugins/tifffile.py in _open(self, **kwargs)
    224                 self._f = None
    225                 f = self.request.get_file()
--> 226             self._tf = _tifffile.TiffFile(f, **kwargs)
    227 
    228             # metadata is the same for all images

~/env/lib/python3.7/site-packages/tifffile/tifffile.py in __init__(self, arg, name, offset, size, multifile, _useframes, _master, **kwargs)
   2189                         setattr(self, key, bool(value))
   2190                 else:
-> 2191                     raise TypeError(f'unexpected keyword argument: {key}')
   2192 
   2193         fh = FileHandle(arg, mode='rb', name=name, offset=offset, size=size)

TypeError: unexpected keyword argument: as_gray

Many thanks

@duhaime
Copy link
Author

duhaime commented Feb 15, 2023

@mufticsanjin I'd try an earlier scipy version, maybe 1.0.0 or so

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment