Created
December 5, 2014 04:41
-
-
Save DomNomNom/11d9a982d04d186f8cf8 to your computer and use it in GitHub Desktop.
Creating a spline with just a linear system
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import scipy.sparse as sparse | |
import scipy.sparse.linalg | |
xs = [25.0, 50.0, 60.0] | |
ys = [ 1.0, -1.0, 0.3] | |
# assert xs == sorted(xs) | |
numSamples = 100 | |
''' | |
returns y coordinates for the smoothest curve through the given points | |
Parameters: | |
xs: the x coordinates of the points | |
ys: the y coordinates of the points (same length as xs) | |
boundary: | |
a choice of: | |
None: not contrained, becomes a straight line near the towards the ends | |
'neumann': boundaryTargets define the slope at the boundaries | |
'dirichlet': boundaryTargets define the values at the boundaries | |
boundaryTargetLeft: see boundary | |
boundaryTargetRight: see boundary | |
numSamples: the length of the output | |
closenessFactor: decreasing this will prioritize smoothness over accuracy | |
''' | |
def spline(xs, ys, boundary=None, boundaryTargetLeft=0.0, boundaryTargetRight=0.0, numSamples=100, closenessFactor=1.0): | |
assert len(xs) == len(ys) | |
derivativeMatrix = sparse.lil_matrix((numSamples, numSamples)) #np.zeros((numSamples, numSamples)) | |
def setVal(i, j, value): | |
if ( | |
0 <= i < numSamples and | |
0 <= j < numSamples | |
): | |
derivativeMatrix[i, j] = value | |
# I could use np.eye here, but I found this more intuitive | |
for i in xrange(numSamples): | |
setVal(i, i-1, 1); | |
setVal(i, i , -2); | |
setVal(i, i+1, 1); | |
# deal with boundaries | |
b = 1.0 # boundary factor, increasing this will prioritize boundary constraints over other constraints | |
boundaryConstraints = { | |
None: [0.0] * 4, | |
'neumann': [-b, b, -b, b], | |
'dirichlet': [b, 0.0, 0.0, b], | |
} | |
assert boundary in boundaryConstraints | |
boundaryConstraint = boundaryConstraints[boundary] | |
setVal(0, 0, boundaryConstraint[0]) | |
setVal(0, 1, boundaryConstraint[1]) | |
setVal(numSamples-1, numSamples-2, boundaryConstraint[2]) | |
setVal(numSamples-1, numSamples-1, boundaryConstraint[3]) | |
constraintMatrix = sparse.lil_matrix((len(xs), numSamples)) | |
for i, x in enumerate(xs): | |
constraintMatrix[i, int(x)] = closenessFactor | |
A = sparse.vstack((derivativeMatrix, constraintMatrix)) | |
# b can be a non-sparse matrix | |
b = np.hstack((np.zeros(numSamples), np.array(ys) * closenessFactor )) | |
b[ 0] = boundaryTargetLeft | |
b[numSamples-1] = boundaryTargetRight | |
C = A.T.dot(A) # make it a square matrix | |
Atb = A.T.dot(b) | |
out = scipy.sparse.linalg.spsolve(C, Atb) | |
# out = scipy.sparse.linalg.lsqr(A, b)[0] # Why does this not give the same result as the next line? | |
# out = np.linalg.lstsq(A.todense(), b)[0] | |
return out | |
# closes the matplotlib window when escape is pressed | |
def registerEscape(): | |
def quit_figure(event): | |
if event.key == 'escape': | |
plt.close(event.canvas.figure) | |
plt.gcf().canvas.mpl_connect('key_press_event', quit_figure) | |
if __name__ == '__main__': | |
registerEscape() | |
xPoints = np.arange(0, numSamples) | |
plt.plot(xPoints, spline(xs, ys ), 'g--') # The smoothest curve through the points | |
plt.plot(xPoints, spline(xs, ys, 'neumann', 0.0, 0.0), 'b-' ) # The smoothest curve through the points that is flat at the edges | |
plt.plot(xPoints, spline(xs, ys, 'dirichlet', -1.0, 1.0), 'k--') # The smoothest curve through the points that hits -1 and 1 at the left/right edges | |
plt.plot(xs, ys, 'ro') | |
plt.show() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment