Skip to content

Instantly share code, notes, and snippets.

@nyk510
Created November 10, 2019 10:43
Show Gist options
  • Save nyk510/996a6e149415464bd8902e01d1180970 to your computer and use it in GitHub Desktop.
Save nyk510/996a6e149415464bd8902e01d1180970 to your computer and use it in GitHub Desktop.
PRML Section3 Figure3.5
import numpy as np
import matplotlib.pyplot as plt
def gaussian_kernel(x, basis=None):
if basis is None:
basis = np.linspace(-1.2, 1.2, 101)
# parameter is my choice >_<
phi = np.exp(- (x.reshape(-1, 1) - basis) ** 2 * 250)
# add bias basis
phi = np.hstack([phi, np.ones_like(phi[:, 0]).reshape(-1, 1)])
return phi
def estimate_ml_weight(x, t, lam, xx):
basis = np.linspace(0, 1, 24)
phi = gaussian_kernel(x, basis=basis)
w_ml = np.linalg.inv(phi.T.dot(phi) + lam * np.eye(len(basis) + 1)).dot(phi.T).dot(t)
xx_phi = gaussian_kernel(xx, basis=basis)
pred = xx_phi.dot(w_ml)
return pred
n_samples = 100
fig, axes = plt.subplots(ncols=2, nrows=3, figsize=(10, 12), sharey=True, sharex=True)
for i, l in enumerate([2.6, -.31, -2.4]):
ax = axes[i]
preds = []
for n in range(n_samples):
x = np.random.uniform(0, 1, 40)
xx = np.linspace(0, 1, 101)
t = np.sin(x * 2 * np.pi) + .2 * np.random.normal(size=len(x))
pred = estimate_ml_weight(x, t, lam=np.exp(l), xx=xx)
if n < 20:
ax[0].plot(xx, pred, c='black', alpha=.8, linewidth=1)
preds.append(pred)
ax[1].plot(xx, np.sin(2 * xx * np.pi), c='black', label=f'Lambda = {l}')
ax[1].plot(xx, np.mean(preds, axis=0), '--', c='black')
ax[1].legend()
fig.tight_layout()
fig.savefig('bias_variance.png', dpi=120)
@nyk510
Copy link
Author

nyk510 commented Nov 10, 2019

bias_variance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment