Skip to content

Instantly share code, notes, and snippets.

@syrte
Created April 30, 2025 06:50
Show Gist options
  • Save syrte/2d59285666612b693d87fbc8c44a2153 to your computer and use it in GitHub Desktop.
Save syrte/2d59285666612b693d87fbc8c44a2153 to your computer and use it in GitHub Desktop.
fit gaussian process and smooth piecewise linear function
import numpy as np
import matplotlib.pyplot as plt
def fit_gp(x, y, y_err, y_avg=None, x_pred=None, scale=1):
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel, Matern, WhiteKernel
from sklearn.exceptions import ConvergenceWarning
import warnings
k1 = ConstantKernel(1.0, (1e-3, 1e3))
# k2 = RBF(length_scale=scale, length_scale_bounds=(scale*0.1, scale*10))
k2 = RBF(length_scale=scale,
length_scale_bounds=(scale, scale)) # set scale to be const
k3 = WhiteKernel(noise_level=0.1, noise_level_bounds=(1e-5, 1e1))
kernel = k1 * k2 + k3
if y_avg is None:
y_avg = 0
normalize_y = True
else:
normalize_y = False
gp = GaussianProcessRegressor(kernel=kernel,
alpha=y_err**2,
normalize_y=normalize_y,
n_restarts_optimizer=10)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=ConvergenceWarning)
gp.fit(x.reshape(-1, 1), y - y_avg)
print(gp.kernel_)
if x_pred is None:
x_pred = np.linspace(np.min(x), np.max(x), 500)
y_pred, y_sig = gp.predict(x_pred.reshape(-1, 1), return_std=True)
return x_pred, y_pred + y_avg, y_sig
def piecewise_linear(x, a1, a2, x0, y0, k):
return y0 + a1*(x - x0) + (a2 - a1)/k * np.log1p(np.exp(k * (x - x0)))
def fit_piecewise(x, y, x_pred=None):
a1, a2, x0, y0, k = curve_fit(
piecewise_linear, x, y, p0=[0, 4, 8.5, -2.3, 2])[0]
print(f"{a1:.2g}, {a2:.2g}, {x0:.2g}, {y0:.2g}, {k:.2g}")
if x_pred is None:
x_pred = np.linspace(np.min(x), np.max(x), 500)
y_pred = piecewise_linear(x_pred, a1, a2, x0, y0, k)
return x_pred, y_pred
# Parameters
a1 = 1.0 # slope before x0
a2 = 5 # slope after x0
x0 = 0.0 # transition x
y0 = 0.0 # value at x0
k = 5.0 # smoothness
# Plot
x = np.linspace(-5, 5, 500)
y = piecewise_linear(x, a1, a2, x0, y0, k)
plt.plot(x, y)
plt.xlabel("x")
plt.ylabel("f(x)")
plt.title("Smooth Piecewise Linear Function")
plt.grid(True)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment