Created
May 21, 2012 16:06
-
-
Save perimosocordiae/2763049 to your computer and use it in GitHub Desktop.
User-friendly interface to scipy's curve_fit function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
from scipy.optimize import curve_fit | |
import re | |
from math import * | |
from numpy import * | |
from sys import stdin | |
from optparse import OptionParser | |
op = OptionParser(usage='%prog [options] <function_of_x> [input_file]') | |
op.add_option('-a','--args',type='str',default=None,help='CSV symbolic parameters') | |
op.add_option('-g','--guess',type='str',default=None,help='CSV initial guesses for parameters') | |
op.add_option('-c','--context',action='store_true',default=False,help='show the fitted parameters in context') | |
op.add_option('-p','--plot',action='store_true',default=False,help='plot the fitted function over the data') | |
opts,args = op.parse_args() | |
if not args: | |
op.error('Must supply symbolic function. Example: a+x^b') | |
# here be dragons! beware malicious users | |
fstr = args[0].split('=')[-1].replace('^','**') | |
if opts.args is None: | |
# this breaks on equations with any named math functions (exp,log,etc) | |
# opts.args = ','.join(re.findall('[a-w]',fstr,re.I)) | |
op.error("Must supply --args. Example: -a 'a,b'") | |
function = eval("lambda x,%s: %s" % (opts.args, fstr)) | |
fh = open(args[1]) if len(args) >= 2 else stdin | |
data = loadtxt(fh) | |
if len(data.shape) != 2: | |
op.error('Data must be two-dimensional (last column is y)') | |
x = data[:,:-1].T # curve_fit expects X in DxN form?? | |
if x.shape[0] == 1: x = x.flatten() | |
y = data[:,-1].flatten() | |
p0 = map(float,opts.guess.split(',')) if opts.guess else [1 for _ in opts.args.split(',')] | |
popt,pcov = curve_fit(function,x,y,p0) | |
if opts.context: | |
fstr = args[0] | |
for param,val in zip(opts.args.split(','),popt): | |
fstr = fstr.replace(param,str(val)) | |
print fstr | |
else: | |
for i,param in enumerate(opts.args.split(',')): | |
print "%s = %f (%f)" % (param,popt[i],pcov[i,i]) | |
if opts.plot: | |
if len(x.shape)>1 and x.shape[1] != 1: op.error('Can only plot fitted curves for 1-d domains') | |
from matplotlib import pyplot | |
ynew = function(x,*popt) | |
pyplot.plot(x,y,'b.',x,ynew,'r-') | |
pyplot.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment