Created
July 26, 2024 16:29
-
-
Save blakeNaccarato/8cb6a335bc5784eaff487083aa681ff1 to your computer and use it in GitHub Desktop.
Fitting experimental data to model functions using `scipy.optimize.curve_fit`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Get fits and errors.""" | |
from functools import partial | |
from warnings import catch_warnings | |
from numpy import array, diagonal, full, inf, isinf, linspace, nan, sqrt, where | |
from scipy.optimize import OptimizeWarning, curve_fit | |
from scipy.stats import t | |
# Docs: https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html | |
# Module-level variables denoted by all-caps, which may become function arguments in | |
# your implementation, if you wrap all this into a proper function for instance | |
# We want to fit this function, maybe it even comes from an external library | |
def fun(b, c, a, const, x, other_const): | |
"""Maybe this comes from an external source and you can't control argument order. | |
This argument order is all out-of-whack, but we need it in a certain order since | |
`scipy.optimize.curve_fit` is picky! | |
""" | |
return a * x**2 + b * x + c + const + other_const | |
# `const` and `other_const` in the external `fun` we don't want to fit for | |
# They will be constants, supplied by us using `functools.partial` later | |
CONST = 0.25 | |
OTHER_CONST = 0.75 | |
# Your experimental data | |
EXPERIMENTAL_X = array([0.0, 1.2, 2.5, 3, 4.2, 5.1]) | |
EXPERIMENTAL_Y = array([2.0, 4.5, 10.9, 13.1, 25.6, 31.8]) | |
# Uncertainty in each of your `y` measurements from `x` data, computed through uncertainty approaches | |
# This will propagate the uncertainty through to the fit | |
# Optional, absolute if absolute_sigma is True, but you can try relative as well | |
SIGMA_Y = [0.01, 0.003, 0.1, 0.2, 0.01, 0.01] | |
SIGMAS_ARE_ABSOLUTE = True | |
# In this example, each y value (e.g. 4.5) is actually composed as the mean of 3 samples | |
# taken at that x value (at experiment time, for instance). Set this to `1` if you only | |
# have one sample per x value, which is also common. | |
NUMBER_OF_SAMPLES_FOR_EACH_Y = 3 | |
# Pull from the student-t distribution. `ci` of 0.95 is 95% CI, for instance Will be | |
# used to compute uncertainty in your fit parameters propagated from uncertainty in your | |
# y-values | |
CONFIDENCE_INTERVAL_THRESH = 0.95 | |
CONFIDENCE_INTERVAL_95 = t.interval( | |
CONFIDENCE_INTERVAL_THRESH, NUMBER_OF_SAMPLES_FOR_EACH_Y | |
)[1] | |
# Here we redefine `fun` to have appropriate argument order | |
def my_fun(x, a, b, c, const, other_const): | |
"""So you redefine it in the following order. | |
independent_variable, e.g. time or x: x | |
parameters to fit: a, b, c | |
fixed parameter: const | |
""" | |
return fun(b, c, a, const, x, other_const) | |
# Then we use `functools.partial` to fix the constant parameter(s). You could also bake | |
# these in to `my_fun` above, but we see use of `partial` here in case we sometimes want | |
# to fit different sets of parameters, and the constants aren't always the same params. | |
MODEL = partial(my_fun, const=CONST, other_const=OTHER_CONST) | |
# Perform fit, filling "nan" on failure or when covariance computation fails | |
with catch_warnings(): | |
try: | |
# Because curve fit takes guesses/bounds just as tuples of values, it's very | |
# sensitive to argument order of your model function, and assumes you have | |
# complete control over the order of its arguments in the function definition. | |
# That's why we have to "wrap" `fun`, because maybe it comes from an external | |
# source we don't control | |
fits, pcov = curve_fit( | |
f=MODEL, | |
p0=[1, 1, 1], # Optional, guesses for [a, b, c] | |
# Expects e.g. ([a_lower, b_lower, c_lower], [a_upper, b_upper, c_upper]) | |
bounds=([0, -inf, 0], [inf, inf, inf]), # Optional | |
xdata=EXPERIMENTAL_X, # This should be the same 'x' from your exp data | |
ydata=EXPERIMENTAL_Y, # Experimental `y` data, aka result of `my_fun` | |
sigma=SIGMA_Y, # Optional | |
absolute_sigma=SIGMAS_ARE_ABSOLUTE, # Optional | |
method="trf", # Optional, algo to fit with | |
) | |
except (RuntimeError, OptimizeWarning): | |
# We gotta catch fit errors and just return `nan` if it fails | |
dim = 3 # Number of parameters, aka a, b, c is 3 params | |
fits = full(dim, nan) # Fill with "nan" on failure | |
pcov = full((dim, dim), nan) # Fill with "nan" on failure | |
# Compute confidence interval | |
standard_errors = sqrt(diagonal(pcov)) | |
errors = standard_errors * CONFIDENCE_INTERVAL_95 | |
# Catching `OptimizeWarning` should be enough, but let's explicitly check for inf | |
fits = where(isinf(errors), nan, fits) | |
errors = where(isinf(errors), nan, errors) | |
# Embed the fit parameters into the function, so now `fitted_model` varies only in `x` | |
# Here we "unpack" the `fits` tuple into the left-hand-side, assigning three variables | |
# at once, corresponding to our fit parameters | |
a_fit, b_fit, c_fit = fits | |
a_err, b_err, c_err = errors | |
fitted_model = partial(MODEL, a=a_fit, b=b_fit, c=c_fit) | |
# We can evaluate `fitted_model` at any `x` | |
arbitrary_x = linspace(0, 5, 6) | |
# We can next string join statements to build a nice output. The " = " syntax inside | |
# curly braces is a nice feature that automatically renders the variable name | |
print( # noqa: T201 | |
"\n".join([ | |
"", | |
f"fitted_model: {a_fit:.4f} * x**2 + {b_fit:.4f} * x + {c_fit:.4f} + {CONST} + {OTHER_CONST}", | |
f"95% CI:\n\t{a_fit = :.4f} ± {a_err:.4f} \n\t{b_fit = :.4f} ± {b_err:.4f} \n\t{c_fit = :.4f} ± {c_err:.4f}", # noqa: E203 | |
]) | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment