Last active
March 29, 2022 13:30
-
-
Save airalcorn2/206fea0ad3384a9ec68e05d0f8f67a60 to your computer and use it in GitHub Desktop.
Demonstrates the issues that can arise when assuming a multivariate normal distribution for a multimodal multivariate distribution. See discussion here --> https://openreview.net/forum?id=sO4tOk2lg9I¬eId=d30Xi6n6BcJ.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import numpy as np | |
import seaborn as sns | |
N = 1000 | |
mu = 1 | |
sd = 0.1 | |
# Separate dimensions example. | |
# Generate and plot the first agent's trajectories. | |
x1 = np.random.normal(size=N) | |
y1 = x1 + np.random.normal(0, sd, size=N) | |
samps1 = np.vstack([x1, y1]).T | |
# Generate second agent's trajectories, which are strongly correlated with the first agent's trajectories. | |
x2 = x1 + np.random.normal(0, sd, size=N) | |
y2 = y1 + np.random.normal(0, sd, size=N) | |
samps2 = np.vstack([x2, y2]).T | |
# Estimate the mu and covariance matrix for the agents' xs and ys. | |
xs = np.vstack([samps1[:, 0], samps2[:, 0]]).T | |
mu_hat_x = xs.mean(axis=0) | |
cov_hat_x = np.cov(xs, rowvar=False) | |
ys = np.vstack([samps1[:, 1], samps2[:, 1]]).T | |
mu_hat_y = ys.mean(axis=0) | |
cov_hat_y = np.cov(ys, rowvar=False) | |
# Sample xs and ys for the agents using the estimated mus and covariance matrices. | |
samps_hat_x = np.random.multivariate_normal(mu_hat_x, cov_hat_x, N) | |
samps_hat_y = np.random.multivariate_normal(mu_hat_y, cov_hat_y, N) | |
# Plot the true and generated joint distributions of the trajectories for the first agent. | |
(fig, axs) = plt.subplots(nrows=1, ncols=2) | |
sns.kdeplot(samps1[:, 0], samps1[:, 1], fill=True, ax=axs[0]) | |
axs[0].axis("equal") | |
sns.kdeplot(samps_hat_x[:, 0], samps_hat_y[:, 0], fill=True, ax=axs[1]) | |
axs[1].axis("equal") | |
plt.tight_layout() | |
plt.show() | |
# Multimodal example. | |
# Generate the agents' trajectories. | |
mu_1 = [mu, mu] | |
cov_1 = [[sd, 0], [0, sd]] | |
samps_1 = np.random.multivariate_normal(mu_1, cov_1, N // 2) | |
mu_2 = [-mu, -mu] | |
cov_2 = [[sd, 0], [0, sd]] | |
samps_2 = np.random.multivariate_normal(mu_2, cov_2, N // 2) | |
samps = np.concatenate([samps_1, samps_2]) | |
# Estimate the mu and covariance matrix for the agents. | |
mu_hat = samps.mean(axis=0) | |
cov_hat = np.cov(samps, rowvar=False) | |
# Sample trajectories for the agents using the estimated mu and covariance matrix. | |
samps_hat = np.random.multivariate_normal(mu_hat, cov_hat, N) | |
# Plot the true and generated joint distributions of the agent trajectories. | |
(fig, axs) = plt.subplots(nrows=1, ncols=2) | |
sns.kdeplot(samps[:, 0], samps[:, 1], fill=True, ax=axs[0]) | |
axs[0].axis("equal") | |
sns.kdeplot(samps_hat[:, 0], samps_hat[:, 1], fill=True, ax=axs[1]) | |
axs[1].axis("equal") | |
plt.tight_layout() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment