Last active
May 25, 2022 15:55
-
-
Save larsoner/fbe32d57996848395854d5e59dff1e10 to your computer and use it in GitHub Desktop.
Displacement field demo using matched points
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from mne.transforms import _MatchedDisplacementFieldInterpolator | |
import numpy as np | |
import matplotlib.pyplot as plt | |
# Warp from one set of matched points to another using a | |
# nonlinear displacement field using matched points | |
# Use the example like https://stackoverflow.com/questions/32266408 | |
# but in 3D | |
to = np.array([[5, 4, 1], [6, 1, 0], [4, -1, 1], [3, 3, 0]], float) | |
fro = np.array([[0, 2, 2], [2, 2, 1], [2, 0, 2], [0, 0, 1]], float) | |
ndim = fro.shape[-1] | |
n_grid = 20 | |
grid = np.array(np.meshgrid(*[np.linspace(-0.5, 2.5, n_grid)] * 2, [1] * (ndim - 2))).T.reshape(-1, ndim) # noqa: E501 | |
grid_c = grid[:, :2] - grid.min(axis=0)[:2] | |
grid_c = grid_c / grid_c.max(axis=0) | |
grid_c = np.array([grid_c[:, 0], np.zeros_like(grid_c[:, 0]), grid_c[:, 1]]).T | |
assert grid.shape == (n_grid ** 2, ndim) | |
fig, axes = plt.subplots(2, figsize=(6, 6), sharex=True, sharey=True) | |
colors = plt.get_cmap('YlGn')(np.linspace(0.25, 1, to.shape[0])) | |
axes[0].scatter(*to[:, :2].T, c=colors, edgecolors='none', zorder=5, lw=0) | |
axes[0].scatter(*fro[:, :2].T, c=colors, marker='x', zorder=4, lw=2) | |
axes[0].scatter(*grid[:, :2].T, c=grid_c, marker='.', alpha=0.2, lw=2) | |
axes[0].set_ylabel('Original') | |
interp = _MatchedDisplacementFieldInterpolator(fro, to) | |
fro_t = interp(fro) | |
grid_t = interp(grid) | |
axes[1].scatter(*to[:, :2].T, c=colors, edgecolors='none', zorder=5, lw=0) | |
axes[1].scatter(*fro_t[:, :2].T, c=colors, marker='x', zorder=4, lw=2) | |
axes[1].scatter(*grid_t[:, :2].T, c=grid_c, marker='.', alpha=0.2, lw=2) | |
axes[1].scatter(*interp._extrema[:, :2].T, c='k', marker='d') | |
axes[1].set_ylabel('Deformed') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Added: | |
class _MatchedDisplacementFieldInterpolator: | |
"""Interpolate from matched points using a displacement field in ND.""" | |
def __init__(self, fro, to): | |
from scipy.interpolate import LinearNDInterpolator | |
fro = np.array(fro, float) | |
to = np.array(to, float) | |
assert fro.shape == to.shape | |
assert fro.ndim == 2 | |
# this restriction is only necessary because it's what | |
# _fit_matched_points requires | |
assert fro.shape[1] == 3 | |
# Prealign using affine + uniform scaling | |
trans, scale = _fit_matched_points(fro, to, scale=True) | |
trans = _quat_to_affine(trans) | |
trans[:3, :3] *= scale | |
self._affine = trans | |
fro = apply_trans(trans, fro) | |
# Add points at extrema | |
delta = (to.max(axis=0) - to.min(axis=0)) / 2. | |
extrema = np.array([fro.min(axis=0) - delta, fro.max(axis=0) + delta]) | |
self._extrema = np.array( | |
np.meshgrid(*extrema.T)).T.reshape(-1, fro.shape[-1]) | |
fro_concat = np.concatenate((fro, self._extrema)) | |
to_concat = np.concatenate((to, self._extrema)) | |
# Compute the interpolator (which internally uses Delaunay) | |
self._interp = LinearNDInterpolator(fro_concat, to_concat) | |
def __call__(self, x): | |
assert x.ndim in (1, 2) and x.shape[-1] == 3 | |
singleton = x.ndim == 1 | |
out = self._interp(apply_trans(self._affine, x)) | |
out = out[0] if singleton else out | |
return out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment