Skip to content

Instantly share code, notes, and snippets.

@Flecart
Last active December 14, 2024 11:13
Show Gist options
  • Save Flecart/e4105125f9b59fd7104bf9955b7e5d1b to your computer and use it in GitHub Desktop.
Save Flecart/e4105125f9b59fd7104bf9955b7e5d1b to your computer and use it in GitHub Desktop.
Dirichlet Mixture Model with Collapsed Gibbs Sampler
from re import I
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
# Generate synthetic data (2D points for visualization)
def generate_data():
np.random.seed(42)
data = np.vstack([
np.random.multivariate_normal([5, 5], [[1, 0], [0, 1]], 50),
np.random.multivariate_normal([2, 2], [[1, 0.2], [0.2, 1]], 100),
np.random.multivariate_normal([8, 1], [[1, 0], [0, 1]], 50)
])
return data
# Collapsed Gibbs Sampler for Dirichlet Mixture Model
def collapsed_gibbs_sampler(data, max_clusters, alpha, num_iterations):
n, d = data.shape
num_clusters = max_clusters
cluster_assignments = np.random.randint(0, max_clusters, size=n)
cluster_counts = np.zeros(max_clusters, dtype=int)
cluster_means = np.zeros((max_clusters, d))
cluster_covs = np.array([np.eye(d) for _ in range(max_clusters)])
# Store cluster assignments history for visualization
cluster_history = []
# Initialize clusters
for k in range(max_clusters):
cluster_points = data[cluster_assignments == k]
if len(cluster_points) > 0:
cluster_means[k] = np.mean(cluster_points, axis=0)
cluster_covs[k] = np.cov(cluster_points.T) + np.eye(d) * 1e-5
cluster_counts[k] = len(cluster_points)
for it in range(num_iterations):
for i in range(n):
# Remove current point from its cluster
current_cluster = cluster_assignments[i]
cluster_counts[current_cluster] -= 1
if cluster_counts[current_cluster] > 0:
cluster_points = data[cluster_assignments == current_cluster]
cluster_means[current_cluster] = np.mean(cluster_points, axis=0)
cluster_covs[current_cluster] = np.cov(cluster_points.T) + np.eye(d) * 1e-5
else:
cluster_means[current_cluster] = np.zeros(d)
cluster_covs[current_cluster] = np.eye(d)
# Compute posterior probabilities for each cluster
probs = []
for k in range(num_clusters):
if cluster_counts[k] > 0:
# print(cluster_covs[k])
# check if cluster covs has nans, and just set it to a small eye in case
if np.isnan(cluster_covs[k]).any():
cluster_covs[k] = np.eye(d) * 1e-5
likelihood = multivariate_normal.pdf(data[i], mean=cluster_means[k], cov=cluster_covs[k])
else:
likelihood = 1.0 # Non-informative prior
assert cluster_counts[k] >= 0
prior = cluster_counts[k] / (n - 1 + alpha)
probs.append(prior * likelihood)
# Add the probability of having a new cluster
if num_clusters < max_clusters:
probs.append(alpha / (n - 1 + alpha))
# else:
# print("WARNING: maximum number of clusters reached!")
probs = np.array(probs)
probs /= probs.sum()
# Resample cluster assignment
# print(probs)
new_cluster = np.random.choice(len(probs), p=probs)
if new_cluster == num_clusters:
num_clusters += 1
cluster_assignments[i] = new_cluster
cluster_counts[new_cluster] += 1
cluster_points = data[cluster_assignments == new_cluster]
cluster_means[new_cluster] = np.mean(cluster_points, axis=0)
cluster_covs[new_cluster] = np.cov(cluster_points.T) + np.eye(d) * 1e-5
# if a cluster is empty remove it (i.e. swap it with another cluster)
if cluster_counts[current_cluster] == 0:
cluster_assignments[cluster_assignments == num_clusters - 1] = current_cluster
cluster_counts[current_cluster] = cluster_counts[num_clusters - 1]
cluster_counts[num_clusters - 1] = 0
cluster_means[current_cluster] = cluster_means[num_clusters - 1]
cluster_covs[current_cluster] = cluster_covs[num_clusters - 1]
num_clusters -= 1
# Store cluster assignments for visualization
cluster_history.append(cluster_assignments.copy())
return cluster_assignments, cluster_history
def k_means_clustering(data, num_clusters, num_iterations):
n, d = data.shape
cluster_assignments = np.random.randint(0, num_clusters, size=n)
cluster_means = np.zeros((num_clusters, d))
cluster_history = [cluster_assignments.copy()]
for it in range(num_iterations):
cluster_centers = np.zeros((num_clusters, d))
cluster_counts = np.zeros(num_clusters, dtype=int)
for i in range(n):
cluster_centers[cluster_assignments[i]] += data[i]
cluster_counts[cluster_assignments[i]] += 1
for k in range(num_clusters):
if cluster_counts[k] > 0:
cluster_means[k] = cluster_centers[k] / cluster_counts[k]
for i in range(n):
distances = np.linalg.norm(cluster_means - data[i], axis=1)
cluster_assignments[i] = np.argmin(distances)
cluster_history.append(cluster_assignments.copy())
return cluster_assignments, cluster_history
# Visualization of clustering evolution
def visualize_clustering(data, cluster_history):
num_iterations = len(cluster_history)
y_size = (num_iterations + 4)//5
fig, axes = plt.subplots(y_size, 5, figsize=(15, 4 * y_size))
for x, clusters in enumerate(cluster_history):
j = x % 5
i = x // 5
if y_size == 1:
ax = axes[j]
else:
ax = axes[i][j]
ax.scatter(data[:, 0], data[:, 1], c=clusters, cmap='tab10', s=10)
ax.set_title(f"Iteration {x + 1}")
ax.axis('equal')
plt.tight_layout()
plt.show()
# Main function
def main():
# Parameters
data = generate_data()
num_clusters = 10
alpha = 1 # Dirichlet parameter
num_iterations = 30
# Run Collapsed Gibbs Sampler
cluster_assignments, cluster_history = collapsed_gibbs_sampler(data, num_clusters, alpha, num_iterations)
# Visualize Clustering Evolution
visualize_clustering(data, cluster_history)
# We the same with a K_means and same number of iterations
# cluster_assignments, cluster_history = k_means_clustering(data, num_clusters, num_iterations)
# visualize_clustering(data, cluster_history)
if __name__ == "__main__":
main()
@Flecart
Copy link
Author

Flecart commented Dec 14, 2024

Original Code proposed by O1

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal

# Generate synthetic data (2D points for visualization)
def generate_data():
    np.random.seed(42)
    data = np.vstack([
        np.random.multivariate_normal([5, 5], [[1, 0], [0, 1]], 50),
        np.random.multivariate_normal([0, 0], [[1, 0], [0, 1]], 50),
        np.random.multivariate_normal([8, 1], [[1, 0], [0, 1]], 50)
    ])
    return data

# Collapsed Gibbs Sampler for Dirichlet Mixture Model
def collapsed_gibbs_sampler(data, num_clusters, alpha, num_iterations):
    n, d = data.shape
    cluster_assignments = np.random.randint(0, num_clusters, size=n)
    cluster_counts = np.zeros(num_clusters, dtype=int)
    cluster_means = np.zeros((num_clusters, d))
    cluster_covs = np.array([np.eye(d) for _ in range(num_clusters)])
    
    # Initialize clusters
    for k in range(num_clusters):
        cluster_points = data[cluster_assignments == k]
        if len(cluster_points) > 0:
            cluster_means[k] = np.mean(cluster_points, axis=0)
            cluster_covs[k] = np.cov(cluster_points.T) + np.eye(d) * 1e-3
        cluster_counts[k] = len(cluster_points)

    # Store cluster assignments history for visualization
    cluster_history = []

    for it in range(num_iterations):
        for i in range(n):
            # Remove current point from its cluster
            current_cluster = cluster_assignments[i]
            cluster_counts[current_cluster] -= 1
            
            if cluster_counts[current_cluster] > 0:
                cluster_points = data[cluster_assignments == current_cluster]
                cluster_means[current_cluster] = np.mean(cluster_points, axis=0)
                cluster_covs[current_cluster] = np.cov(cluster_points.T) + np.eye(d) * 1e-3
            else:
                cluster_means[current_cluster] = np.zeros(d)
                cluster_covs[current_cluster] = np.eye(d)

            # Compute posterior probabilities for each cluster
            probs = []
            for k in range(num_clusters):
                if cluster_counts[k] > 0:
                    likelihood = multivariate_normal.pdf(data[i], mean=cluster_means[k], cov=cluster_covs[k])
                else:
                    likelihood = 1.0  # Non-informative prior
                prior = (cluster_counts[k] + alpha) / (n - 1 + num_clusters * alpha)
                probs.append(prior * likelihood)

            probs = np.array(probs)
            probs /= probs.sum()

            # Resample cluster assignment
            new_cluster = np.random.choice(num_clusters, p=probs)
            cluster_assignments[i] = new_cluster
            cluster_counts[new_cluster] += 1

            cluster_points = data[cluster_assignments == new_cluster]
            cluster_means[new_cluster] = np.mean(cluster_points, axis=0)
            cluster_covs[new_cluster] = np.cov(cluster_points.T) + np.eye(d) * 1e-3

        # Store cluster assignments for visualization
        cluster_history.append(cluster_assignments.copy())

    return cluster_assignments, cluster_history

# Visualization of clustering evolution
def visualize_clustering(data, cluster_history):
    num_iterations = len(cluster_history)
    fig, axes = plt.subplots(1, num_iterations, figsize=(15, 5))

    for i, clusters in enumerate(cluster_history):
        axes[i].scatter(data[:, 0], data[:, 1], c=clusters, cmap='tab10', s=10)
        axes[i].set_title(f"Iteration {i + 1}")
        axes[i].axis('equal')

    plt.tight_layout()
    plt.show()

# Main function
def main():
    # Parameters
    data = generate_data()
    num_clusters = 3
    alpha = 1.0
    num_iterations = 5

    # Run Collapsed Gibbs Sampler
    cluster_assignments, cluster_history = collapsed_gibbs_sampler(data, num_clusters, alpha, num_iterations)

    # Visualize Clustering Evolution
    visualize_clustering(data, cluster_history)

if __name__ == "__main__":
    main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment