Created
June 19, 2020 17:19
-
-
Save AruniRC/562c2eb4d6d5c4a4603e66bfe942a72c to your computer and use it in GitHub Desktop.
Histogram specification demo code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import pickle | |
import json | |
import numpy as np | |
import sys | |
import matplotlib | |
matplotlib.use('Agg') | |
from matplotlib import pyplot as plt | |
import os.path as osp | |
DEBUG = True | |
def bin_scores(scores, score_bins): | |
# bins[i-1] <= x < bins[i] | |
scores_quantized = np.digitize(scores, score_bins) # bin indices | |
return score_bins[scores_quantized] # bin values | |
def match_histograms(source, template): | |
""" | |
Histogram of source data matches that of template data | |
Arguments: | |
----------- | |
source: np.ndarray (1-D) after bin quantization | |
template: np.ndarray (1-D) after bin quantization | |
Returns: | |
----------- | |
matched: np.ndarray | |
The transformed source data | |
""" | |
# get unique values, indices and counts | |
s_val, bin_idx, s_counts = np.unique(source, return_inverse=True, | |
return_counts=True) | |
t_val, t_counts = np.unique(template, return_counts=True) | |
# calculate empirical CDFs | |
s_cdf = np.cumsum(s_counts).astype(np.float64) | |
s_cdf /= s_cdf[-1] | |
t_cdf = np.cumsum(t_counts).astype(np.float64) | |
t_cdf /= t_cdf[-1] | |
# mapping: values in template's CDF closest to source's CDF | |
interp_t_val = np.interp(s_cdf, t_cdf, t_val) | |
# modify source values to match closest template CDF | |
source_matched = interp_t_val[bin_idx] | |
if DEBUG: | |
# plt.ylim([0,0.25]) | |
sm_val, sm_counts = np.unique(source_matched, return_counts=True) | |
sm_cdf = np.cumsum(sm_counts).astype(np.float64) | |
sm_cdf /= sm_cdf[-1] | |
fig_path = osp.join('hist_src_target-DEBUG.png') | |
plt.plot(s_val, s_cdf, label='CDF source', alpha=0.8) | |
plt.plot(t_val, t_cdf, label='CDF target', alpha=0.8) | |
plt.plot(sm_val, sm_cdf, label='CDF source-matched', alpha=0.8) | |
plt.title('Score histograms') | |
plt.grid() | |
plt.legend() | |
out_dir = os.path.dirname(fig_path) | |
if not osp.exists(out_dir): | |
os.makedirs(out_dir, exist_ok=True) | |
plt.savefig(fig_path, bbox_inches='tight') | |
print('Score histogram saved at: %s' % fig_path) | |
plt.close() | |
fig_path = osp.join('hist_map-DEBUG.png') | |
plt.plot(source, source_matched, 'bo', label='Map source') | |
plt.title('Score mapping') | |
plt.xlabel('Source scores') | |
plt.ylabel('Mapped scores') | |
plt.grid() | |
plt.legend() | |
out_dir = os.path.dirname(fig_path) | |
if not osp.exists(out_dir): | |
os.makedirs(out_dir, exist_ok=True) | |
plt.savefig(fig_path, bbox_inches='tight') | |
print('Score mapping saved at: %s' % fig_path) | |
plt.close() | |
return source_matched | |
if __name__ == '__main__': | |
# TESTING | |
# create some dummy data | |
data_unif = np.random.randint(0, 10, 1000) | |
data_normal = np.random.normal(5, 5, 5000) | |
data_normal = np.abs(data_normal.astype(np.int64)) | |
unif_to_normal = match_histograms(data_unif, data_normal) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment