Created
February 5, 2020 18:12
-
-
Save larsoner/a2be47b0489ac04834f67c883cc4d6c2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_atlas_roi_mask(stc, roi, atlas='IXI', atlas_subject=None, | |
subjects_dir=None): | |
"""Get ROI mask for a given subject/atlas. | |
Parameters | |
---------- | |
stc : instance of mne.SourceEstimate or mne.VectorSourceEstimate | |
The source estimate. | |
roi : str | |
The ROI to obtain a mask for. | |
atlas : str | |
The atlas to use. Must be "IXI" or "LBPA40". | |
atlas_subject : str | None | |
Atlas subject to process. Must be one of the (unwarped) subjects | |
"ANTS3-0Months3T", "ANTS6-0Months3T", or "ANTS12-0Months3T". | |
If None, it will be inferred from the number of vertices. | |
Returns | |
------- | |
mask : ndarray, shape (n_vertices,) | |
The mask. | |
""" | |
import nibabel as nib | |
from mne.utils import _validate_type | |
from mne.surface import _compute_nearest | |
_validate_type(stc, (VolSourceEstimate, VolVectorSourceEstimate), 'stc') | |
subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) | |
if atlas_subject is None: | |
atlas_subject = _VERT_COUNT_MAP[len(stc.vertices)] | |
fname_src = op.join(subjects_dir, atlas_subject, 'bem', '%s-vol5-src.fif' | |
% (atlas_subject,)) | |
src = read_source_spaces(fname_src) | |
mri = op.join(subjects_dir, atlas_subject, 'mri', | |
'%s_brain_ANTS_%s_atlas.mgz' % (atlas_subject, atlas)) | |
if not np.in1d(stc.vertices, src[0]['vertno']).all(): | |
raise RuntimeError('stc does not appear to be created from %s ' | |
'volumetric source space' % (atlas_subject,)) | |
rr = src[0]['rr'][stc.vertices] | |
mapping = get_atlas_mapping(atlas) | |
vol_id = mapping[roi] | |
mgz = nib.load(mri) | |
mgz_data = mgz.get_fdata() | |
vox_bool = mgz_data == vol_id | |
vox_ijk = np.array(np.where(vox_bool)).T | |
vox_mri_t = mgz.header.get_vox2ras_tkr() | |
vox_mri_t *= np.array([[1e-3, 1e-3, 1e-3, 1]]).T | |
rr_voi = apply_trans(vox_mri_t, vox_ijk) | |
dists = _compute_nearest(rr_voi, rr, return_dists=True)[1] | |
maxdist = np.linalg.norm(vox_mri_t[:3, :3].sum(0) / 2.) | |
mask = (dists <= maxdist) | |
return mask |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment