Created
March 20, 2025 22:00
-
-
Save Nikolaj-K/5b87cf21dcf6f13a3c25ec9dc10fa7cc to your computer and use it in GitHub Desktop.
Inverse Transform Sampling
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Code explained in | |
https://youtu.be/tSyMnVd6DsY | |
* Theorem: | |
V := CDF_X^{-1}(U), with U from the uniform distirbution on [0,1], has same distribution as X. | |
Proof: | |
* Note: {r \in Q | sqrt{3} < r} = {r \in Q | 3 < r^2} | |
* Pr(V <= x) = Pr(CDF_X^{-1}(U) <= x) = Pr(U <= CDF_X(x)) = CDF_X(x) | |
See: | |
* https://en.wikipedia.org/wiki/Inverse_transform_sampling | |
Sanity check of method: | |
Where pdf of X is low along any x-axis inverval (say from 2 to 4), | |
the cdf remaines flat, the inverse cdf is steep, | |
and so the y-axis interval to which the x-axis values are associated is short. | |
So if we sample uniformly on the y-axis, we collect view points in any more narrow y-axis interval. | |
""" | |
import numpy as np | |
import math | |
import matplotlib.pyplot as plt | |
from scipy.integrate import quad | |
from scipy.interpolate import interp1d | |
class Config: | |
SAMPLE_SIZE: int = 3000 | |
LOWER: float = -3 | |
UPPER: float = 5 | |
def my_example_func(x: np.ndarray) -> np.ndarray: | |
f = 3.0 * np.exp(-0.5 * ((x - 0.2) / .3) ** 2) | |
g = 4.0 * np.exp(-0.5 * ((x - 2.5) / .5) ** 2) | |
h = 0.1 * (x - 1) ** 2 | |
k = 0.7 * abs(np.sin(np.pi * x)) | |
return f + g + h + k | |
def normalize_func(f: callable, a: float, b: float) -> callable: | |
c, _ = quad(f, a, b) | |
def pdf(x: float) -> float: | |
return f(x) / c | |
return pdf | |
def get_cdf(pdf: callable, a: float, b: float, num_points: int = 100) -> tuple[np.ndarray, np.ndarray]: | |
xs = np.linspace(a, b, num_points) | |
ys = np.array([quad(pdf, a, x)[0] for x in xs]) | |
return xs, ys | |
def inverse_monotone_func(xs, ys) -> interp1d: | |
assert all(idx == 0 or ys[idx-1] < y for idx, y in enumerate(ys)) | |
return interp1d(ys, xs, bounds_error=False) | |
def make_inverse_cdf(pdf: callable, a: float, b: float, num_points: int = 100) -> interp1d: | |
xs, ys = get_cdf(pdf, a, b, num_points) | |
return inverse_monotone_func(xs, ys) | |
def plot_histogram(pdf: callable, samples: np.ndarray, a: float, b: float, bins: int = 70) -> None: | |
xs = np.linspace(1.3 * a, 1.3 * b, 1000) # Reaches beyond [a, b] for visualization sake | |
pdf_vals = pdf(xs) | |
xs_cdf, cdf_vals = get_cdf(pdf, a, b) | |
cdf_vals_normalized = cdf_vals * np.max(pdf_vals) / np.max(cdf_vals) | |
sorted_samples = np.sort(samples) | |
empirical_cdf_vals = np.arange(1, len(sorted_samples) + 1) / len(sorted_samples) | |
empirical_cdf_vals_normalized = empirical_cdf_vals * np.max(pdf_vals) / np.max(empirical_cdf_vals) | |
inverse_cdf = make_inverse_cdf(pdf, Config.LOWER, Config.UPPER) | |
DARK_GREEN = '#006400' | |
fig, axes = plt.subplots(1, 2, figsize=(16, 6)) | |
axes[0].hist(sorted_samples, bins=bins, density=True, alpha=0.6, color='green', label='CDF^{-1}(uniform dist. samples) histogram') | |
axes[0].plot(sorted_samples, empirical_cdf_vals_normalized, label='CDF of uniform dist. samples (re-normalized)', color=DARK_GREEN, lw=2) | |
axes[0].plot(xs, pdf_vals, label='PDF ("Analytical")', color='black', lw=2) | |
axes[0].plot(xs_cdf, cdf_vals_normalized, label='Analytical CDF (re-normalized)', color='gray', lw=2) | |
axes[0].set_xlabel('x, resp. sample value domain') | |
axes[0].set_ylabel('Density / CDF Value') | |
axes[0].legend(loc='best') | |
#axes[0].set_title('PDF, CDF, and Histogram') | |
axes[1].plot(xs, inverse_cdf(xs), label='Inverse CDF', color='gray', lw=2) | |
for xi in np.linspace(0, 1, 30): | |
yi = inverse_cdf(xi) | |
axes[1].plot([xi, xi], [a, yi], color='blue', linewidth=0.5) | |
axes[1].plot([0, xi], [yi, yi], color='blue', linewidth=0.5) | |
axes[1].set_xlim(0, 1) | |
axes[1].set_ylim(a, b) | |
axes[1].set_xlabel('U') | |
#axes[1].set_ylabel('Inverse CDF Value') | |
axes[1].legend(loc='best') | |
#axes[1].set_title('Inverse CDF with Grid-Lines') | |
axes[1].grid(False) | |
plt.tight_layout() | |
plt.show() | |
if __name__ == '__main__': | |
my_pdf = normalize_func(my_example_func, Config.LOWER, Config.UPPER) | |
my_make_inverse_cdf = make_inverse_cdf(my_pdf, Config.LOWER, Config.UPPER) | |
my_uniform_samples = np.random.uniform(0, 1, Config.SAMPLE_SIZE) | |
my_samples = my_make_inverse_cdf(my_uniform_samples) | |
plot_histogram(my_pdf, my_samples, Config.LOWER, Config.UPPER) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment