Skip to content

Instantly share code, notes, and snippets.

@simonjisu
Created March 22, 2020 08:18
Show Gist options
  • Save simonjisu/57c6e2b89b4c9457541809ec5b5f51c9 to your computer and use it in GitHub Desktop.
Save simonjisu/57c6e2b89b4c9457541809ec5b5f51c9 to your computer and use it in GitHub Desktop.
prml excercise 1.4 code
import numpy as np
import matplotlib.pyplot as plt
def g(y):
"""x = g(y)"""
return np.log(y) - np.log(1-y) + 5
def g_inv(x):
"""y = g^{-1}(x)"""
return 1 / (1 + np.exp(-x + 5))
def gaussian(x, mu, sigma):
"""p(x)"""
return (1/sigma*np.sqrt(2*np.pi)) * np.exp((-1/2)*((x-mu)/sigma)**2)
def dxdy(y):
return 1 / (y - y**2)
def scaler(x):
"""for drawing"""
x_max = x.max()
x_min = x.min()
return (x - x_min) / (x_max - x_min)
# vairable
np.random.seed(88)
N = 50000
mu = 6.0
sigma = 1.0
sampled_x = np.random.normal(loc=mu, scale=sigma, size=(N,))
sampled_y = g_inv(sampled_x)
x = np.linspace(0, 10, N)
y = g_inv(x)
px = gaussian(x, mu, sigma)
py = gaussian(g(y), mu, sigma)
py_real = px * np.abs(dxdy(y))
# drawing
fig, ax= plt.subplots(1, 1, figsize=(8, 6))
n, bins, patches = ax.hist(sampled_x, bins=50, alpha=0.8)
p_bins = np.array([p.get_height() for p in patches])
p_bins_normed = 0.5*scaler(p_bins)
for i, p in enumerate(patches):
p.set_height(p_bins_normed[i])
px_normed = 0.5*scaler(px) # normalize to 0~.5
ax.plot(x, px_normed, c="r", label="$p_x(x)$")
ax.plot(x, g_inv(x), c="g", label="$g^{-1}(x)$")
ax.set_ylim(0, 1)
ax.set_xlim(0, 10)
ax2 = ax.twiny()
n, bins, patches = ax2.hist(sampled_y, bins=50, alpha=0.8, orientation="horizontal")
p_bins = np.array([p.get_width() for p in patches])
p_bins_normed = scaler(p_bins)
for i, p in enumerate(patches):
p.set_width(p_bins_normed[i])
py_normed = scaler(py)
ax2.set_xlim(0, 2)
ax2.plot(py_normed, y, label="$p_x(g(y))$", c="orange")
py_real_normed = scaler(py_real)
yy = np.linspace(0, 1, len(p_bins_normed))
ax2.plot(py_real_normed, y, c="b", label="$p_x(g(y)) | \dfrac{dx}{dy} | $")
ax2.legend(loc=3)
ax.legend(loc=1)
ax.set_xlabel("$x$")
ax.set_ylabel("$y$", rotation=0)
ax.annotate("$p_y(y)$", (5.2, 0.9), fontsize=14, c="b")
ax.annotate("$p_x(x)$", (7.2, 0.35), fontsize=14, c="r")
ax.scatter(mu, px_normed.max(), c="r")
ax.scatter(mu, g_inv(mu), c="k")
ax.plot((mu, mu), (px_normed.max(), y[py_normed.argmax()]), "k--")
ax.plot((5, mu), (y[py_normed.argmax()], y[py_normed.argmax()]), "k--")
ax2.scatter(py_normed.max(), y[py_normed.argmax()], c="orange")
ax2.scatter(py_real_normed.max(), y[py_real_normed.argmax()], c="b")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment