Skip to content

Instantly share code, notes, and snippets.

Last active March 29, 2022 13:30
Show Gist options
  • Save airalcorn2/206fea0ad3384a9ec68e05d0f8f67a60 to your computer and use it in GitHub Desktop.
Save airalcorn2/206fea0ad3384a9ec68e05d0f8f67a60 to your computer and use it in GitHub Desktop.
Demonstrates the issues that can arise when assuming a multivariate normal distribution for a multimodal multivariate distribution. See discussion here -->
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
N = 1000
mu = 1
sd = 0.1
# Separate dimensions example.
# Generate and plot the first agent's trajectories.
x1 = np.random.normal(size=N)
y1 = x1 + np.random.normal(0, sd, size=N)
samps1 = np.vstack([x1, y1]).T
# Generate second agent's trajectories, which are strongly correlated with the first agent's trajectories.
x2 = x1 + np.random.normal(0, sd, size=N)
y2 = y1 + np.random.normal(0, sd, size=N)
samps2 = np.vstack([x2, y2]).T
# Estimate the mu and covariance matrix for the agents' xs and ys.
xs = np.vstack([samps1[:, 0], samps2[:, 0]]).T
mu_hat_x = xs.mean(axis=0)
cov_hat_x = np.cov(xs, rowvar=False)
ys = np.vstack([samps1[:, 1], samps2[:, 1]]).T
mu_hat_y = ys.mean(axis=0)
cov_hat_y = np.cov(ys, rowvar=False)
# Sample xs and ys for the agents using the estimated mus and covariance matrices.
samps_hat_x = np.random.multivariate_normal(mu_hat_x, cov_hat_x, N)
samps_hat_y = np.random.multivariate_normal(mu_hat_y, cov_hat_y, N)
# Plot the true and generated joint distributions of the trajectories for the first agent.
(fig, axs) = plt.subplots(nrows=1, ncols=2)
sns.kdeplot(samps1[:, 0], samps1[:, 1], fill=True, ax=axs[0])
sns.kdeplot(samps_hat_x[:, 0], samps_hat_y[:, 0], fill=True, ax=axs[1])
# Multimodal example.
# Generate the agents' trajectories.
mu_1 = [mu, mu]
cov_1 = [[sd, 0], [0, sd]]
samps_1 = np.random.multivariate_normal(mu_1, cov_1, N // 2)
mu_2 = [-mu, -mu]
cov_2 = [[sd, 0], [0, sd]]
samps_2 = np.random.multivariate_normal(mu_2, cov_2, N // 2)
samps = np.concatenate([samps_1, samps_2])
# Estimate the mu and covariance matrix for the agents.
mu_hat = samps.mean(axis=0)
cov_hat = np.cov(samps, rowvar=False)
# Sample trajectories for the agents using the estimated mu and covariance matrix.
samps_hat = np.random.multivariate_normal(mu_hat, cov_hat, N)
# Plot the true and generated joint distributions of the agent trajectories.
(fig, axs) = plt.subplots(nrows=1, ncols=2)
sns.kdeplot(samps[:, 0], samps[:, 1], fill=True, ax=axs[0])
sns.kdeplot(samps_hat[:, 0], samps_hat[:, 1], fill=True, ax=axs[1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment