Skip to content

Instantly share code, notes, and snippets.

@larsoner
Created September 12, 2018 16:36
Show Gist options
  • Save larsoner/f34c56646edb25ee0329ea53f10d6c14 to your computer and use it in GitHub Desktop.
Save larsoner/f34c56646edb25ee0329ea53f10d6c14 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
"""
Try realtime movecomp.
"""
import time
import numpy as np
from scipy.spatial.distance import cdist
import mne
from mne.chpi import (_get_hpi_initial_fit, _setup_hpi_struct,
_fit_cHPI_amplitudes, _fit_magnetic_dipole,
_fit_chpi_quat)
from mne.transforms import (apply_trans, invert_transform,
rot_to_quat, quat_to_rot)
raw = mne.io.read_raw_fif(
'../random_data/erm_and_phantom/uw/sliding/'
'phantom_slow_up_down_raw.fif', allow_maxshield='yes')
raw.crop(3., 6.).load_data() # for simplicity do just a few sec
assert raw.info['sfreq'] == 1000.
n_window = 200
lims = np.arange(0, len(raw.times), n_window)
destination = (0., 0., 0.04) # destination in device coords
hpi_dig_head_rrs = _get_hpi_initial_fit(raw.info)
hpi = _setup_hpi_struct(raw.info, n_window)
hpi_coil_dists = cdist(hpi_dig_head_rrs, hpi_dig_head_rrs)
dev_head_t = raw.info['dev_head_t']['trans']
head_dev_t = invert_transform(raw.info['dev_head_t'])['trans']
hpi_dig_dev_rrs = apply_trans(head_dev_t, hpi_dig_head_rrs)
hpi['n_freqs'] = len(hpi['freqs'])
dist_limit = 0.005
gof_limit = 0.85 # low for this phantom data
def do_fits():
last = dict(sin_fit=None, coil_dev_rrs=hpi_dig_dev_rrs,
quat=np.concatenate([rot_to_quat(dev_head_t[:3, :3]),
dev_head_t[:3, 3]]))
for start, stop in zip(lims[:-1], lims[1:]):
t = start / raw.info['sfreq']
data, times = raw[:, start:stop]
time_sl = slice(start, stop)
# 1. Fit amplitudes for each channel from each of the N cHPI sinusoids
sin_fit = _fit_cHPI_amplitudes(raw, time_sl, hpi, 0., verbose=False)
# skip this window if bad
# logging has already been done! Maybe turn this into an Exception
if sin_fit is None:
raise RuntimeError('Bad sin fit!') # should probably reuse
# check if data has sufficiently changed
if last['sin_fit'] is not None: # first iteration
# The sign of our fits is arbitrary
flips = np.sign((sin_fit * last['sin_fit']).sum(-1, keepdims=True))
sin_fit *= flips
corr = np.corrcoef(sin_fit.ravel(), last['sin_fit'].ravel())[0, 1]
# check to see if we need to continue
if corr * corr > 0.98:
# don't need to refit data
print('%s: could reuse here' % t)
# update 'last' sin_fit *before* inplace sign mult
last['sin_fit'] = sin_fit.copy()
#
# 2. Fit magnetic dipole for each coil to obtain coil positions
# in device coordinates
#
outs = [_fit_magnetic_dipole(f, pos, hpi['coils'], hpi['scale'],
hpi['method'])
for f, pos in zip(sin_fit, last['coil_dev_rrs'])]
this_coil_dev_rrs = np.array([o[0] for o in outs])
# filter coil fits based on the correspodnace to digitization geometry
use_mask = np.ones(hpi['n_freqs'], bool)
these_dists = cdist(this_coil_dev_rrs, this_coil_dev_rrs)
these_dists = np.abs(hpi_coil_dists - these_dists)
# there is probably a better algorithm for finding the bad ones...
good = False
while not good:
d = these_dists[use_mask][:, use_mask]
d_bad = (d > dist_limit)
good = not d_bad.any()
if not good:
if use_mask.sum() == 2:
use_mask[:] = False
break # failure
# exclude next worst point
badness = (d * d_bad).sum(axis=0)
exclude_coils = np.where(use_mask)[0][np.argmax(badness)]
use_mask[exclude_coils] = False
good = use_mask.sum() >= 3
if not good:
raise RuntimeError(
'%s/%s good HPI fits, cannot determine the transformation!'
% (use_mask.sum(), hpi['n_freqs']))
#
# 3. Fit the head translation and rotation params (minimize error
# between coil positions and the head coil digitization positions)
#
this_quat, g = _fit_chpi_quat(this_coil_dev_rrs[use_mask],
hpi_dig_head_rrs[use_mask],
last['quat'])
if g < gof_limit:
raise RuntimeError('Bad coil fit! (g=%7.3f)' % (g,))
# Convert quaterion to transform
this_dev_head_t = np.concatenate(
(quat_to_rot(this_quat[:3]),
this_quat[3:][:, np.newaxis]), axis=1)
this_dev_head_t = np.concatenate((this_dev_head_t, [[0, 0, 0, 1.]]))
if __name__ == '__main__':
t0 = time.time()
do_fits()
elapsed = time.time() - t0
print('Ran %0.1fx realtime' % (raw.times[-1] / elapsed))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment