Created
June 8, 2021 20:49
-
-
Save ClementC/1a33a936329b3142807e976ccc1e0e62 to your computer and use it in GitHub Desktop.
Small Python snippet to generate an illustration of the overfitting phenomenon
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import warnings | |
np.random.seed(1) # This line is here to get a nice plot on the first try, you can comment it | |
warnings.simplefilter('ignore', np.RankWarning) | |
%matplotlib inline | |
# Main parameters | |
x_noise = 0.2 | |
y_noise = 0.3 | |
x_min, x_max = -2, 2 | |
# Generate the training and testing data with noise (the true model is y = x ** 2) | |
train_x = [elem + x_noise * np.random.normal() for elem in range(x_min, x_max + 1)] | |
test_x = (x_max - x_min + 1) * np.random.rand(8) + x_min | |
train_y = [elem**2 + x_noise * np.random.normal() for elem in train_x] | |
test_y = [elem**2 + y_noise * np.random.normal() for elem in test_x] | |
# Compute a perfect polynomial fit | |
p = np.poly1d(np.polyfit(train_x, train_y, 1 + len(train_x))) | |
# Plot everything | |
fig, ax = plt.subplots(figsize=(12, 8)) | |
ax.scatter(train_x, train_y, color="blue", edgecolor="k", s=80, alpha=0.5, label="Training set") | |
ax.scatter(test_x, test_y, color="green", edgecolor="k", s=80, alpha=0.5, label="Testing set") | |
x_ = np.linspace(x_min - 1, x_max + 1, 100) | |
ax.plot(x_, x_ ** 2, color="blue", alpha=0.5, label="Best fit") | |
ax.plot(x_, p(x_), color="red", alpha=0.5, label="Overfitted model") | |
ax.hlines(np.mean(train_y), x_min - 1, x_max + 1, color="orange", alpha=0.5, label="Underfitted model") | |
plt.figtext(0.1, 0.9, '$y$') | |
ax.set_xlabel(r"$x$") #, loc="right") | |
ax.spines['right'].set_visible(False) | |
ax.spines['top'].set_visible(False) | |
ax.set_ylim(-1, (x_max + 1) ** 2 + 0.5) | |
ax.set_xlim(x_min - 1.5, x_max + 1.5) | |
plt.tick_params(left=False, right=False, labelleft=False, | |
labelbottom=False, bottom=False) | |
ax.legend(loc=0, frameon=False); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment