Skip to content

Instantly share code, notes, and snippets.

@larsoner
Created February 5, 2020 18:12
Show Gist options
  • Save larsoner/a2be47b0489ac04834f67c883cc4d6c2 to your computer and use it in GitHub Desktop.
Save larsoner/a2be47b0489ac04834f67c883cc4d6c2 to your computer and use it in GitHub Desktop.
def get_atlas_roi_mask(stc, roi, atlas='IXI', atlas_subject=None,
subjects_dir=None):
"""Get ROI mask for a given subject/atlas.
Parameters
----------
stc : instance of mne.SourceEstimate or mne.VectorSourceEstimate
The source estimate.
roi : str
The ROI to obtain a mask for.
atlas : str
The atlas to use. Must be "IXI" or "LBPA40".
atlas_subject : str | None
Atlas subject to process. Must be one of the (unwarped) subjects
"ANTS3-0Months3T", "ANTS6-0Months3T", or "ANTS12-0Months3T".
If None, it will be inferred from the number of vertices.
Returns
-------
mask : ndarray, shape (n_vertices,)
The mask.
"""
import nibabel as nib
from mne.utils import _validate_type
from mne.surface import _compute_nearest
_validate_type(stc, (VolSourceEstimate, VolVectorSourceEstimate), 'stc')
subjects_dir = get_subjects_dir(subjects_dir, raise_error=True)
if atlas_subject is None:
atlas_subject = _VERT_COUNT_MAP[len(stc.vertices)]
fname_src = op.join(subjects_dir, atlas_subject, 'bem', '%s-vol5-src.fif'
% (atlas_subject,))
src = read_source_spaces(fname_src)
mri = op.join(subjects_dir, atlas_subject, 'mri',
'%s_brain_ANTS_%s_atlas.mgz' % (atlas_subject, atlas))
if not np.in1d(stc.vertices, src[0]['vertno']).all():
raise RuntimeError('stc does not appear to be created from %s '
'volumetric source space' % (atlas_subject,))
rr = src[0]['rr'][stc.vertices]
mapping = get_atlas_mapping(atlas)
vol_id = mapping[roi]
mgz = nib.load(mri)
mgz_data = mgz.get_fdata()
vox_bool = mgz_data == vol_id
vox_ijk = np.array(np.where(vox_bool)).T
vox_mri_t = mgz.header.get_vox2ras_tkr()
vox_mri_t *= np.array([[1e-3, 1e-3, 1e-3, 1]]).T
rr_voi = apply_trans(vox_mri_t, vox_ijk)
dists = _compute_nearest(rr_voi, rr, return_dists=True)[1]
maxdist = np.linalg.norm(vox_mri_t[:3, :3].sum(0) / 2.)
mask = (dists <= maxdist)
return mask
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment