Last active
November 14, 2024 07:03
-
-
Save dvm-shlee/0126c8b0f37b50a5ab4176f7af03e42d to your computer and use it in GitHub Desktop.
SORDINO2NII, Bruker SORDINO-ZTE image reconstruction
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
Script Name: sordino2nii | |
Description: This script converts raw SORDINO data (a type of Zero TE (ZTE) image acquired from the Bruker Biospin preclinical scanner) into NIfTI-1 format by performing an adjoint NUFFT reconstruction. | |
A distance-based Density Compensation Function (DCF) is applied, where `dk = 1/(x² + y² + z²)` and `DCF = 1/dk`. | |
Usage: | |
./sordino2nii -i [brkraw rawdata] -o [output prefix] [options] | |
Author: SungHo Lee | |
Date: 2024-11-12 | |
# Changes | |
1. Resolved Orientation Issue: Addressed and corrected orientation inconsistencies in the output images. | |
2. Added --legacy-pose Option: This option should be enabled if the subject's pose is Prone but is set to Supine in the console. | |
3. Image Centering: The image center is now aligned with the scanner's iso-center, following the scanner's coordinate system. | |
- Consistent Orientation with Extension Factor: Orientation is now consistently maintained, even when using an extension factor, ensuring no orientation issues. | |
- Exception: For dual-brain imaging, the image center is set to the center of the image array for ease of use. | |
""" | |
import os | |
import sys | |
import argparse | |
import tempfile | |
import numpy as np | |
from tqdm.auto import tqdm | |
from pathlib import Path | |
__version__ = '24.11.12' | |
# default configuration | |
tqdm_ncols = 100 | |
## argparse utility | |
def validate_float_or_list(value): | |
if len(value) in {1, 3}: | |
return float(value) | |
else: | |
raise argparse.ArgumentTypeError("Size must be either 1 or 3.") | |
## Parameter converter | |
def recon_output_shape(params): | |
""" Apply extension factor to output shape | |
The order of extension factors follews RAS+ orientation, | |
therefore, need to reorient (reorder) according to gradient orient | |
""" | |
ext_factors = params['ext_factors'] | |
matrix_size = np.array(params['matrix_size']) | |
if not all(matrix_size == matrix_size[0]): | |
axis_order = params['axis_order'] | |
if params['subj_type'] != 'Biped': | |
matrix_size = matrix_size[axis_order][[0, 2, 1]] | |
oshape = (params['matrix_size'] * ext_factors).astype(int).tolist() | |
return oshape | |
def recon_n_frames(args, params): | |
""" Return number of data frames need to be reconstructed | |
""" | |
total_frames = params['n_frames'] | |
offset = args.offset or 0 | |
avail_frames = total_frames - offset | |
set_frames = args.num_frames or total_frames | |
if set_frames > avail_frames: | |
diff = set_frames - avail_frames | |
set_frames -= diff | |
return set_frames | |
def recon_buffer_offset(args, params): | |
""" Return offset of file buffer to skip initial volumes | |
""" | |
offset = args.offset or 0 | |
return offset * params['buffer_size'] | |
def get_vol_scantime(params): | |
return params['repetition_time'] * params['fid_shape'][3] | |
## Trajectory calculation | |
radial_angles = lambda n, factor: np.ceil((np.pi * n * factor)/2).astype(int) | |
radial_angle = lambda i, n: np.pi * (i + 0.5) / n | |
def calc_n_pro(matrix_size, under_sampling): | |
""" | |
Calculate the number of half projections for given matrix size and undersampling factor. | |
Args: | |
matrix_size (int): Size of the matrix. | |
undersamp (float): Undersampling factor. | |
Returns: | |
int: Number of half projections. | |
""" | |
usamp = np.sqrt(under_sampling) | |
n_theta = radial_angles(matrix_size, 1/usamp) | |
n_pro = 0 | |
for i_theta in range(n_theta): | |
theta = radial_angle(i_theta, n_theta) | |
n_phi = radial_angles(matrix_size, np.sin(theta) / usamp) | |
n_pro += n_phi | |
return n_pro | |
def find_undersamp(matrix_size, n_pro_target): | |
""" | |
Finds the undersampling factor that results in the desired number of projections. | |
Args: | |
matrix_size (int): The size of the matrix. | |
n_pro_target (int): The desired number of half projections. | |
Returns: | |
float: The undersampling factor that yields n_pro_target projections. | |
""" | |
from scipy.optimize import brentq | |
def func(under_sampling): | |
n_pro = calc_n_pro(matrix_size, under_sampling) | |
return n_pro - n_pro_target # We want this to be zero | |
max_val = calc_n_pro(matrix_size, 1) | |
# start = n_pro_target / max_val | |
start = 1e-6 | |
end = max_val / matrix_size | |
if func(start) * func(end) > 0: | |
raise ValueError("The function does not change sign over the interval. Adjust the bounds.") | |
undersamp_solution = brentq(func, start, end, xtol=1e-6) | |
return undersamp_solution | |
def calc_radial_traj3d(grad_array, matrix_size, use_origin, over_sampling, | |
ramp_time_corr=False, traj_offset=None, traj_denom=None): | |
""" | |
Calculate trajectory for SORDINO imaging. | |
For each projection, a vector with sampling positions (trajectory) is created. | |
Args: | |
grad_array (ndarray): Gradient vector profile for each projection, sized (3 x n_pro). | |
matrix_size (int): Matrix size of the final image. | |
over_sampling (float): Oversampling factor. | |
bandwidth (float): Bandwidth of the imaging process. | |
traj_offset (float, optional): Trajectory offset ratio compared to ADC sampling. | |
Returns: | |
ndarray: Calculated trajectory for each projection. | |
""" | |
pro_offset = 1 if use_origin else 0 | |
g = grad_array.copy() | |
n_pro = g.shape[-1] | |
traj_offset = traj_offset or 0 | |
num_samples = int(matrix_size / 2 * over_sampling) | |
traj = np.zeros([n_pro, num_samples, 3]) | |
scale_factor = (num_samples - 1 + traj_offset) / (num_samples-1) | |
print('++ Processing trajectory calculation...') | |
print(' + Input arguments') | |
print(f' - Size of Matrix: {matrix_size}') | |
print(f' - OverSampling: {over_sampling}') | |
print(f' - Trajectory offset ratio: {traj_offset}') | |
print(f' - Ramp-time Correction: {str(ramp_time_corr)}') | |
print(f' - Size of Output Trajectory: {traj.shape}\n') | |
print(f' - Image Scailing Factor (*Subject to be corrected in future version): {scale_factor}\n') | |
for i_pro in tqdm(range(pro_offset, n_pro + pro_offset), desc=' - NPro', file=sys.stdout, ncols=tqdm_ncols): | |
for i_samp in range(num_samples): | |
if traj_denom: | |
func = lambda i, s, denom: (i + s) / denom | |
samp = func(i_samp, traj_offset, traj_denom) / 2 | |
else: | |
func = lambda i, m, s: (i + s) / (m-1) | |
samp = func(i_samp, num_samples, traj_offset) / 2 | |
if not ramp_time_corr or i_pro == (n_pro + pro_offset) - 1: | |
# this will deactivate ramp_time_correction | |
correction = np.zeros(3) | |
else: | |
i_next_pro = i_pro + 1 | |
correction = (g[:, i_next_pro] - g[:, i_pro]) / num_samples * i_samp | |
traj[i_pro, i_samp, :] = samp * (g[:, i_pro] + correction) | |
return traj | |
def calc_radial_grad3d(matrix_size, n_pro_target, half_sphere, use_origin, reorder): | |
""" | |
Generate 3D radial gradient profile based on input parameters. | |
Args: | |
matrix_size (int): Target matrix size. | |
n_pro_target (int): Target number of projections. | |
half_sphere (bool): If True, only generate for half the sphere. | |
use_origin (bool): If True, add center points at the start. | |
reorder (bool): Use reorder scheme provided by Bruker ZTE sequence. | |
Returns: | |
ndarray: The gradient profile as an array. | |
""" | |
print('\n++ Processing SORDINO 3D Radial Gradient Calculation...') | |
print(' + Input arguments') | |
print(f' - Matrix size: {matrix_size}') | |
print(f' - Number of Projections: {n_pro_target}') | |
print(f' - Half sphere only: {half_sphere}') | |
print(f' - Use origin: {use_origin}') | |
print(f' - Reorder Gradient: {reorder}\n') | |
n_pro = int(n_pro_target / (1 if half_sphere else 2) - (1 if use_origin else 0)) | |
usamp = np.sqrt(find_undersamp(matrix_size, n_pro)) | |
grad = {'r':[], 'p':[], 's':[]} | |
radial_n_phi = [] | |
print(' + Calculating Gradient Vectors...', end='') | |
n_theta = radial_angles(matrix_size, 1.0 / usamp) | |
for i_theta in range(n_theta): | |
theta = radial_angle(i_theta, n_theta) | |
n_phi = radial_angles(matrix_size, np.sin(theta) / usamp) | |
radial_n_phi.append(n_phi) | |
for i_phi in range(n_phi): | |
phi = radial_angle(i_phi, n_phi) | |
grad['r'].append(np.sin(theta) * np.cos(phi)) | |
grad['p'].append(np.sin(theta) * np.sin(phi)) | |
grad['s'].append(np.cos(theta)) | |
print('done\n') | |
# convert grad object to numpy array | |
grad_array = np.stack([grad['r'], grad['p'], grad['s']], axis=0) | |
n_pro_created = grad_array.shape[-1] * (1 if half_sphere else 2) + (1 if use_origin else 0) | |
if n_pro_created != n_pro_target: | |
raise ValueError(f"Target number of projections can't be reached. Suggested adjusted value: {n_pro_created}.") | |
# reorder projections | |
if reorder: | |
print(' + Reordering projections...', end='') | |
grad_array = reorder_projections(n_theta, radial_n_phi, grad_array, reorder) | |
if reorder: | |
print('done\n') | |
# add gradients for other hemisphere | |
if not half_sphere: | |
grad_array = np.concatenate([grad_array, -1 * grad_array], axis=1) | |
if use_origin: | |
grad_array = np.concatenate([[[0, 0, 0]], grad.T], axis=0).T | |
return grad_array | |
def reorder_projections(n_theta, radial_n_phi, grad_array, reorder): | |
""" | |
Reorder radial projections for improved image spoiling. | |
Args: | |
n_theta (int): Number of theta angles. | |
radial_n_phi (list): Number of phi angles for each theta. | |
grad_array (ndarray): Gradient array. | |
reorder (bool): Whether to apply the reordering scheme. | |
Returns: | |
ndarray: Reordered gradient array. | |
""" | |
g = grad_array.copy() | |
if reorder: | |
def reorder_incr_index(n, i, d): | |
"""Increment index and switch direction at boundaries.""" | |
if (i + d > n - 1 or i + d < 0): | |
d *= -1 | |
i += d | |
return i, d | |
n_pro = g.shape[-1] | |
n_phi_max = max(radial_n_phi) | |
r_g = np.zeros_like(g) # template for the reordered g | |
r_mask = np.zeros([n_theta, n_phi_max]) | |
for i_theta in range(n_theta): | |
for i_phi in range(radial_n_phi[i_theta], n_phi_max): | |
r_mask[i_theta][i_phi] = 1 | |
# indices | |
i_theta = 0 # index for angle Theta | |
d_theta = 1 # step for Theta index | |
i_phi = 0 # index for angle Phi | |
d_phi = 1 # step for Phi index | |
# loop until each projection has a new position | |
for i in range(n_pro): | |
# find next Theta with unused Phi | |
while not any(r_mask[i_theta] == 0): | |
i_theta, d_theta = reorder_incr_index(n_theta, i_theta, d_theta) | |
# find next Phi | |
while r_mask[i_theta][i_phi] == 1: | |
i_phi, d_phi = reorder_incr_index(n_phi_max, i_phi, d_phi) | |
new_i = sum(radial_n_phi[:i_theta]) + i_phi | |
r_g[:, i] = g[:, new_i] | |
r_mask[i_theta][i_phi] = 1 | |
# update to next theta | |
i_theta, d_theta = reorder_incr_index(n_theta, i_theta, d_theta) | |
i_phi, d_phi = reorder_incr_index(n_phi_max, i_phi, d_phi) | |
return r_g | |
else: | |
i = 0 | |
for i_theta in range(n_theta): | |
if i_theta %2 == 1: | |
for i_phi in range(int(radial_n_phi[i_theta]/2)): | |
i0 = i + i_phi | |
i1 = i + radial_n_phi[i_theta] - 1 - i_phi | |
# swap between values in given index | |
g[:, i0], g[:, i1] = g[:, i1].copy(), g[:, i0].copy() | |
i += radial_n_phi[i_theta] | |
return g | |
def generate_hash(*args): | |
"""Generate a hash from the input arguments.""" | |
import hashlib | |
hash_input = "".join(str(arg) for arg in args) | |
return hashlib.md5(hash_input.encode()).hexdigest() | |
def get_trajectory(args, params): | |
matrix_size = params['matrix_size'][0] | |
eff_bandwidth = params['eff_bandwidth'] | |
over_sampling = params['over_sampling'] | |
n_pro = params['fid_shape'][3] | |
half_acquisition = params['half_acquisition'] | |
use_origin = params['use_origin'] | |
reorder = params['reorder'] | |
ext_factors = params['ext_factors'] | |
traj_offset_time = args.traj_offset or params['traj_offset'] | |
print(f' + Extension factors applied to matrix: {ext_factors}') | |
grad = calc_radial_grad3d(matrix_size, n_pro, half_acquisition, use_origin, reorder) | |
offset_factor = traj_offset_time * (10**-6) * eff_bandwidth * over_sampling | |
option_for_hashs = (traj_offset_time, matrix_size, eff_bandwidth, over_sampling, n_pro, | |
ext_factors[0], half_acquisition, use_origin, reorder, args.ramp_time) | |
hash = generate_hash(*option_for_hashs) | |
tmpdir = Path(params['tmpdir']) | |
if not tmpdir.exists(): | |
tmpdir.mkdir(parents=True, exist_ok=True) | |
traj_path = tmpdir / f'{hash}.npy' | |
if traj_path.exists(): | |
traj = np.load(traj_path) | |
else: | |
# to call back old function | |
traj = calc_radial_traj3d(grad, matrix_size, use_origin, over_sampling, | |
ramp_time_corr=args.ramp_time, | |
traj_offset=offset_factor, traj_denom=args.traj_denom) | |
np.save(traj_path, traj) | |
return traj | |
## Spoke timing correction method | |
def correct_spoketiming(fid_f, stc_f, args, params, stc_params): | |
""" Correct timing of each spoke to align center of scan time | |
(Same concept as slice timing correction, but applied to FID signal) | |
""" | |
from scipy.interpolate import interp1d | |
pro_loc = 0 | |
stc_dtype = None | |
stc_buffer_size = None | |
target_timestamps = stc_params['target_timestamps'] | |
for seg_size in tqdm(stc_params['segs'], desc=' - Segments', | |
file=sys.stdout, ncols=tqdm_ncols): | |
# load data | |
pro_offset = pro_loc * stc_params['buffer_size_per_pro'] | |
seg_buffer_size = stc_params['buffer_size_per_pro'] * seg_size # total buffer size for current segment | |
seg = [] | |
for t in range(recon_n_frames(args, params)): | |
frame_offset = t * params['buffer_size'] | |
seek_loc = recon_buffer_offset(args, params) + frame_offset + pro_offset | |
fid_f.seek(seek_loc) | |
seg.append(fid_f.read(seg_buffer_size)) | |
seg_data = np.frombuffer(b''.join(seg), dtype=params['dtype_code']) | |
seg_data = seg_data.reshape([2, np.prod(params['fid_shape'][1:3]), | |
seg_size, recon_n_frames(args, params)], | |
order='F') | |
# interpolation step | |
corrected_seg_data = np.empty_like(seg_data) | |
# each projection interpolated the timing at the middle of projection | |
# number of spokes * receivers processed together | |
# therefore, the spoke timing within a projection will not be corrected | |
# instead this actually corrected the spoke timing for each projection | |
# as single TR represent time of one projection | |
for pro_id in range(seg_size): | |
cur_pro = pro_loc + pro_id | |
ref_timestamps = stc_params['base_timestamps'] + (cur_pro * params['repetition_time']) | |
for c in range(2): # real and imaginary (complex) | |
for e in range(np.prod(params['fid_shape'][1:3])): | |
try: | |
data_feed = seg_data[c, e, pro_id, :] | |
interp_func = interp1d(ref_timestamps, | |
data_feed, | |
kind='linear', | |
fill_value='extrapolate') | |
corrected_seg_data[c, e, pro_id, :] = interp_func(target_timestamps) | |
except Exception as e: | |
print("****Debugging****") | |
print(f"RefTimeStamps: {ref_timestamps}") | |
print(f"DataFeed: {data_feed}") | |
raise e | |
# Store data | |
for t in range(recon_n_frames(args, params)): | |
frame_offset = t * params['buffer_size'] | |
stc_f.seek(frame_offset + pro_offset) | |
stc_f.write(corrected_seg_data[:,:,:, t].flatten(order='F').tobytes()) | |
if not stc_dtype: | |
stc_dtype = corrected_seg_data.dtype | |
stc_buffer_size = np.prod(params['fid_shape']) * stc_dtype.itemsize | |
pro_loc += seg_size | |
stc_params['buffer_size'] = stc_buffer_size | |
stc_params['dtype'] = stc_dtype | |
def run_spoketiming_correction(recon_f, objects, args, params): | |
print("\n++ Running spoke timing correction for --spoketiming") | |
## Run spoke timing correction | |
# parameters for spoke timing correction | |
pvobj = objects['pvobj'] | |
n_pro = params['fid_shape'][3] | |
vol_scantime = get_vol_scantime(params) | |
base_timestamps = np.arange(recon_n_frames(args, params)) * vol_scantime | |
stc_buffer_size = int(params['buffer_size'] / n_pro) | |
print(f" + Buffer Size for each Projection: {stc_buffer_size}") | |
stc_params = dict( | |
base_timestamps = base_timestamps, | |
target_timestamps = base_timestamps + (vol_scantime / 2), | |
buffer_size_per_pro = stc_buffer_size, | |
) | |
with tempfile.NamedTemporaryFile(mode='w+b', | |
delete=False, | |
dir=objects['tmpdir']) as stc_f: | |
scan_id = args.scanid | |
with pvobj._open_object(pvobj._fid[scan_id]) as fid_f: | |
try: | |
# ZipFileObject case | |
file_size = pvobj.filelist[pvobj._fid[scan_id]].file_size | |
except: | |
# FileObj case | |
file_size = os.path.getsize(pvobj._fid[scan_id]) | |
file_size *= (recon_n_frames(args, params) / params['n_frames']) # process only selected frames | |
file_size /= 1024 ** 3 # unit to GB | |
print(f' + Size: {file_size:.3f} GB') | |
# for safety reason, cut data into the size defined at limit_mem_size (in GB) | |
num_segs = np.ceil(file_size / args.mem_limit).astype(int) | |
print(f' + Split data into {num_segs} segments for saving memory.') | |
n_pro_per_seg = int(np.ceil(n_pro / num_segs)) | |
if residual_pro := n_pro % n_pro_per_seg: | |
segs = [n_pro_per_seg for _ in range(num_segs -1)] + [residual_pro] | |
else: | |
segs = [n_pro_per_seg for _ in range(num_segs)] | |
print(f" + The number of projections for each segments: \n {segs}") | |
stc_params['segs'] = segs | |
correct_spoketiming(fid_f, stc_f, args, params, stc_params) | |
del fid_f | |
print("\n++ Reconstruction (FID -> Image[complex])") | |
with open(stc_f.name, 'r+b') as fid_f: | |
recon_params = reconstruct_image(fid_f, recon_f, objects, args, params, stc_params) | |
print(' + Success') | |
params['cache'].append(stc_f.name) | |
# end of sopketiming correction | |
return recon_params | |
def correct_offreso(k, shift_freq, params): | |
""" | |
i: number of frame | |
k: fid | |
traj: trajectory coordinate | |
""" | |
bw = params['eff_bandwidth'] * params['over_sampling'] | |
m_k = k.copy() | |
num_samp = m_k.shape[1] | |
for samp_id in tqdm(range(num_samp)): | |
m_k[:, samp_id] *= np.exp(-1j * 2 * shift_freq * np.pi * ((samp_id+1)/ bw)) | |
return m_k | |
def reconstruct_image(fid_f, recon_f, objects, args, params, stc_params = None): | |
recon_f.seek(0) | |
if stc_params: | |
fid_f.seek(0) | |
buffer_size = stc_params['buffer_size'] | |
dtype = stc_params['dtype'] | |
else: | |
fid_f.seek(recon_buffer_offset(args, params)) | |
buffer_size = params['buffer_size'] | |
dtype = params['dtype_code'] | |
traj = objects['traj'][:, :-args.pass_samples, ...] | |
for n in tqdm(range(recon_n_frames(args, params)), desc=' - Frames', | |
file=sys.stdout, ncols=tqdm_ncols): | |
buffer = fid_f.read(buffer_size) | |
v = np.frombuffer(buffer, dtype=dtype).reshape(params['fid_shape'], order='F') | |
v = (v[0] + 1j*v[1])[np.newaxis, ...] | |
ksp = v.squeeze().T[..., args.pass_samples:] | |
n_receivers = params['fid_shape'][2] | |
if n_receivers > 1: | |
recon_vol = [] | |
for ch_id in range(n_receivers): | |
# crop initial samples | |
k = ksp[:, ch_id, :] | |
if args.offreso_ch: | |
if ch_id == args.offreso_ch - 1: | |
print(f" - Correcting OffResonance Frequency for Channel {args.offreso_ch}: {args.offreso_freq}Hz") | |
k = correct_offreso(k, args.offreso_freq, params) | |
recon_vol.append(nufft_adjoint(params, k, traj)) | |
recon_vol = np.stack(recon_vol, axis=0) | |
else: | |
recon_vol = nufft_adjoint(params, ksp, traj) | |
if n == 0: | |
recon_dtype = recon_vol.dtype | |
recon_f.write(recon_vol.T.flatten(order='C').tobytes()) | |
return dict(dtype=recon_dtype) | |
def nufft_adjoint(params, kspace, traj, operator='finufft'): | |
"""Run nufft and return the reconstucted image""" | |
from mrinufft import get_operator | |
output_shape = recon_output_shape(params) | |
dcf = np.sqrt(np.square(traj).sum(-1)).flatten() ** 2 | |
dcf /= dcf.max() | |
traj = traj.copy() / 0.5 * np.pi | |
nufft_op = get_operator(operator)(traj, shape=output_shape, density=dcf) | |
complex_img = nufft_op.adj_op(kspace.flatten()) | |
return complex_img | |
def run_reconstruction(objects, args, params): | |
params['cache'] = [] | |
with tempfile.NamedTemporaryFile(mode='w+b', | |
delete=False, | |
dir=objects['tmpdir']) as recon_f: | |
if args.spoketiming: | |
recon_params = run_spoketiming_correction(recon_f, objects, args, params) | |
else: | |
pvobj = objects['pvobj'] | |
print("\n++ Reconstruction (FID -> Image[complex])") | |
with pvobj._open_object(pvobj._fid[args.scanid]) as fid_f: | |
recon_params = reconstruct_image(fid_f, recon_f, objects, args, params) | |
print(" + Success") | |
recon_params['name'] = recon_f.name | |
return recon_params | |
## NIFTI conversion ## | |
def trim_empty_space(img, chid, args): | |
# slice at middle | |
slice_loc = int(img.shape[1] / 2) | |
f_slicer = [] | |
b_slicer = [] | |
for i in range(3): | |
if i == 1: | |
f_slicer.append(slice(None, slice_loc, None)) | |
b_slicer.append(slice(slice_loc, None, None)) | |
else: | |
ns = slice(None, None, None) | |
f_slicer.append(ns) | |
b_slicer.append(ns) | |
f_img = img.copy()[tuple(f_slicer)] | |
b_img = img.copy()[tuple(b_slicer)] | |
# choose image with higher mean contrast | |
trimmed_img = f_img if f_img.mean() > b_img.mean() else b_img | |
if chid == args.reversed_ch - 1: | |
# rotate subject 180 degree | |
trimmed_img = np.flip(trimmed_img, (0, 1)) | |
return trimmed_img | |
def apply_angle(rotate, rad_x=0, rad_y=0, rad_z=0): | |
''' axis = x or y or z ''' | |
rmat = dict(x = np.array([[1, 0, 0], | |
[0, np.cos(rad_x), -np.sin(rad_x)], | |
[0, np.sin(rad_x), np.cos(rad_x)]]).astype('float'), | |
y = np.array([[np.cos(rad_y), 0, np.sin(rad_y)], | |
[0, 1, 0], | |
[-np.sin(rad_y), 0, np.cos(rad_y)]]).astype('float'), | |
z = np.array([[np.cos(rad_z), -np.sin(rad_z), 0], | |
[np.sin(rad_z), np.cos(rad_z), 0], | |
[0, 0, 1]]).astype('float')) | |
rotate = rmat['z'].dot(rmat['y'].dot(rmat['x'].dot(rotate))) | |
return rotate | |
def swap_axis_order(img, affine, axis_order): | |
""" Method use to swap axis of matrix while keep affine intact | |
""" | |
swapped_img = np.transpose(img, axis_order + [d for d in range(img.ndim-1, img.ndim)]) | |
swapped_affine = np.eye(4) | |
swapped_affine[:3, :3] = affine[:3, :3][axis_order, :][:, axis_order] | |
swapped_affine[:3, 3] = affine[:3, 3][axis_order] | |
return swapped_img, swapped_affine | |
def rotate_img_axis(img, rotate, inv=False): | |
rotate = prep_rotate(rotate, inv) | |
axis_order = np.nonzero(rotate)[1].tolist() | |
corr_img = transpose_img(img, axis_order) | |
corr_img = corr_img[tuple([slice(None, None, d) for d in rotate.sum(1)])] | |
return corr_img | |
def update_affine(affine, rotate, inv=False): | |
""" Method to update affine matrix | |
""" | |
rotate = prep_rotate(rotate, inv) | |
affine = affine.copy() | |
affine[:3, :3] = rotate.dot(affine[:3, :3]) | |
affine[:3, 3] = rotate.dot(affine[:3, 3]) | |
return affine | |
def prep_rotate(rotate, inv): | |
rotate = rotate.copy() | |
if inv: | |
rotate = np.linalg.inv(rotate) | |
rotate = np.round(rotate, decimals=0).astype(np.int16) | |
return rotate | |
def transpose_img(img, axis_order): | |
tp_img = np.transpose(img, axis_order + [i for i in range(img.ndim-1, img.ndim)]) | |
return tp_img | |
def legacy_pose_correction(img, affine, resol, orient): | |
img = img.copy() | |
affine = affine.copy() | |
inv_resol = np.linalg.inv(np.diag(resol[np.nonzero(orient)[1].tolist()])) | |
reset_tf = np.linalg.inv(affine[:3, :3].dot(inv_resol)) | |
legacy_tf = apply_angle(np.eye(3), 0, 0, np.pi) | |
img = rotate_img_axis(img, legacy_tf) | |
origin_ijk = [img.shape[i] - 1 if d < 0 else 0 for i, d in enumerate(legacy_tf.sum(-1))] + [1] | |
affine[:3, 3] = affine.dot(origin_ijk)[:3].dot(legacy_tf.dot(reset_tf)) | |
affine[:3, :3] = affine[:3, :3].dot(reset_tf) | |
return img, np.round(affine, decimals=3) | |
def update_orient(img, args, params): | |
print(f" + Correcting data orientation...") | |
# axis order (read, phase, slice order to 2dseq order) | |
axis_order = params['axis_order'] | |
_2dseq = transpose_img(img, axis_order) | |
orient, is_reversed_slice, position = params['orient_info'] | |
if is_reversed_slice: | |
_2dseq = _2dseq[:,:,::-1] | |
# get_affine | |
resol = np.array(params['resol'])[axis_order] | |
affine = np.eye(4) | |
affine[:3, :3] = np.linalg.inv(orient).dot(np.diag(resol)) | |
affine[:3, 3] = position | |
# apply extension factor | |
matrix_size = np.array(params['matrix_size'])[axis_order] | |
center = (matrix_size - 1) / 2 | |
scaled_matrix = (matrix_size * params['ext_factors']).astype(np.int16) | |
origin = center - (scaled_matrix - 1)/2 | |
affine[:, 3] = affine.dot(origin.tolist() + [1]) | |
# reorient to subject coordinate system | |
ortho_affine = update_affine(affine, orient, inv=True) | |
ortho_img = rotate_img_axis(_2dseq, orient, inv=True) | |
# legacy pose correction | |
if args.legacy_pose: | |
ortho_img, ortho_affine = legacy_pose_correction(ortho_img, ortho_affine, resol, orient) | |
# Reorient to LAS+ | |
if params['subj_type'] != 'Biped': | |
las_img, las_affine = swap_axis_order(ortho_img, ortho_affine, [0, 2, 1]) | |
else: | |
flip_tf = np.diag([1, -1, 1]) | |
las_img = rotate_img_axis(ortho_img, flip_tf) | |
las_affine = update_affine(ortho_affine, flip_tf) | |
# Orient to RAS+ | |
ras_affine = update_affine(las_affine, np.diag([-1, 1, 1])) | |
return las_img, np.round(ras_affine, decimals=3) | |
def calc_slope_inter(data): | |
print(f" + convert dtype to UINT16") | |
inter = np.min(data) | |
dmax = np.max(data) | |
slope = (dmax - inter) / 2**16 | |
if data.ndim > 3: | |
converted_data = [] | |
for t in tqdm(range(data.shape[-1]), desc=' - Frame', file=sys.stdout, ncols=tqdm_ncols): | |
converted_data.append(((data[..., t] - inter) / slope).round().astype(np.uint16)[..., np.newaxis]) | |
converted_data = np.concatenate(converted_data, axis=-1) | |
else: | |
converted_data = ((data - inter) / slope).round().astype(np.uint16) | |
converted_data = converted_data.squeeze() | |
# Print out results | |
print(f" - Slope: {slope:.3f}") | |
print(f" - Intercept: {inter:.3f}") | |
print(f" - Min: {converted_data.min()}") | |
print(f" - Max: {converted_data.max()}") | |
return converted_data, slope, inter | |
def save_to_nifti(img, args, params, output_fpath, n_receivers, chid): | |
import nibabel as nib | |
img, affine = update_orient(img, args, params) | |
if n_receivers > 1: | |
# Centered FOV for dual-head dataset for convenience | |
origin = -1 * params['fov'] * params['ext_factors'] / 2 | |
origin[1] /= 2 | |
origin[0] *= -1 | |
affine[:3, 3] = origin | |
iimg, slope, inter = calc_slope_inter(img) | |
print(f" - Saving {output_fpath}...", end='') | |
niiobj = nib.Nifti1Image(iimg, affine) | |
niiobj.set_qform(affine, 1) | |
niiobj.set_sform(affine, 0) | |
niiobj.header.set_slope_inter(slope, inter) | |
niiobj.header['pixdim'][4] = get_vol_scantime(params) | |
niiobj.to_filename(output_fpath) | |
print(f"success") | |
def convert_to_numpy(recon_params, args, params): | |
# Save to NIFTI file | |
n_receivers = params['fid_shape'][2] | |
n_frames = recon_n_frames(args, params) | |
with open(recon_params['name'], 'r+b') as img_f: | |
print(f" + Converting dtype (complex -> float32)...", end='') | |
if n_receivers > 1: | |
oshape = [n_receivers] + recon_output_shape(params) | |
else: | |
oshape = recon_output_shape(params) | |
imgs = np.abs(np.frombuffer(img_f.read(), | |
dtype=recon_params['dtype']).reshape(oshape + [n_frames], order='F')) | |
print('success') | |
print_recon_info(args, params) | |
return imgs | |
def convert_to_nifti(imgs, recon_params, args, params): | |
n_receivers = params['fid_shape'][2] | |
print(f"\n++ Converting to Nifti image") | |
print(f" + {n_receivers} images will be created.") | |
if n_receivers == 1: | |
# add new axis to work on loop below | |
imgs = imgs[np.newaxis, ...] | |
for chid, img in enumerate(imgs): | |
if n_receivers > 1: | |
print(f" + Receiver Channel ID: {chid+1}") | |
output_fpath = f'{args.prefix}_ch-{str(chid+1)}.nii.gz' | |
img = trim_empty_space(img, chid, args) | |
else: | |
output_fpath = f'{args.prefix}.nii.gz' | |
save_to_nifti(img, args, params, output_fpath, n_receivers, chid) | |
print('Done...') | |
params['cache'].append(recon_params['name']) | |
## Command line Argument Parser and Processing Sequences | |
def init(): | |
parser = argparse.ArgumentParser(prog="sordino2nii", description='Reconstruction tool for the SORDINO MRI') | |
parser.add_argument("-i", "--input", help="Path to the input raw data", type=str, default=None) | |
parser.add_argument("-o", "--prefix", help="Prefix for the output reconstructed image", type=str, default='output') | |
parser.add_argument("-s", "--scanid", help="Scan ID", type=int, default=None) | |
parser.add_argument("-e", "--extention", help="FOV regridding extension factors (RAS+)", nargs="+", | |
type=validate_float_or_list, default=[1,1,1]) | |
parser.add_argument("--pass-samples", help="Exclude a given number of initial samples from each spoke's FID to remove underpowered signals (default: 1)", | |
type=int, default=1) # 2024.09.24 LS | |
parser.add_argument("--offset", help="Index of the starting frame for reconstruction", type=int, default=0) | |
parser.add_argument("--num-frames", help='Number of frames to reconstruct (starting from the offset)', type=int, default=None) | |
parser.add_argument("--traj-offset", help='Offset (in microseconds) to define the true center of the trajectory (default is determined by the AcqDelayTotal parameter)', | |
type=float, default=None) # 2024.10.09 LS | |
parser.add_argument("--spoketiming", help="Apply spoke timing correction if this option is enabled", action='store_true') | |
parser.add_argument("--ramp-time", help='Apply ramp time correction if this option is enabled', action='store_true') | |
parser.add_argument("--legacy-pose", help='Enable if the subject pose is Prone, but set to Supine in the console (following Paravision convention)', action='store_true') | |
parser.add_argument("--reversed-ch", help='Channel number for the subject scanned in the opposite position (default: 1) during simultaneous dual-brain scanning', type=int, default=1) | |
parser.add_argument("--offreso-ch", help='Channel affected by off-resonance frequency shifts', type=int, default=None) | |
parser.add_argument("--offreso-freq", help='Frequency shift (in Hz) to correct off-resonance for the specified channel (works only with the offreso-ch option)', type=float, default=0) | |
parser.add_argument("--tmpdir", help="Directory for storing temporary files", type=str, default=None) | |
parser.add_argument("--clear-cache", help='Delete intermediate binary files generated during reconstruction', action='store_true') | |
parser.add_argument("--mem-limit", help='Set memory usage limit when loading data (in GB, default: 0.5)', type=float, default=0.5) | |
parser.add_argument("--traj-denom", help='If given, change denominator to given value while calculating trajectory (default: None)', type=float, default=None) | |
print(f"++ sordino2nii(v{__version__}): Reconstruction CLI for the SORDINO MRI") | |
print("++ Authored by: SungHo Lee (email: [email protected])") | |
return parser | |
def load_data(args): | |
import brkraw as brk | |
print(f"\n++ Loading input Bruker rawdata {args.input}") | |
raw = brk.load(args.input) | |
pvobj = raw._pvobj | |
# create temporary folder | |
tmpdir = args.tmpdir or os.path.join(os.curdir, '.tmp') | |
os.makedirs(tmpdir, exist_ok=True) | |
ext_factors = args.extention * 3 if len(args.extention) == 1 else args.extention | |
objects = dict(raw = raw, pvobj = pvobj, tmpdir = tmpdir) | |
params = parse_params(brk.lib.reference, raw, args.scanid) | |
params['ext_factors'] = np.array(ext_factors) | |
params['tmpdir'] = tmpdir | |
return objects, params | |
def get_gradient_axis_order(method): | |
axis_decoder = {'axial' : 'L_R', | |
'sagittal': 'A_P', | |
'coronal' : 'L_R'} | |
slice_orient = method['PVM_SPackArrSliceOrient'] | |
read_orient = method['PVM_SPackArrReadOrient'] | |
axis_order = [1, 0, 2] if axis_decoder[slice_orient] != read_orient else [0, 1, 2] | |
return axis_order | |
def get_fid_shape(method): | |
n_receivers = method['PVM_EncNReceivers'] | |
n_points = method['NPoints'] | |
n_pro = method['NPro'] | |
fid_shape = np.array([2, n_points, n_receivers, n_pro]) | |
return fid_shape | |
def get_orient_info(rawobj, scan_id): | |
visu_pars = rawobj.get_visu_pars(scan_id, 1).parameters | |
orient_matrix = np.squeeze(np.round(visu_pars['VisuCoreOrientation'])).reshape([3, 3]) | |
reversed_slice = True if 'reverse' in visu_pars['VisuCoreDiskSliceOrder'] else False | |
# Coordinate of first voxel in subject coordinate system X (R->L), Y (V->D), Z (Cd->Ro) | |
position = np.squeeze(visu_pars['VisuCorePosition']) | |
return orient_matrix, reversed_slice, position | |
def get_dtype_code(reference, acqp): | |
wordtype = reference.WORDTYPE[f'_{"".join(acqp["ACQ_word_size"].split("_"))}_SGN_INT'] | |
byteorder = reference.BYTEORDER[f'{acqp["BYTORDA"]}Endian'] | |
dtype_code = np.dtype(f'{byteorder}{wordtype}') | |
return dtype_code | |
def parse_params(reference, rawobj, scan_id): | |
pvobj = rawobj.pvobj | |
acqp = rawobj.get_acqp(scan_id).parameters | |
method = rawobj.get_method(scan_id).parameters | |
fid_shape = get_fid_shape(method) | |
n_frames = method['PVM_NRepetitions'] | |
dtype_code = get_dtype_code(reference, acqp) | |
buffer_size = np.prod(fid_shape) * dtype_code.itemsize | |
dwell_time = 1 / (method['PVM_EffSWh'] * method['OverSampling']) * 10**6 | |
orient_info = get_orient_info(rawobj, scan_id) | |
return dict( | |
subj_type = pvobj.subj_type, | |
subj_pose = f'{pvobj.subj_entry}_{pvobj.subj_pose}', | |
orient_info = orient_info, | |
axis_order = get_gradient_axis_order(method), | |
fid_shape = fid_shape, | |
n_frames = n_frames, | |
matrix_size = method['PVM_Matrix'], #Read, Phase, Slice | |
eff_bandwidth = method['PVM_EffSWh'], | |
over_sampling = method['OverSampling'], | |
resol = method['PVM_SpatResol'], #Read, Phase, Slice | |
fov = method['PVM_Fov'], #Read, Phase, Slice | |
half_acquisition = True if method['HalfAcquisition'] == 'Yes' else False, | |
use_origin = True if method['UseOrigin'] == 'Yes' else False, | |
reorder = True if method['Reorder'] == 'Yes' else False, | |
dtype_code = dtype_code, | |
buffer_size = buffer_size, | |
repetition_time = method.get('RepetitionTime') / 1000, | |
dwell_time = dwell_time, | |
traj_offset = method['AcqDelayTotal'] + dwell_time | |
) | |
def print_recon_info(args, params): | |
n_frames = recon_n_frames(args, params) | |
output = recon_output_shape(params) | |
output = ' x '.join([str(s) for s in output]) | |
n_receivers = params['fid_shape'][2] | |
ndim = 4 if n_frames > 1 else 3 | |
print(f" + Output dim: {ndim}") | |
print(f" + Input Extension Factor: {params['ext_factors']}") | |
print(f" + Output Matrix Size: {output}") | |
if n_receivers > 1: | |
print(f" + Number of Receiver Coils: {n_receivers}") | |
if args.offset: | |
print(f" + Frame offset: {args.offset}") | |
if ndim == 4: | |
print(f" + Output num of frames: {n_frames}") | |
def finalize(args, params): | |
if args.clear_cache: | |
print("\n++ Clear cache for --clear-cache") | |
for f in params['cache']: | |
os.remove(f) | |
else: | |
cache_fpath = f'{args.prefix}_cache.log' | |
print(f"\n++ Saving cache file: {cache_fpath}") | |
with open(cache_fpath, 'w+t') as log_f: | |
for f in params['cache']: | |
log_f.write(f + '\n') | |
def main(): | |
parser = init() | |
args = parser.parse_args() | |
try: | |
# loading data | |
objects, params = load_data(args) | |
except: | |
parser.print_help() | |
sys.exit() | |
## calculate trajectory | |
objects['traj'] = get_trajectory(args, params) | |
## run reconstruction | |
recon_params = run_reconstruction(objects, args, params) | |
imgs = convert_to_numpy(recon_params, args, params) | |
convert_to_nifti(imgs, recon_params, args, params) | |
finalize(args, params) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment