Skip to content

Instantly share code, notes, and snippets.

@rmsander
Created January 31, 2021 22:07
Show Gist options
  • Save rmsander/68aed8a0600c944bf44303d14bd5c084 to your computer and use it in GitHub Desktop.
Save rmsander/68aed8a0600c944bf44303d14bd5c084 to your computer and use it in GitHub Desktop.
import matplotlib.pyplot as plt
import numpy as np
p_values = [1, 1.5, 2]
labels = ["Lasso (L=1)", "ElasticNet (L=1.5)", "Ridge (L=2)"]
xx, yy = np.meshgrid(np.linspace(-3, 3, num=101), np.linspace(-3, 3, num=101))
fig, axes = plt.subplots(ncols=3, figsize=(28, 7))
for p, ax, l in zip(p_values, axes.flat, labels):
if p == 0:
zz = (xx != 0).astype(int) + (yy != 0).astype(int)
ax.imshow(zz, cmap='bwr', extent=(xx.min(),xx.max(),yy.min(),yy.max()), aspect="auto")
else:
if np.isinf(p):
zz = np.maximum(np.abs(xx),np.abs(yy))
else:
zz = ((np.abs((xx))**p) + (np.abs((yy))**p))**(1./p)
ax.contourf(xx, yy, zz, 30, cmap='bwr')
ax.contour(xx,yy,zz, [1], colors='red', linewidths = 2)
ax.set_title("{}".format(l))
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment