Created
February 29, 2012 11:30
-
-
Save attilaolah/1940208 to your computer and use it in GitHub Desktop.
Fast image comparison with Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import Image | |
import Levenshtein | |
class BWImageCompare(object): | |
"""Compares two images (b/w).""" | |
_pixel = 255 | |
_colour = False | |
def __init__(self, imga, imgb, maxsize=64): | |
"""Save a copy of the image objects.""" | |
sizea, sizeb = imga.size, imgb.size | |
newx = min(sizea[0], sizeb[0], maxsize) | |
newy = min(sizea[1], sizeb[1], maxsize) | |
# Rescale to a common size: | |
imga = imga.resize((newx, newy), Image.BICUBIC) | |
imgb = imgb.resize((newx, newy), Image.BICUBIC) | |
if not self._colour: | |
# Store the images in B/W Int format | |
imga = imga.convert('I') | |
imgb = imgb.convert('I') | |
self._imga = imga | |
self._imgb = imgb | |
# Store the common image size | |
self.x, self.y = newx, newy | |
def _img_int(self, img): | |
"""Convert an image to a list of pixels.""" | |
x, y = img.size | |
for i in xrange(x): | |
for j in xrange(y): | |
yield img.getpixel((i, j)) | |
@property | |
def imga_int(self): | |
"""Return a tuple representing the first image.""" | |
if not hasattr(self, '_imga_int'): | |
self._imga_int = tuple(self._img_int(self._imga)) | |
return self._imga_int | |
@property | |
def imgb_int(self): | |
"""Return a tuple representing the second image.""" | |
if not hasattr(self, '_imgb_int'): | |
self._imgb_int = tuple(self._img_int(self._imgb)) | |
return self._imgb_int | |
@property | |
def mse(self): | |
"""Return the mean square error between the two images.""" | |
if not hasattr(self, '_mse'): | |
tmp = sum((a-b)**2 for a, b in zip(self.imga_int, self.imgb_int)) | |
self._mse = float(tmp) / self.x / self.y | |
return self._mse | |
@property | |
def psnr(self): | |
"""Calculate the peak signal-to-noise ratio.""" | |
if not hasattr(self, '_psnr'): | |
self._psnr = 20 * math.log(self._pixel / math.sqrt(self.mse), 10) | |
return self._psnr | |
@property | |
def nrmsd(self): | |
"""Calculate the normalized root mean square deviation.""" | |
if not hasattr(self, '_nrmsd'): | |
self._nrmsd = math.sqrt(self.mse) / self._pixel | |
return self._nrmsd | |
@property | |
def levenshtein(self): | |
"""Calculate the Levenshtein distance.""" | |
if not hasattr(self, '_lv'): | |
stra = ''.join((chr(x) for x in self.imga_int)) | |
strb = ''.join((chr(x) for x in self.imgb_int)) | |
lv = Levenshtein.distance(stra, strb) | |
self._lv = float(lv) / self.x / self.y | |
return self._lv | |
class ImageCompare(BWImageCompare): | |
"""Compares two images (colour).""" | |
_pixel = 255 ** 3 | |
_colour = True | |
def _img_int(self, img): | |
"""Convert an image to a list of pixels.""" | |
x, y = img.size | |
for i in xrange(x): | |
for j in xrange(y): | |
pixel = img.getpixel((i, j)) | |
yield pixel[0] | (pixel[1]<<8) | (pixel[2]<<16) | |
@property | |
def levenshtein(self): | |
"""Calculate the Levenshtein distance.""" | |
if not hasattr(self, '_lv'): | |
stra_r = ''.join((chr(x>>16) for x in self.imga_int)) | |
strb_r = ''.join((chr(x>>16) for x in self.imgb_int)) | |
lv_r = Levenshtein.distance(stra_r, strb_r) | |
stra_g = ''.join((chr((x>>8)&0xff) for x in self.imga_int)) | |
strb_g = ''.join((chr((x>>8)&0xff) for x in self.imgb_int)) | |
lv_g = Levenshtein.distance(stra_g, strb_g) | |
stra_b = ''.join((chr(x&0xff) for x in self.imga_int)) | |
strb_b = ''.join((chr(x&0xff) for x in self.imgb_int)) | |
lv_b = Levenshtein.distance(stra_b, strb_b) | |
self._lv = (lv_r + lv_g + lv_b) / 3. / self.x / self.y | |
return self._lv | |
class FuzzyImageCompare(object): | |
"""Compares two images based on the previous comparison values.""" | |
def __init__(self, imga, imgb, lb=1, tol=15): | |
"""Store the images in the instance.""" | |
self._imga, self._imgb, self._lb, self._tol = imga, imgb, lb, tol | |
def compare(self): | |
"""Run all the comparisons.""" | |
if hasattr(self, '_compare'): | |
return self._compare | |
lb, i = self._lb, 2 | |
diffs = { | |
'levenshtein': [], | |
'nrmsd': [], | |
'psnr': [], | |
} | |
stop = { | |
'levenshtein': False, | |
'nrmsd': False, | |
'psnr': False, | |
} | |
while not all(stop.values()): | |
cmp = ImageCompare(self._imga, self._imgb, i) | |
diff = diffs['levenshtein'] | |
if len(diff) >= lb+2 and \ | |
abs(diff[-1] - diff[-lb-1]) <= abs(diff[-lb-1] - diff[-lb-2]): | |
stop['levenshtein'] = True | |
else: | |
diff.append(cmp.levenshtein) | |
diff = diffs['nrmsd'] | |
if len(diff) >= lb+2 and \ | |
abs(diff[-1] - diff[-lb-1]) <= abs(diff[-lb-1] - diff[-lb-2]): | |
stop['nrmsd'] = True | |
else: | |
diff.append(cmp.nrmsd) | |
diff = diffs['psnr'] | |
if len(diff) >= lb+2 and \ | |
abs(diff[-1] - diff[-lb-1]) <= abs(diff[-lb-1] - diff[-lb-2]): | |
stop['psnr'] = True | |
else: | |
try: | |
diff.append(cmp.psnr) | |
except ZeroDivisionError: | |
diff.append(-1) # to indicate that the images are identical | |
i *= 2 | |
self._compare = { | |
'levenshtein': 100 - diffs['levenshtein'][-1] * 100, | |
'nrmsd': 100 - diffs['nrmsd'][-1] * 100, | |
'psnr': diffs['psnr'][-1] == -1 and 100.0 or diffs['psnr'][-1], | |
} | |
return self._compare | |
def similarity(self): | |
"""Try to calculate the image similarity.""" | |
cmp = self.compare() | |
lnrmsd = (cmp['levenshtein'] + cmp['nrmsd']) / 2 | |
return lnrmsd | |
return min(lnrmsd * cmp['psnr'] / self._tol, 100.0) # TODO: fix psnr! | |
if __name__ == '__main__': | |
import sys | |
if len(sys.argv) < 3: | |
print 'usage: %s image-file-1.jpg image-file-2.jpg ...' % sys.argv[0] | |
sys.exit() | |
tot = len(sys.argv) - 1 | |
tot = (tot ** 2 - tot) / 2 | |
print 'Comparing %d images:' % tot | |
images = {} | |
for img in sys.argv[1:]: | |
images[img] = Image.open(img) | |
results, i = {}, 1 | |
for namea, imga in images.items(): | |
for nameb, imgb in images.items(): | |
if namea == nameb or (nameb, namea) in results: | |
continue | |
print ' * %2d / %2d:' % (i, tot), | |
print namea, nameb, '...', | |
cmp = FuzzyImageCompare(imga, imgb) | |
sim = cmp.similarity() | |
results[(namea, nameb)] = sim | |
print '%.2f %%' % sim | |
i += 1 | |
res = max(results.values()) | |
imgs = [k for k, v in results.iteritems() if v == res][0] | |
print 'Most similar images: %s %s (%.2f %%)' % (imgs[0], imgs[1], res) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment