Created
April 2, 2025 23:13
-
-
Save berserker1/4f8102e4e9e415880e0cebcd5bf3bf16 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Implementation of the Alpha Divergence Estimators from Section 4.3 of the paper | |
"Empirical Squared Hellinger Distance Estimator and Generalizations to a Family of α-Divergence Estimators" | |
https://pmc.ncbi.nlm.nih.gov/articles/PMC10137612/#sec4-entropy-25-00612 | |
""" | |
import numpy as np | |
from scipy.spatial import KDTree | |
from scipy.special import gamma, digamma | |
import matplotlib.pyplot as plt | |
from typing import Tuple, Optional, Union, List | |
def alpha_divergence_estimator(X: np.ndarray, Y: np.ndarray, alpha: float = 1.0, k: int = 5) -> float: | |
""" | |
Estimate the alpha-divergence between two distributions using k-NN method. | |
Based on Section 4.3 of the paper "Empirical Squared Hellinger Distance Estimator | |
and Generalizations to a Family of α-Divergence Estimators" | |
Args: | |
X: Sample from the first distribution, shape (n_samples_X, n_features) | |
Y: Sample from the second distribution, shape (n_samples_Y, n_features) | |
alpha: Parameter for the alpha-divergence family, default is 1.0 (KL divergence) | |
k: Number of nearest neighbors to consider, default is 5 | |
Returns: | |
Estimated alpha-divergence value | |
""" | |
n_X = X.shape[0] # Number of samples in X | |
n_Y = Y.shape[0] # Number of samples in Y | |
d = X.shape[1] # Dimensionality of the data | |
# Build KD trees for efficient nearest neighbor search | |
tree_X = KDTree(X) | |
tree_Y = KDTree(Y) | |
# Compute the k-NN distances in X for samples in X | |
knn_dist_X_X = np.zeros(n_X) | |
for i in range(n_X): | |
# Find k+1 nearest neighbors because the first one is the point itself | |
# Then take the k-th distance (index k) | |
dist, _ = tree_X.query(X[i].reshape(1, -1), k=k+1) | |
knn_dist_X_X[i] = dist[0, k] | |
# Compute the k-NN distances in Y for samples in X | |
knn_dist_X_Y = np.zeros(n_X) | |
for i in range(n_X): | |
dist, _ = tree_Y.query(X[i].reshape(1, -1), k=k) | |
knn_dist_X_Y[i] = dist[0, k-1] # k-1 because 0-indexed | |
# Compute the k-NN distances in X for samples in Y | |
knn_dist_Y_X = np.zeros(n_Y) | |
for i in range(n_Y): | |
dist, _ = tree_X.query(Y[i].reshape(1, -1), k=k) | |
knn_dist_Y_X[i] = dist[0, k-1] # k-1 because 0-indexed | |
# Compute the k-NN distances in Y for samples in Y | |
knn_dist_Y_Y = np.zeros(n_Y) | |
for i in range(n_Y): | |
# Find k+1 nearest neighbors because the first one is the point itself | |
dist, _ = tree_Y.query(Y[i].reshape(1, -1), k=k+1) | |
knn_dist_Y_Y[i] = dist[0, k] | |
# Constants for estimator | |
c_d = np.pi**(d/2) / gamma(d/2 + 1) # Volume of unit ball in d dimensions | |
# Compute the estimator based on alpha value | |
if alpha == 1: # KL divergence: D(P||Q) | |
# Equation 27 in the paper | |
log_ratio = np.mean(np.log(knn_dist_X_Y / knn_dist_X_X)) | |
return d * log_ratio + np.log(n_Y / (n_X - 1)) | |
elif alpha == 0: # KL divergence: D(Q||P) | |
# Equation 28 in the paper | |
log_ratio = np.mean(np.log(knn_dist_Y_X / knn_dist_Y_Y)) | |
return d * log_ratio + np.log(n_X / (n_Y - 1)) | |
else: # General alpha-divergence | |
# Equation 31 in the paper | |
beta = (1 - alpha) # Parameter in the paper | |
# Compute the elements of the estimator | |
term1 = (np.sum((knn_dist_X_Y / knn_dist_X_X)**(-d * beta)) / | |
(n_X * (n_Y / (n_X - 1))**beta)) | |
term2 = (np.sum((knn_dist_Y_X / knn_dist_Y_Y)**(-d * (1 - beta))) / | |
(n_Y * (n_X / (n_Y - 1))**(1 - beta))) | |
# Final estimator | |
if alpha != 0.5: | |
return (1 / (alpha * (1 - alpha))) * (term1 - 1) | |
else: # Special case for alpha = 0.5 (squared Hellinger distance) | |
return 2 * (1 - np.sqrt(term1)) | |
def hellinger_distance_estimator(X: np.ndarray, Y: np.ndarray, k: int = 5) -> float: | |
""" | |
Estimate the squared Hellinger distance between two distributions using k-NN method. | |
This is a special case of alpha-divergence with alpha = 0.5. | |
Args: | |
X: Sample from the first distribution, shape (n_samples_X, n_features) | |
Y: Sample from the second distribution, shape (n_samples_Y, n_features) | |
k: Number of nearest neighbors to consider, default is 5 | |
Returns: | |
Estimated squared Hellinger distance | |
""" | |
return alpha_divergence_estimator(X, Y, alpha=0.5, k=k) | |
def kl_divergence_estimator(X: np.ndarray, Y: np.ndarray, k: int = 5) -> float: | |
""" | |
Estimate the KL divergence D(P||Q) between two distributions using k-NN method. | |
This is a special case of alpha-divergence with alpha = 1. | |
Args: | |
X: Sample from the first distribution, shape (n_samples_X, n_features) | |
Y: Sample from the second distribution, shape (n_samples_Y, n_features) | |
k: Number of nearest neighbors to consider, default is 5 | |
Returns: | |
Estimated KL divergence D(P||Q) | |
""" | |
return alpha_divergence_estimator(X, Y, alpha=1.0, k=k) | |
def reverse_kl_divergence_estimator(X: np.ndarray, Y: np.ndarray, k: int = 5) -> float: | |
""" | |
Estimate the reverse KL divergence D(Q||P) between two distributions using k-NN method. | |
This is a special case of alpha-divergence with alpha = 0. | |
Args: | |
X: Sample from the first distribution, shape (n_samples_X, n_features) | |
Y: Sample from the second distribution, shape (n_samples_Y, n_features) | |
k: Number of nearest neighbors to consider, default is 5 | |
Returns: | |
Estimated KL divergence D(Q||P) | |
""" | |
return alpha_divergence_estimator(X, Y, alpha=0.0, k=k) | |
def evaluate_on_gaussians( | |
d: int = 2, | |
n_samples: int = 1000, | |
means_distance: float = 2.0, | |
k_values: List[int] = [1, 3, 5, 7, 10], | |
alpha_values: List[float] = [0, 0.5, 1.0] | |
) -> None: | |
""" | |
Evaluate the alpha-divergence estimators on samples from two Gaussian distributions. | |
Args: | |
d: Dimensionality of the distributions | |
n_samples: Number of samples to draw from each distribution | |
means_distance: Distance between the means of the two distributions | |
k_values: List of k values to try for k-NN | |
alpha_values: List of alpha values to evaluate | |
Returns: | |
None, but produces plots showing the estimates | |
""" | |
# Define two Gaussian distributions | |
mean1 = np.zeros(d) | |
mean2 = np.zeros(d) | |
mean2[0] = means_distance # Shift the first dimension | |
cov = np.eye(d) # Identity covariance matrix for both distributions | |
# Generate samples | |
np.random.seed(42) | |
X = np.random.multivariate_normal(mean1, cov, n_samples) | |
Y = np.random.multivariate_normal(mean2, cov, n_samples) | |
# Calculate true divergences (analytical for Gaussians) | |
true_kl = 0.5 * np.sum((mean1 - mean2) @ np.linalg.inv(cov) @ (mean1 - mean2)) | |
true_hellinger = 1 - np.exp(-0.25 * np.sum((mean1 - mean2) @ np.linalg.inv(cov) @ (mean1 - mean2))) | |
# Create subplots for different alpha values | |
fig, axes = plt.subplots(len(alpha_values), 1, figsize=(10, 4*len(alpha_values))) | |
if len(alpha_values) == 1: | |
axes = [axes] | |
for i, alpha in enumerate(alpha_values): | |
estimates = [] | |
for k in k_values: | |
est = alpha_divergence_estimator(X, Y, alpha=alpha, k=k) | |
estimates.append(est) | |
# Plot the results | |
ax = axes[i] | |
ax.plot(k_values, estimates, 'o-', label=f'Estimated α={alpha}') | |
# Add true value line if available | |
if alpha == 1.0: | |
ax.axhline(y=true_kl, color='r', linestyle='--', label=f'True KL={true_kl:.4f}') | |
elif alpha == 0.5: | |
ax.axhline(y=true_hellinger, color='r', linestyle='--', label=f'True Hellinger={true_hellinger:.4f}') | |
ax.set_xlabel('Number of neighbors (k)') | |
ax.set_ylabel(f'α-Divergence (α={alpha})') | |
ax.set_title(f'Estimation of α-Divergence with α={alpha}') | |
ax.legend() | |
ax.grid(True) | |
plt.tight_layout() | |
plt.show() | |
# Example usage | |
if __name__ == "__main__": | |
# Set random seed for reproducibility | |
np.random.seed(42) | |
# Generate samples from two different distributions | |
# Here we use 2D Gaussians with different means | |
d = 2 # Dimensionality | |
n_samples = 1000 # Number of samples | |
# First distribution: 2D Gaussian centered at (0, 0) | |
mean1 = np.zeros(d) | |
cov1 = np.eye(d) # Identity covariance matrix | |
X = np.random.multivariate_normal(mean1, cov1, n_samples) | |
# Second distribution: 2D Gaussian centered at (2, 0) | |
mean2 = np.array([2.0, 0.0]) | |
cov2 = np.eye(d) # Identity covariance matrix | |
Y = np.random.multivariate_normal(mean2, cov2, n_samples) | |
# Compute KL divergence | |
kl_div = kl_divergence_estimator(X, Y, k=5) | |
print(f"Estimated KL divergence D(P||Q): {kl_div:.4f}") | |
# Compute reverse KL divergence | |
rev_kl_div = reverse_kl_divergence_estimator(X, Y, k=5) | |
print(f"Estimated reverse KL divergence D(Q||P): {rev_kl_div:.4f}") | |
# Compute Squared Hellinger distance | |
hellinger_dist = hellinger_distance_estimator(X, Y, k=5) | |
print(f"Estimated squared Hellinger distance: {hellinger_dist:.4f}") | |
# Compute general alpha-divergence for alpha = 0.3 | |
alpha_div = alpha_divergence_estimator(X, Y, alpha=0.3, k=5) | |
print(f"Estimated alpha-divergence (alpha=0.3): {alpha_div:.4f}") | |
# Evaluate the estimators for different k values and alpha values | |
evaluate_on_gaussians(k_values=[1, 3, 5, 7, 10, 15], alpha_values=[0.0, 0.5, 1.0]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This code implements the α-divergence estimators from Section 4.3 of the paper "Empirical Squared Hellinger Distance Estimator and Generalizations to a Family of α-Divergence Estimators." The implementation focuses on:
The main alpha_divergence_estimator function that implements the k-nearest neighbor (k-NN) based approach for estimating α-divergences between two probability distributions given finite samples
Special cases for common divergences:
hellinger_distance_estimator for squared Hellinger distance (α = 0.5)
kl_divergence_estimator for Kullback-Leibler divergence (α = 1.0)
reverse_kl_divergence_estimator for reverse KL divergence (α = 0.0)
Evaluation functions to test the estimators on Gaussian distributions where analytical values can be calculated
The implementation follows the equations described in Section 4.3 of the paper, particularly equations 27, 28, and 31, which provide estimators for different values of α in the α-divergence family. The code uses k-nearest neighbor distances to compute non-parametric estimates of these divergences from finite samples without requiring explicit density estimation.