Created
April 30, 2025 06:50
-
-
Save syrte/2d59285666612b693d87fbc8c44a2153 to your computer and use it in GitHub Desktop.
fit gaussian process and smooth piecewise linear function
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
def fit_gp(x, y, y_err, y_avg=None, x_pred=None, scale=1): | |
from sklearn.gaussian_process import GaussianProcessRegressor | |
from sklearn.gaussian_process.kernels import RBF, ConstantKernel, Matern, WhiteKernel | |
from sklearn.exceptions import ConvergenceWarning | |
import warnings | |
k1 = ConstantKernel(1.0, (1e-3, 1e3)) | |
# k2 = RBF(length_scale=scale, length_scale_bounds=(scale*0.1, scale*10)) | |
k2 = RBF(length_scale=scale, | |
length_scale_bounds=(scale, scale)) # set scale to be const | |
k3 = WhiteKernel(noise_level=0.1, noise_level_bounds=(1e-5, 1e1)) | |
kernel = k1 * k2 + k3 | |
if y_avg is None: | |
y_avg = 0 | |
normalize_y = True | |
else: | |
normalize_y = False | |
gp = GaussianProcessRegressor(kernel=kernel, | |
alpha=y_err**2, | |
normalize_y=normalize_y, | |
n_restarts_optimizer=10) | |
with warnings.catch_warnings(): | |
warnings.filterwarnings("ignore", category=ConvergenceWarning) | |
gp.fit(x.reshape(-1, 1), y - y_avg) | |
print(gp.kernel_) | |
if x_pred is None: | |
x_pred = np.linspace(np.min(x), np.max(x), 500) | |
y_pred, y_sig = gp.predict(x_pred.reshape(-1, 1), return_std=True) | |
return x_pred, y_pred + y_avg, y_sig | |
def piecewise_linear(x, a1, a2, x0, y0, k): | |
return y0 + a1*(x - x0) + (a2 - a1)/k * np.log1p(np.exp(k * (x - x0))) | |
def fit_piecewise(x, y, x_pred=None): | |
a1, a2, x0, y0, k = curve_fit( | |
piecewise_linear, x, y, p0=[0, 4, 8.5, -2.3, 2])[0] | |
print(f"{a1:.2g}, {a2:.2g}, {x0:.2g}, {y0:.2g}, {k:.2g}") | |
if x_pred is None: | |
x_pred = np.linspace(np.min(x), np.max(x), 500) | |
y_pred = piecewise_linear(x_pred, a1, a2, x0, y0, k) | |
return x_pred, y_pred | |
# Parameters | |
a1 = 1.0 # slope before x0 | |
a2 = 5 # slope after x0 | |
x0 = 0.0 # transition x | |
y0 = 0.0 # value at x0 | |
k = 5.0 # smoothness | |
# Plot | |
x = np.linspace(-5, 5, 500) | |
y = piecewise_linear(x, a1, a2, x0, y0, k) | |
plt.plot(x, y) | |
plt.xlabel("x") | |
plt.ylabel("f(x)") | |
plt.title("Smooth Piecewise Linear Function") | |
plt.grid(True) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment