Last active
December 14, 2024 11:13
-
-
Save Flecart/e4105125f9b59fd7104bf9955b7e5d1b to your computer and use it in GitHub Desktop.
Dirichlet Mixture Model with Collapsed Gibbs Sampler
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from re import I | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from scipy.stats import multivariate_normal | |
# Generate synthetic data (2D points for visualization) | |
def generate_data(): | |
np.random.seed(42) | |
data = np.vstack([ | |
np.random.multivariate_normal([5, 5], [[1, 0], [0, 1]], 50), | |
np.random.multivariate_normal([2, 2], [[1, 0.2], [0.2, 1]], 100), | |
np.random.multivariate_normal([8, 1], [[1, 0], [0, 1]], 50) | |
]) | |
return data | |
# Collapsed Gibbs Sampler for Dirichlet Mixture Model | |
def collapsed_gibbs_sampler(data, max_clusters, alpha, num_iterations): | |
n, d = data.shape | |
num_clusters = max_clusters | |
cluster_assignments = np.random.randint(0, max_clusters, size=n) | |
cluster_counts = np.zeros(max_clusters, dtype=int) | |
cluster_means = np.zeros((max_clusters, d)) | |
cluster_covs = np.array([np.eye(d) for _ in range(max_clusters)]) | |
# Store cluster assignments history for visualization | |
cluster_history = [] | |
# Initialize clusters | |
for k in range(max_clusters): | |
cluster_points = data[cluster_assignments == k] | |
if len(cluster_points) > 0: | |
cluster_means[k] = np.mean(cluster_points, axis=0) | |
cluster_covs[k] = np.cov(cluster_points.T) + np.eye(d) * 1e-5 | |
cluster_counts[k] = len(cluster_points) | |
for it in range(num_iterations): | |
for i in range(n): | |
# Remove current point from its cluster | |
current_cluster = cluster_assignments[i] | |
cluster_counts[current_cluster] -= 1 | |
if cluster_counts[current_cluster] > 0: | |
cluster_points = data[cluster_assignments == current_cluster] | |
cluster_means[current_cluster] = np.mean(cluster_points, axis=0) | |
cluster_covs[current_cluster] = np.cov(cluster_points.T) + np.eye(d) * 1e-5 | |
else: | |
cluster_means[current_cluster] = np.zeros(d) | |
cluster_covs[current_cluster] = np.eye(d) | |
# Compute posterior probabilities for each cluster | |
probs = [] | |
for k in range(num_clusters): | |
if cluster_counts[k] > 0: | |
# print(cluster_covs[k]) | |
# check if cluster covs has nans, and just set it to a small eye in case | |
if np.isnan(cluster_covs[k]).any(): | |
cluster_covs[k] = np.eye(d) * 1e-5 | |
likelihood = multivariate_normal.pdf(data[i], mean=cluster_means[k], cov=cluster_covs[k]) | |
else: | |
likelihood = 1.0 # Non-informative prior | |
assert cluster_counts[k] >= 0 | |
prior = cluster_counts[k] / (n - 1 + alpha) | |
probs.append(prior * likelihood) | |
# Add the probability of having a new cluster | |
if num_clusters < max_clusters: | |
probs.append(alpha / (n - 1 + alpha)) | |
# else: | |
# print("WARNING: maximum number of clusters reached!") | |
probs = np.array(probs) | |
probs /= probs.sum() | |
# Resample cluster assignment | |
# print(probs) | |
new_cluster = np.random.choice(len(probs), p=probs) | |
if new_cluster == num_clusters: | |
num_clusters += 1 | |
cluster_assignments[i] = new_cluster | |
cluster_counts[new_cluster] += 1 | |
cluster_points = data[cluster_assignments == new_cluster] | |
cluster_means[new_cluster] = np.mean(cluster_points, axis=0) | |
cluster_covs[new_cluster] = np.cov(cluster_points.T) + np.eye(d) * 1e-5 | |
# if a cluster is empty remove it (i.e. swap it with another cluster) | |
if cluster_counts[current_cluster] == 0: | |
cluster_assignments[cluster_assignments == num_clusters - 1] = current_cluster | |
cluster_counts[current_cluster] = cluster_counts[num_clusters - 1] | |
cluster_counts[num_clusters - 1] = 0 | |
cluster_means[current_cluster] = cluster_means[num_clusters - 1] | |
cluster_covs[current_cluster] = cluster_covs[num_clusters - 1] | |
num_clusters -= 1 | |
# Store cluster assignments for visualization | |
cluster_history.append(cluster_assignments.copy()) | |
return cluster_assignments, cluster_history | |
def k_means_clustering(data, num_clusters, num_iterations): | |
n, d = data.shape | |
cluster_assignments = np.random.randint(0, num_clusters, size=n) | |
cluster_means = np.zeros((num_clusters, d)) | |
cluster_history = [cluster_assignments.copy()] | |
for it in range(num_iterations): | |
cluster_centers = np.zeros((num_clusters, d)) | |
cluster_counts = np.zeros(num_clusters, dtype=int) | |
for i in range(n): | |
cluster_centers[cluster_assignments[i]] += data[i] | |
cluster_counts[cluster_assignments[i]] += 1 | |
for k in range(num_clusters): | |
if cluster_counts[k] > 0: | |
cluster_means[k] = cluster_centers[k] / cluster_counts[k] | |
for i in range(n): | |
distances = np.linalg.norm(cluster_means - data[i], axis=1) | |
cluster_assignments[i] = np.argmin(distances) | |
cluster_history.append(cluster_assignments.copy()) | |
return cluster_assignments, cluster_history | |
# Visualization of clustering evolution | |
def visualize_clustering(data, cluster_history): | |
num_iterations = len(cluster_history) | |
y_size = (num_iterations + 4)//5 | |
fig, axes = plt.subplots(y_size, 5, figsize=(15, 4 * y_size)) | |
for x, clusters in enumerate(cluster_history): | |
j = x % 5 | |
i = x // 5 | |
if y_size == 1: | |
ax = axes[j] | |
else: | |
ax = axes[i][j] | |
ax.scatter(data[:, 0], data[:, 1], c=clusters, cmap='tab10', s=10) | |
ax.set_title(f"Iteration {x + 1}") | |
ax.axis('equal') | |
plt.tight_layout() | |
plt.show() | |
# Main function | |
def main(): | |
# Parameters | |
data = generate_data() | |
num_clusters = 10 | |
alpha = 1 # Dirichlet parameter | |
num_iterations = 30 | |
# Run Collapsed Gibbs Sampler | |
cluster_assignments, cluster_history = collapsed_gibbs_sampler(data, num_clusters, alpha, num_iterations) | |
# Visualize Clustering Evolution | |
visualize_clustering(data, cluster_history) | |
# We the same with a K_means and same number of iterations | |
# cluster_assignments, cluster_history = k_means_clustering(data, num_clusters, num_iterations) | |
# visualize_clustering(data, cluster_history) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Original Code proposed by O1