Skip to content

Instantly share code, notes, and snippets.

@AruniRC
Created June 19, 2020 17:19
Show Gist options
  • Save AruniRC/562c2eb4d6d5c4a4603e66bfe942a72c to your computer and use it in GitHub Desktop.
Save AruniRC/562c2eb4d6d5c4a4603e66bfe942a72c to your computer and use it in GitHub Desktop.
Histogram specification demo code
import os
import sys
import pickle
import json
import numpy as np
import sys
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import os.path as osp
DEBUG = True
def bin_scores(scores, score_bins):
# bins[i-1] <= x < bins[i]
scores_quantized = np.digitize(scores, score_bins) # bin indices
return score_bins[scores_quantized] # bin values
def match_histograms(source, template):
"""
Histogram of source data matches that of template data
Arguments:
-----------
source: np.ndarray (1-D) after bin quantization
template: np.ndarray (1-D) after bin quantization
Returns:
-----------
matched: np.ndarray
The transformed source data
"""
# get unique values, indices and counts
s_val, bin_idx, s_counts = np.unique(source, return_inverse=True,
return_counts=True)
t_val, t_counts = np.unique(template, return_counts=True)
# calculate empirical CDFs
s_cdf = np.cumsum(s_counts).astype(np.float64)
s_cdf /= s_cdf[-1]
t_cdf = np.cumsum(t_counts).astype(np.float64)
t_cdf /= t_cdf[-1]
# mapping: values in template's CDF closest to source's CDF
interp_t_val = np.interp(s_cdf, t_cdf, t_val)
# modify source values to match closest template CDF
source_matched = interp_t_val[bin_idx]
if DEBUG:
# plt.ylim([0,0.25])
sm_val, sm_counts = np.unique(source_matched, return_counts=True)
sm_cdf = np.cumsum(sm_counts).astype(np.float64)
sm_cdf /= sm_cdf[-1]
fig_path = osp.join('hist_src_target-DEBUG.png')
plt.plot(s_val, s_cdf, label='CDF source', alpha=0.8)
plt.plot(t_val, t_cdf, label='CDF target', alpha=0.8)
plt.plot(sm_val, sm_cdf, label='CDF source-matched', alpha=0.8)
plt.title('Score histograms')
plt.grid()
plt.legend()
out_dir = os.path.dirname(fig_path)
if not osp.exists(out_dir):
os.makedirs(out_dir, exist_ok=True)
plt.savefig(fig_path, bbox_inches='tight')
print('Score histogram saved at: %s' % fig_path)
plt.close()
fig_path = osp.join('hist_map-DEBUG.png')
plt.plot(source, source_matched, 'bo', label='Map source')
plt.title('Score mapping')
plt.xlabel('Source scores')
plt.ylabel('Mapped scores')
plt.grid()
plt.legend()
out_dir = os.path.dirname(fig_path)
if not osp.exists(out_dir):
os.makedirs(out_dir, exist_ok=True)
plt.savefig(fig_path, bbox_inches='tight')
print('Score mapping saved at: %s' % fig_path)
plt.close()
return source_matched
if __name__ == '__main__':
# TESTING
# create some dummy data
data_unif = np.random.randint(0, 10, 1000)
data_normal = np.random.normal(5, 5, 5000)
data_normal = np.abs(data_normal.astype(np.int64))
unif_to_normal = match_histograms(data_unif, data_normal)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment