Skip to content

Instantly share code, notes, and snippets.

@danoneata
Last active May 5, 2020 12:58
Show Gist options
  • Save danoneata/75c5bbe8d651d4ec0e804995010a850d to your computer and use it in GitHub Desktop.
Save danoneata/75c5bbe8d651d4ec0e804995010a850d to your computer and use it in GitHub Desktop.
Plot a 2D Gaussian
import numpy as np
import pdb
from matplotlib import pyplot as plt
from scipy.stats import multivariate_normal
def gauss2d(mu, sigma, to_plot=False):
w, h = 100, 100
std = [np.sqrt(sigma[0, 0]), np.sqrt(sigma[1, 1])]
x = np.linspace(mu[0] - 3 * std[0], mu[0] + 3 * std[0], w)
y = np.linspace(mu[1] - 3 * std[1], mu[1] + 3 * std[1], h)
x, y = np.meshgrid(x, y)
x_ = x.flatten()
y_ = y.flatten()
xy = np.vstack((x_, y_)).T
normal_rv = multivariate_normal(mu, sigma)
z = normal_rv.pdf(xy)
z = z.reshape(w, h, order='F')
if to_plot:
plt.contourf(x, y, z.T)
plt.show()
return z
MU = [50, 70]
SIGMA = [75.0, 90.0]
z = gauss2d(MU, SIGMA, True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment