Created
September 12, 2018 16:36
-
-
Save larsoner/f34c56646edb25ee0329ea53f10d6c14 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Try realtime movecomp. | |
""" | |
import time | |
import numpy as np | |
from scipy.spatial.distance import cdist | |
import mne | |
from mne.chpi import (_get_hpi_initial_fit, _setup_hpi_struct, | |
_fit_cHPI_amplitudes, _fit_magnetic_dipole, | |
_fit_chpi_quat) | |
from mne.transforms import (apply_trans, invert_transform, | |
rot_to_quat, quat_to_rot) | |
raw = mne.io.read_raw_fif( | |
'../random_data/erm_and_phantom/uw/sliding/' | |
'phantom_slow_up_down_raw.fif', allow_maxshield='yes') | |
raw.crop(3., 6.).load_data() # for simplicity do just a few sec | |
assert raw.info['sfreq'] == 1000. | |
n_window = 200 | |
lims = np.arange(0, len(raw.times), n_window) | |
destination = (0., 0., 0.04) # destination in device coords | |
hpi_dig_head_rrs = _get_hpi_initial_fit(raw.info) | |
hpi = _setup_hpi_struct(raw.info, n_window) | |
hpi_coil_dists = cdist(hpi_dig_head_rrs, hpi_dig_head_rrs) | |
dev_head_t = raw.info['dev_head_t']['trans'] | |
head_dev_t = invert_transform(raw.info['dev_head_t'])['trans'] | |
hpi_dig_dev_rrs = apply_trans(head_dev_t, hpi_dig_head_rrs) | |
hpi['n_freqs'] = len(hpi['freqs']) | |
dist_limit = 0.005 | |
gof_limit = 0.85 # low for this phantom data | |
def do_fits(): | |
last = dict(sin_fit=None, coil_dev_rrs=hpi_dig_dev_rrs, | |
quat=np.concatenate([rot_to_quat(dev_head_t[:3, :3]), | |
dev_head_t[:3, 3]])) | |
for start, stop in zip(lims[:-1], lims[1:]): | |
t = start / raw.info['sfreq'] | |
data, times = raw[:, start:stop] | |
time_sl = slice(start, stop) | |
# 1. Fit amplitudes for each channel from each of the N cHPI sinusoids | |
sin_fit = _fit_cHPI_amplitudes(raw, time_sl, hpi, 0., verbose=False) | |
# skip this window if bad | |
# logging has already been done! Maybe turn this into an Exception | |
if sin_fit is None: | |
raise RuntimeError('Bad sin fit!') # should probably reuse | |
# check if data has sufficiently changed | |
if last['sin_fit'] is not None: # first iteration | |
# The sign of our fits is arbitrary | |
flips = np.sign((sin_fit * last['sin_fit']).sum(-1, keepdims=True)) | |
sin_fit *= flips | |
corr = np.corrcoef(sin_fit.ravel(), last['sin_fit'].ravel())[0, 1] | |
# check to see if we need to continue | |
if corr * corr > 0.98: | |
# don't need to refit data | |
print('%s: could reuse here' % t) | |
# update 'last' sin_fit *before* inplace sign mult | |
last['sin_fit'] = sin_fit.copy() | |
# | |
# 2. Fit magnetic dipole for each coil to obtain coil positions | |
# in device coordinates | |
# | |
outs = [_fit_magnetic_dipole(f, pos, hpi['coils'], hpi['scale'], | |
hpi['method']) | |
for f, pos in zip(sin_fit, last['coil_dev_rrs'])] | |
this_coil_dev_rrs = np.array([o[0] for o in outs]) | |
# filter coil fits based on the correspodnace to digitization geometry | |
use_mask = np.ones(hpi['n_freqs'], bool) | |
these_dists = cdist(this_coil_dev_rrs, this_coil_dev_rrs) | |
these_dists = np.abs(hpi_coil_dists - these_dists) | |
# there is probably a better algorithm for finding the bad ones... | |
good = False | |
while not good: | |
d = these_dists[use_mask][:, use_mask] | |
d_bad = (d > dist_limit) | |
good = not d_bad.any() | |
if not good: | |
if use_mask.sum() == 2: | |
use_mask[:] = False | |
break # failure | |
# exclude next worst point | |
badness = (d * d_bad).sum(axis=0) | |
exclude_coils = np.where(use_mask)[0][np.argmax(badness)] | |
use_mask[exclude_coils] = False | |
good = use_mask.sum() >= 3 | |
if not good: | |
raise RuntimeError( | |
'%s/%s good HPI fits, cannot determine the transformation!' | |
% (use_mask.sum(), hpi['n_freqs'])) | |
# | |
# 3. Fit the head translation and rotation params (minimize error | |
# between coil positions and the head coil digitization positions) | |
# | |
this_quat, g = _fit_chpi_quat(this_coil_dev_rrs[use_mask], | |
hpi_dig_head_rrs[use_mask], | |
last['quat']) | |
if g < gof_limit: | |
raise RuntimeError('Bad coil fit! (g=%7.3f)' % (g,)) | |
# Convert quaterion to transform | |
this_dev_head_t = np.concatenate( | |
(quat_to_rot(this_quat[:3]), | |
this_quat[3:][:, np.newaxis]), axis=1) | |
this_dev_head_t = np.concatenate((this_dev_head_t, [[0, 0, 0, 1.]])) | |
if __name__ == '__main__': | |
t0 = time.time() | |
do_fits() | |
elapsed = time.time() - t0 | |
print('Ran %0.1fx realtime' % (raw.times[-1] / elapsed)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment