Skip to content

Instantly share code, notes, and snippets.

@bgshih
Created October 21, 2015 09:48
Show Gist options
  • Save bgshih/e252ba7148590a381f9c to your computer and use it in GitHub Desktop.
Save bgshih/e252ba7148590a381f9c to your computer and use it in GitHub Desktop.
A simple example of Thin Plate Spline (TPS) transformation in Numpy.
import ipdb
import numpy as np
import numpy.linalg as nl
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist, cdist, squareform
def makeT(cp):
# cp: [K x 2] control points
# T: [(K+3) x (K+3)]
K = cp.shape[0]
T = np.zeros((K+3, K+3))
T[:K, 0] = 1
T[:K, 1:3] = cp
T[K, 3:] = 1
T[K+1:, 3:] = cp.T
R = squareform(pdist(cp, metric='euclidean'))
R = R * R
R[R == 0] = 1 # a trick to make R ln(R) 0
R = R * np.log(R)
np.fill_diagonal(R, 0)
T[:K, 3:] = R
return T
def liftPts(p, cp):
# p: [N x 2], input points
# cp: [K x 2], control points
# pLift: [N x (3+K)], lifted input points
N, K = p.shape[0], cp.shape[0]
pLift = np.zeros((N, K+3))
pLift[:,0] = 1
pLift[:,1:3] = p
R = cdist(p, cp, 'euclidean')
R = R * R
R[R == 0] = 1
R = R * np.log(R)
pLift[:,3:] = R
return pLift
# source control points
x, y = np.linspace(-1, 1, 3), np.linspace(-1, 1, 3)
x, y = np.meshgrid(x, y)
xs = x.flatten()
ys = y.flatten()
cps = np.vstack([xs, ys]).T
# target control points
xt = xs + np.random.uniform(-0.3, 0.3, size=xs.size)
yt = ys + np.random.uniform(-0.3, 0.3, size=ys.size)
# construct T
T = makeT(cps)
# solve cx, cy (coefficients for x and y)
xtAug = np.concatenate([xt, np.zeros(3)])
ytAug = np.concatenate([yt, np.zeros(3)])
cx = nl.solve(T, xtAug) # [K+3]
cy = nl.solve(T, ytAug)
# dense grid
N = 30
x = np.linspace(-2, 2, N)
y = np.linspace(-2, 2, N)
x, y = np.meshgrid(x, y)
xgs, ygs = x.flatten(), y.flatten()
gps = np.vstack([xgs, ygs]).T
# transform
pgLift = liftPts(gps, cps) # [N x (K+3)]
xgt = np.dot(pgLift, cx.T)
ygt = np.dot(pgLift, cy.T)
# display
plt.xlim(-2.5, 2.5)
plt.ylim(-2.5, 2.5)
plt.subplot(1, 2, 1)
plt.title('Source')
plt.grid()
plt.scatter(xs, ys, marker='+', c='r', s=40)
plt.scatter(xgs, ygs, marker='.', c='r', s=5)
plt.subplot(1, 2, 2)
plt.title('Target')
plt.grid()
plt.scatter(xt, yt, marker='+', c='b', s=40)
plt.scatter(xgt, ygt, marker='.', c='b', s=5)
plt.show()
@MrGiskard
Copy link

Hi, I am trying to implement this on a 640 x 360 grid using 3 hard-coded source and target points:

xs = [200,400,500]
ys = [10,10,10]
cps = np.vstack([xs, ys]).T

and

xt=[250,450,550]
yt=[30,40,50]

However when I execute the process I get a singular matrix error on line 56

cx = nl.solve(T, xtAug) # [K+3]

I'm guessing there is a problem with the way the control points are generated and so the matrix coming out of the makeT method is incorrect but I cannot figure out what exactly is the issue. Any help would be greatly appreciated.

@pebbie
Copy link

pebbie commented Apr 24, 2018

the assumptions of thin-plate-spline are the number of control points should be greater or equal to 3 points and those points do not lie on the same line (collinear).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment