Created
August 5, 2022 02:49
-
-
Save michaelchughes/f39ef71c352dc5f784ec45fd42251cac to your computer and use it in GitHub Desktop.
Create a function that will monotonically transform the intensity values of images from a "target" distribution to match a desired "source" distribution
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import scipy.stats | |
import matplotlib.pyplot as plt | |
from statsmodels.distributions.empirical_distribution import ECDF | |
def create_transform_func_to_match_source(target_x_ND, src_x_MD, n_quantiles=1000): | |
''' | |
Returns | |
------- | |
transform : func | |
Maps a given target arr of any shape into a new arr of same shape | |
Notes | |
----- | |
1) Map each target value to its empirical quantile (value in 0-1) | |
2) Then map that value to a source x | |
''' | |
target_dims = target_x_ND.shape | |
src_dims = src_x_MD.shape | |
assert target_dims[1:] == src_dims[1:] | |
target_ecdf = ECDF(target_x_ND.flatten()) | |
src_qs_Q = np.linspace(0, 1, n_quantiles, endpoint=True) | |
srcx_B = src_x_MD.reshape((np.prod(src_dims),)) | |
x_quantiles_Q = np.nanpercentile(srcx_B, src_qs_Q * 100) | |
x_quantiles_Q = np.sort(x_quantiles_Q) # in increasing order | |
def transform(targetx_ND): | |
dims = targetx_ND.shape | |
qs_A = target_ecdf(targetx_ND.reshape((np.prod(dims),))) | |
ids_A = np.searchsorted(src_qs_Q, qs_A) | |
return x_quantiles_Q[ids_A].reshape(dims) | |
return transform | |
def make_new_fig_with_subplots(): | |
_, axes = plt.subplot_mosaic( | |
''' | |
ABCD | |
EFGH | |
IJKL | |
MNOP | |
XXXX | |
XXXX | |
''' | |
) | |
hist_key = 'X' | |
im_keys = [k for k in sorted(axes.keys()) if k != hist_key] | |
return axes, im_keys, hist_key | |
if __name__ == '__main__': | |
N = 100 | |
D = 32 | |
src_dist = scipy.stats.norm(0.3, 0.05) | |
target_dist = scipy.stats.norm(0.5, 0.1) | |
# Create many square images from each distribution | |
# Each one is dark border, then lighter, then lighter still | |
src_x_NDD = src_dist.rvs(size=(N, D, D), random_state=42) | |
target_x_NDD = target_dist.rvs(size=(N, D, D), random_state=43) | |
m = D//4 | |
M = 3*D//4 | |
src_x_NDD[:, m:M, m:M] += 0.2 | |
target_x_NDD[:, m:M, m:M] += 0.2 | |
# Add black border | |
B = 3 | |
for arr in [src_x_NDD, target_x_NDD]: | |
arr[:, :B, :] *= 0.03 | |
arr[:, -B:, :] *= 0.03 | |
arr[:, :, :B] *= 0.03 | |
arr[:, :, -B:] *= 0.03 | |
target_x_NDD[0][np.diag_indices(D)] = 0 | |
target_x_NDD[5][np.diag_indices(D)] = 0 | |
target_x_NDD[10][np.diag_indices(D)] = 0 | |
transform = create_transform_func_to_match_source(target_x_NDD, src_x_NDD) | |
txfm_x_NDD = transform(target_x_NDD) | |
for (arr, suptitle_str) in [ | |
(src_x_NDD, 'SOURCE images'), | |
(target_x_NDD, 'TARGET images'), | |
(txfm_x_NDD, 'TRANSFORMED_TO_SRC images')]: | |
axes, im_keys, hist_key = make_new_fig_with_subplots() | |
for ii, key in enumerate(im_keys): | |
cur_ax = axes[key] | |
cur_ax.imshow(arr[ii], vmin=0, vmax=1, cmap='gray') | |
cur_ax.set_xticks([]) | |
cur_ax.set_yticks([]) | |
ax = axes[hist_key] | |
ax.hist( | |
arr.flatten(), | |
density=True, | |
bins=np.linspace(0, 1, 128)) | |
ax.set_xlabel('pixel value') | |
ax.set_ylabel('density') | |
ax.set_ylim([0, 5]) | |
plt.suptitle(suptitle_str) | |
plt.savefig(suptitle_str.replace(" ", "_") + '.png', bbox_inches='tight', pad_inches=0) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Mapping from target values to source values