Skip to content

Instantly share code, notes, and snippets.

@mcdickenson
Last active August 17, 2018 15:49
Show Gist options
  • Save mcdickenson/40722e0ca7cbb5440f3ac320228b047f to your computer and use it in GitHub Desktop.
Save mcdickenson/40722e0ca7cbb5440f3ac320228b047f to your computer and use it in GitHub Desktop.
Image Histogram Matching with Polynomials
import numpy as np
import matplotlib.pyplot as plt
import operator
import os
from PIL import Image
# see https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x
def hist_match(source, template, polynomial_order=4):
"""
Adjust the pixel values of one image so that its histogram matches another
:param source: Image to transform (numpy array of size MxNx3)
:param template: Image to match (numpy array, expected to be 3-channel but can be larger than source)
"""
source_copy = np.copy(source)
template_copy = np.copy(template)
train_hist = np.zeros((3, 256), dtype='int')
target_hist = np.zeros((3, 256), dtype='int')
for channel_ix in range(3):
source = source_copy[:, :, channel_ix]
template = template_copy[:, :, channel_ix]
oldshape = source.shape
# crop template to source shape
h, w = oldshape[:2]
template = crop_center(template, (h, w))
train_hist[channel_ix, :] += count_pixel_values(source)
target_hist[channel_ix, :] += count_pixel_values(template)
# convert histograms back to values for polynomial fitting
coefficients = {}
for channel_ix, channel in enumerate(['polyR', 'polyG', 'polyB']):
train_values = []
target_values = []
for value_ix in range(256):
train_values += [value_ix] * train_hist[channel_ix, value_ix]
target_values += [value_ix] * target_hist[channel_ix, value_ix]
assert len(train_values) == len(target_values)
coefficients[channel] = np.polyfit(train_values, target_values, polynomial_order)
matched = color_transform_polynoms(source_copy, coefficients['polyR'], coefficients['polyG'], coefficients['polyB'])
return matched
def count_pixel_values(ary):
"""
Helper for counting pixel values in a 256-color image
"""
counts, _ = np.histogram(ary, bins=range(257))
return counts
def color_transform_polynoms(img, polyR, polyG, polyB):
"""
Performs a color transform on the rgb channels of a numpy array
:param img: a numpy array, rgb color space
:param polyR: polynom defining mapping on Red Channel, array of values see numpy.polynomial
:param polyG: polynom defining mapping on Green Channel, array of values see numpy.polynomial
:param polyB: polynom defining mapping on Blue Channel, array of values see numpy.polynomial
:return: a numpy image array, RGB
"""
mapRGB = np.zeros((3, 256))
mapRGB[0] = np.polyval(polyR, np.arange(0, 256))
mapRGB[1] = np.polyval(polyG, np.arange(0, 256))
mapRGB[2] = np.polyval(polyB, np.arange(0, 256))
mapRGB[mapRGB < 0] = 0
mapRGB[mapRGB > 255] = 255
new_img = img.copy()
new_img[..., 0] = mapRGB[0][new_img[..., 0]]
new_img[..., 1] = mapRGB[1][new_img[..., 1]]
new_img[..., 2] = mapRGB[2][new_img[..., 2]]
return np.uint8(new_img)
def ecdf(x):
"""
Convenience function for computing the empirical CDF
"""
vals, counts = np.unique(x, return_counts=True)
ecdf = np.cumsum(counts).astype(np.float64)
ecdf /= ecdf[-1]
return vals, ecdf
# helper for cropping image center
def crop_center(img, bounding):
start = tuple(map(lambda a, da: a//2-da//2, img.shape, bounding))
end = tuple(map(operator.add, start, bounding))
slices = tuple(map(slice, start, end))
return img[slices]
if __name__ == "__main__":
source_name = 'london.png'
template_name = 'bondi.jpg'
source = np.array(Image.open(source_name))
template = np.array(Image.open(template_name))
x1, y1 = ecdf(source.ravel())
x2, y2 = ecdf(template.ravel())
for polynomial_order in range(1, 6):
matched = hist_match(source, template, polynomial_order)
x3, y3 = ecdf(matched.ravel())
# Images plus histograms
fig = plt.figure()
gs = plt.GridSpec(2, 3)
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1], sharex=ax1, sharey=ax1)
ax3 = fig.add_subplot(gs[0, 2], sharex=ax1, sharey=ax1)
ax4 = fig.add_subplot(gs[1, :])
for aa in (ax1, ax2, ax3):
aa.set_axis_off()
ax1.imshow(source)
ax1.set_title('Source')
ax2.imshow(template)
ax2.set_title('Template')
ax3.imshow(matched)
ax3.set_title('Matched (order={})'.format(polynomial_order))
ax4.plot(x1, y1 * 100, '-r', lw=3, label='Source')
ax4.plot(x2, y2 * 100, '-k', lw=3, label='Template')
ax4.plot(x3, y3 * 100, '--r', lw=3, label='Matched')
ax4.set_xlim(x1[0], x1[-1])
ax4.set_xlabel('Pixel value')
ax4.set_ylabel('Cumulative %')
ax4.legend(loc='best')
fig.savefig('fig{}.png'.format(polynomial_order))
# Larger histograms
fig = plt.figure()
gs = plt.GridSpec(1, 1)
ax1 = fig.add_subplot(gs[0, 0])
ax1.plot(x1, y1 * 100, '-r', lw=3, label='Source')
ax1.plot(x2, y2 * 100, '-k', lw=3, label='Template')
ax1.plot(x3, y3 * 100, '--r', lw=3, label='Matched')
ax1.set_xlim(x1[0], x1[-1])
ax1.set_xlabel('Pixel value')
ax1.set_ylabel('Cumulative %')
ax1.legend(loc='best')
fig.savefig('fig_small{}.png'.format(polynomial_order))
# Matched images
matched_pil = Image.fromarray(matched)
matched_pil.save('matched{}.png'.format(polynomial_order))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment