Skip to content

Instantly share code, notes, and snippets.

@nden
Last active August 29, 2015 13:56
Show Gist options
  • Save nden/9329645 to your computer and use it in GitHub Desktop.
Save nden/9329645 to your computer and use it in GitHub Desktop.
Fitting a straight line example
import numpy as np
from astropy.modeling.fitting import (_validate_model, _fitter_to_model_params, Fitter, _convert_input)
from astropy.modeling.optimizers import *
def chi_line(measured_vals, updated_model, x_sigma, y_sigma, x):
"""
Chi^2 statistic for fitting a straight line with uncertainties in x and y.
Parameters
----------
measured_vals : array
updated_model : an instance of `~astropy.modeling.ParametricModel`
model with parameters set by the current iteration of the optimizer
x_sigma/y_sigma : array
uncertainties in x/y
"""
model_vals = updated_model(x)
if x_sigma is None and y_sigma is None:
return np.sum((model_vals - measured_vals) ** 2)
elif x_sigma is not None and y_sigma is not None:
weights = 1 / (y_sigma**2 + updated_model.parameters[1]**2 * x_sigma ** 2)
return np.sum((weights * (model_vals - measured_vals)) ** 2)
else:
if x_sigma is not None:
weights = 1 / x_sigma ** 2
else:
weights = 1 / y_sigma ** 2
return np.sum((weights * (model_vals - measured_vals)) ** 2)
class LineFitter(Fitter):
"""
Fit a straight line with uncertainties in both variables
Parameters
----------
optimizer : class or callable
one of the classes in optimizers.py (default: Simplex)
"""
def __init__(self, optimizer=Simplex):
self.statistic = chi_line
super(LineFitter, self).__init__(optimizer, statistic=self.statistic)
def __call__(self, model, x, y, x_sigma=None, y_sigma=None, **kwargs):
"""
Fit data to this model.
Parameters
----------
model : `ParametricModel`
model to fit to x, y, z
x : array
input coordinates
y : array
input coordinates
x_sigma : array
uncertainties in x
y_sigma : array
uncertainties in y
kwargs : dict
optional keyword arguments to be passed to the optimizer or the statistic
Returns
------
model_copy : `ParametricModel`
a copy of the input model with parameters set by the fitter
"""
model_copy = _validate_model(model, self._opt_method.supported_constraints)
farg = _convert_input(x, y)
farg = (model_copy, x_sigma, y_sigma) + farg
p0, _ = model_copy._model_to_fit_params()
fitparams, self.fit_info = self._opt_method(
self.objective_function, p0, farg, **kwargs)
_fitter_to_model_params(model_copy, fitparams)
return model_copy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment