Skip to content

Instantly share code, notes, and snippets.

@berserker1
Created April 2, 2025 23:13
Show Gist options
  • Save berserker1/4f8102e4e9e415880e0cebcd5bf3bf16 to your computer and use it in GitHub Desktop.
Save berserker1/4f8102e4e9e415880e0cebcd5bf3bf16 to your computer and use it in GitHub Desktop.
"""
Implementation of the Alpha Divergence Estimators from Section 4.3 of the paper
"Empirical Squared Hellinger Distance Estimator and Generalizations to a Family of α-Divergence Estimators"
https://pmc.ncbi.nlm.nih.gov/articles/PMC10137612/#sec4-entropy-25-00612
"""
import numpy as np
from scipy.spatial import KDTree
from scipy.special import gamma, digamma
import matplotlib.pyplot as plt
from typing import Tuple, Optional, Union, List
def alpha_divergence_estimator(X: np.ndarray, Y: np.ndarray, alpha: float = 1.0, k: int = 5) -> float:
"""
Estimate the alpha-divergence between two distributions using k-NN method.
Based on Section 4.3 of the paper "Empirical Squared Hellinger Distance Estimator
and Generalizations to a Family of α-Divergence Estimators"
Args:
X: Sample from the first distribution, shape (n_samples_X, n_features)
Y: Sample from the second distribution, shape (n_samples_Y, n_features)
alpha: Parameter for the alpha-divergence family, default is 1.0 (KL divergence)
k: Number of nearest neighbors to consider, default is 5
Returns:
Estimated alpha-divergence value
"""
n_X = X.shape[0] # Number of samples in X
n_Y = Y.shape[0] # Number of samples in Y
d = X.shape[1] # Dimensionality of the data
# Build KD trees for efficient nearest neighbor search
tree_X = KDTree(X)
tree_Y = KDTree(Y)
# Compute the k-NN distances in X for samples in X
knn_dist_X_X = np.zeros(n_X)
for i in range(n_X):
# Find k+1 nearest neighbors because the first one is the point itself
# Then take the k-th distance (index k)
dist, _ = tree_X.query(X[i].reshape(1, -1), k=k+1)
knn_dist_X_X[i] = dist[0, k]
# Compute the k-NN distances in Y for samples in X
knn_dist_X_Y = np.zeros(n_X)
for i in range(n_X):
dist, _ = tree_Y.query(X[i].reshape(1, -1), k=k)
knn_dist_X_Y[i] = dist[0, k-1] # k-1 because 0-indexed
# Compute the k-NN distances in X for samples in Y
knn_dist_Y_X = np.zeros(n_Y)
for i in range(n_Y):
dist, _ = tree_X.query(Y[i].reshape(1, -1), k=k)
knn_dist_Y_X[i] = dist[0, k-1] # k-1 because 0-indexed
# Compute the k-NN distances in Y for samples in Y
knn_dist_Y_Y = np.zeros(n_Y)
for i in range(n_Y):
# Find k+1 nearest neighbors because the first one is the point itself
dist, _ = tree_Y.query(Y[i].reshape(1, -1), k=k+1)
knn_dist_Y_Y[i] = dist[0, k]
# Constants for estimator
c_d = np.pi**(d/2) / gamma(d/2 + 1) # Volume of unit ball in d dimensions
# Compute the estimator based on alpha value
if alpha == 1: # KL divergence: D(P||Q)
# Equation 27 in the paper
log_ratio = np.mean(np.log(knn_dist_X_Y / knn_dist_X_X))
return d * log_ratio + np.log(n_Y / (n_X - 1))
elif alpha == 0: # KL divergence: D(Q||P)
# Equation 28 in the paper
log_ratio = np.mean(np.log(knn_dist_Y_X / knn_dist_Y_Y))
return d * log_ratio + np.log(n_X / (n_Y - 1))
else: # General alpha-divergence
# Equation 31 in the paper
beta = (1 - alpha) # Parameter in the paper
# Compute the elements of the estimator
term1 = (np.sum((knn_dist_X_Y / knn_dist_X_X)**(-d * beta)) /
(n_X * (n_Y / (n_X - 1))**beta))
term2 = (np.sum((knn_dist_Y_X / knn_dist_Y_Y)**(-d * (1 - beta))) /
(n_Y * (n_X / (n_Y - 1))**(1 - beta)))
# Final estimator
if alpha != 0.5:
return (1 / (alpha * (1 - alpha))) * (term1 - 1)
else: # Special case for alpha = 0.5 (squared Hellinger distance)
return 2 * (1 - np.sqrt(term1))
def hellinger_distance_estimator(X: np.ndarray, Y: np.ndarray, k: int = 5) -> float:
"""
Estimate the squared Hellinger distance between two distributions using k-NN method.
This is a special case of alpha-divergence with alpha = 0.5.
Args:
X: Sample from the first distribution, shape (n_samples_X, n_features)
Y: Sample from the second distribution, shape (n_samples_Y, n_features)
k: Number of nearest neighbors to consider, default is 5
Returns:
Estimated squared Hellinger distance
"""
return alpha_divergence_estimator(X, Y, alpha=0.5, k=k)
def kl_divergence_estimator(X: np.ndarray, Y: np.ndarray, k: int = 5) -> float:
"""
Estimate the KL divergence D(P||Q) between two distributions using k-NN method.
This is a special case of alpha-divergence with alpha = 1.
Args:
X: Sample from the first distribution, shape (n_samples_X, n_features)
Y: Sample from the second distribution, shape (n_samples_Y, n_features)
k: Number of nearest neighbors to consider, default is 5
Returns:
Estimated KL divergence D(P||Q)
"""
return alpha_divergence_estimator(X, Y, alpha=1.0, k=k)
def reverse_kl_divergence_estimator(X: np.ndarray, Y: np.ndarray, k: int = 5) -> float:
"""
Estimate the reverse KL divergence D(Q||P) between two distributions using k-NN method.
This is a special case of alpha-divergence with alpha = 0.
Args:
X: Sample from the first distribution, shape (n_samples_X, n_features)
Y: Sample from the second distribution, shape (n_samples_Y, n_features)
k: Number of nearest neighbors to consider, default is 5
Returns:
Estimated KL divergence D(Q||P)
"""
return alpha_divergence_estimator(X, Y, alpha=0.0, k=k)
def evaluate_on_gaussians(
d: int = 2,
n_samples: int = 1000,
means_distance: float = 2.0,
k_values: List[int] = [1, 3, 5, 7, 10],
alpha_values: List[float] = [0, 0.5, 1.0]
) -> None:
"""
Evaluate the alpha-divergence estimators on samples from two Gaussian distributions.
Args:
d: Dimensionality of the distributions
n_samples: Number of samples to draw from each distribution
means_distance: Distance between the means of the two distributions
k_values: List of k values to try for k-NN
alpha_values: List of alpha values to evaluate
Returns:
None, but produces plots showing the estimates
"""
# Define two Gaussian distributions
mean1 = np.zeros(d)
mean2 = np.zeros(d)
mean2[0] = means_distance # Shift the first dimension
cov = np.eye(d) # Identity covariance matrix for both distributions
# Generate samples
np.random.seed(42)
X = np.random.multivariate_normal(mean1, cov, n_samples)
Y = np.random.multivariate_normal(mean2, cov, n_samples)
# Calculate true divergences (analytical for Gaussians)
true_kl = 0.5 * np.sum((mean1 - mean2) @ np.linalg.inv(cov) @ (mean1 - mean2))
true_hellinger = 1 - np.exp(-0.25 * np.sum((mean1 - mean2) @ np.linalg.inv(cov) @ (mean1 - mean2)))
# Create subplots for different alpha values
fig, axes = plt.subplots(len(alpha_values), 1, figsize=(10, 4*len(alpha_values)))
if len(alpha_values) == 1:
axes = [axes]
for i, alpha in enumerate(alpha_values):
estimates = []
for k in k_values:
est = alpha_divergence_estimator(X, Y, alpha=alpha, k=k)
estimates.append(est)
# Plot the results
ax = axes[i]
ax.plot(k_values, estimates, 'o-', label=f'Estimated α={alpha}')
# Add true value line if available
if alpha == 1.0:
ax.axhline(y=true_kl, color='r', linestyle='--', label=f'True KL={true_kl:.4f}')
elif alpha == 0.5:
ax.axhline(y=true_hellinger, color='r', linestyle='--', label=f'True Hellinger={true_hellinger:.4f}')
ax.set_xlabel('Number of neighbors (k)')
ax.set_ylabel(f'α-Divergence (α={alpha})')
ax.set_title(f'Estimation of α-Divergence with α={alpha}')
ax.legend()
ax.grid(True)
plt.tight_layout()
plt.show()
# Example usage
if __name__ == "__main__":
# Set random seed for reproducibility
np.random.seed(42)
# Generate samples from two different distributions
# Here we use 2D Gaussians with different means
d = 2 # Dimensionality
n_samples = 1000 # Number of samples
# First distribution: 2D Gaussian centered at (0, 0)
mean1 = np.zeros(d)
cov1 = np.eye(d) # Identity covariance matrix
X = np.random.multivariate_normal(mean1, cov1, n_samples)
# Second distribution: 2D Gaussian centered at (2, 0)
mean2 = np.array([2.0, 0.0])
cov2 = np.eye(d) # Identity covariance matrix
Y = np.random.multivariate_normal(mean2, cov2, n_samples)
# Compute KL divergence
kl_div = kl_divergence_estimator(X, Y, k=5)
print(f"Estimated KL divergence D(P||Q): {kl_div:.4f}")
# Compute reverse KL divergence
rev_kl_div = reverse_kl_divergence_estimator(X, Y, k=5)
print(f"Estimated reverse KL divergence D(Q||P): {rev_kl_div:.4f}")
# Compute Squared Hellinger distance
hellinger_dist = hellinger_distance_estimator(X, Y, k=5)
print(f"Estimated squared Hellinger distance: {hellinger_dist:.4f}")
# Compute general alpha-divergence for alpha = 0.3
alpha_div = alpha_divergence_estimator(X, Y, alpha=0.3, k=5)
print(f"Estimated alpha-divergence (alpha=0.3): {alpha_div:.4f}")
# Evaluate the estimators for different k values and alpha values
evaluate_on_gaussians(k_values=[1, 3, 5, 7, 10, 15], alpha_values=[0.0, 0.5, 1.0])
@berserker1
Copy link
Author

This code implements the α-divergence estimators from Section 4.3 of the paper "Empirical Squared Hellinger Distance Estimator and Generalizations to a Family of α-Divergence Estimators." The implementation focuses on:

The main alpha_divergence_estimator function that implements the k-nearest neighbor (k-NN) based approach for estimating α-divergences between two probability distributions given finite samples
Special cases for common divergences:

hellinger_distance_estimator for squared Hellinger distance (α = 0.5)
kl_divergence_estimator for Kullback-Leibler divergence (α = 1.0)
reverse_kl_divergence_estimator for reverse KL divergence (α = 0.0)

Evaluation functions to test the estimators on Gaussian distributions where analytical values can be calculated

The implementation follows the equations described in Section 4.3 of the paper, particularly equations 27, 28, and 31, which provide estimators for different values of α in the α-divergence family. The code uses k-nearest neighbor distances to compute non-parametric estimates of these divergences from finite samples without requiring explicit density estimation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment