Skip to content

Instantly share code, notes, and snippets.

@dvm-shlee
Last active November 14, 2024 07:03
Show Gist options
  • Save dvm-shlee/0126c8b0f37b50a5ab4176f7af03e42d to your computer and use it in GitHub Desktop.
Save dvm-shlee/0126c8b0f37b50a5ab4176f7af03e42d to your computer and use it in GitHub Desktop.
SORDINO2NII, Bruker SORDINO-ZTE image reconstruction
#!/usr/bin/env python
"""
Script Name: sordino2nii
Description: This script converts raw SORDINO data (a type of Zero TE (ZTE) image acquired from the Bruker Biospin preclinical scanner) into NIfTI-1 format by performing an adjoint NUFFT reconstruction.
A distance-based Density Compensation Function (DCF) is applied, where `dk = 1/(x² + y² + z²)` and `DCF = 1/dk`.
Usage:
./sordino2nii -i [brkraw rawdata] -o [output prefix] [options]
Author: SungHo Lee
Date: 2024-11-12
# Changes
1. Resolved Orientation Issue: Addressed and corrected orientation inconsistencies in the output images.
2. Added --legacy-pose Option: This option should be enabled if the subject's pose is Prone but is set to Supine in the console.
3. Image Centering: The image center is now aligned with the scanner's iso-center, following the scanner's coordinate system.
- Consistent Orientation with Extension Factor: Orientation is now consistently maintained, even when using an extension factor, ensuring no orientation issues.
- Exception: For dual-brain imaging, the image center is set to the center of the image array for ease of use.
"""
import os
import sys
import argparse
import tempfile
import numpy as np
from tqdm.auto import tqdm
from pathlib import Path
__version__ = '24.11.12'
# default configuration
tqdm_ncols = 100
## argparse utility
def validate_float_or_list(value):
if len(value) in {1, 3}:
return float(value)
else:
raise argparse.ArgumentTypeError("Size must be either 1 or 3.")
## Parameter converter
def recon_output_shape(params):
""" Apply extension factor to output shape
The order of extension factors follews RAS+ orientation,
therefore, need to reorient (reorder) according to gradient orient
"""
ext_factors = params['ext_factors']
matrix_size = np.array(params['matrix_size'])
if not all(matrix_size == matrix_size[0]):
axis_order = params['axis_order']
if params['subj_type'] != 'Biped':
matrix_size = matrix_size[axis_order][[0, 2, 1]]
oshape = (params['matrix_size'] * ext_factors).astype(int).tolist()
return oshape
def recon_n_frames(args, params):
""" Return number of data frames need to be reconstructed
"""
total_frames = params['n_frames']
offset = args.offset or 0
avail_frames = total_frames - offset
set_frames = args.num_frames or total_frames
if set_frames > avail_frames:
diff = set_frames - avail_frames
set_frames -= diff
return set_frames
def recon_buffer_offset(args, params):
""" Return offset of file buffer to skip initial volumes
"""
offset = args.offset or 0
return offset * params['buffer_size']
def get_vol_scantime(params):
return params['repetition_time'] * params['fid_shape'][3]
## Trajectory calculation
radial_angles = lambda n, factor: np.ceil((np.pi * n * factor)/2).astype(int)
radial_angle = lambda i, n: np.pi * (i + 0.5) / n
def calc_n_pro(matrix_size, under_sampling):
"""
Calculate the number of half projections for given matrix size and undersampling factor.
Args:
matrix_size (int): Size of the matrix.
undersamp (float): Undersampling factor.
Returns:
int: Number of half projections.
"""
usamp = np.sqrt(under_sampling)
n_theta = radial_angles(matrix_size, 1/usamp)
n_pro = 0
for i_theta in range(n_theta):
theta = radial_angle(i_theta, n_theta)
n_phi = radial_angles(matrix_size, np.sin(theta) / usamp)
n_pro += n_phi
return n_pro
def find_undersamp(matrix_size, n_pro_target):
"""
Finds the undersampling factor that results in the desired number of projections.
Args:
matrix_size (int): The size of the matrix.
n_pro_target (int): The desired number of half projections.
Returns:
float: The undersampling factor that yields n_pro_target projections.
"""
from scipy.optimize import brentq
def func(under_sampling):
n_pro = calc_n_pro(matrix_size, under_sampling)
return n_pro - n_pro_target # We want this to be zero
max_val = calc_n_pro(matrix_size, 1)
# start = n_pro_target / max_val
start = 1e-6
end = max_val / matrix_size
if func(start) * func(end) > 0:
raise ValueError("The function does not change sign over the interval. Adjust the bounds.")
undersamp_solution = brentq(func, start, end, xtol=1e-6)
return undersamp_solution
def calc_radial_traj3d(grad_array, matrix_size, use_origin, over_sampling,
ramp_time_corr=False, traj_offset=None, traj_denom=None):
"""
Calculate trajectory for SORDINO imaging.
For each projection, a vector with sampling positions (trajectory) is created.
Args:
grad_array (ndarray): Gradient vector profile for each projection, sized (3 x n_pro).
matrix_size (int): Matrix size of the final image.
over_sampling (float): Oversampling factor.
bandwidth (float): Bandwidth of the imaging process.
traj_offset (float, optional): Trajectory offset ratio compared to ADC sampling.
Returns:
ndarray: Calculated trajectory for each projection.
"""
pro_offset = 1 if use_origin else 0
g = grad_array.copy()
n_pro = g.shape[-1]
traj_offset = traj_offset or 0
num_samples = int(matrix_size / 2 * over_sampling)
traj = np.zeros([n_pro, num_samples, 3])
scale_factor = (num_samples - 1 + traj_offset) / (num_samples-1)
print('++ Processing trajectory calculation...')
print(' + Input arguments')
print(f' - Size of Matrix: {matrix_size}')
print(f' - OverSampling: {over_sampling}')
print(f' - Trajectory offset ratio: {traj_offset}')
print(f' - Ramp-time Correction: {str(ramp_time_corr)}')
print(f' - Size of Output Trajectory: {traj.shape}\n')
print(f' - Image Scailing Factor (*Subject to be corrected in future version): {scale_factor}\n')
for i_pro in tqdm(range(pro_offset, n_pro + pro_offset), desc=' - NPro', file=sys.stdout, ncols=tqdm_ncols):
for i_samp in range(num_samples):
if traj_denom:
func = lambda i, s, denom: (i + s) / denom
samp = func(i_samp, traj_offset, traj_denom) / 2
else:
func = lambda i, m, s: (i + s) / (m-1)
samp = func(i_samp, num_samples, traj_offset) / 2
if not ramp_time_corr or i_pro == (n_pro + pro_offset) - 1:
# this will deactivate ramp_time_correction
correction = np.zeros(3)
else:
i_next_pro = i_pro + 1
correction = (g[:, i_next_pro] - g[:, i_pro]) / num_samples * i_samp
traj[i_pro, i_samp, :] = samp * (g[:, i_pro] + correction)
return traj
def calc_radial_grad3d(matrix_size, n_pro_target, half_sphere, use_origin, reorder):
"""
Generate 3D radial gradient profile based on input parameters.
Args:
matrix_size (int): Target matrix size.
n_pro_target (int): Target number of projections.
half_sphere (bool): If True, only generate for half the sphere.
use_origin (bool): If True, add center points at the start.
reorder (bool): Use reorder scheme provided by Bruker ZTE sequence.
Returns:
ndarray: The gradient profile as an array.
"""
print('\n++ Processing SORDINO 3D Radial Gradient Calculation...')
print(' + Input arguments')
print(f' - Matrix size: {matrix_size}')
print(f' - Number of Projections: {n_pro_target}')
print(f' - Half sphere only: {half_sphere}')
print(f' - Use origin: {use_origin}')
print(f' - Reorder Gradient: {reorder}\n')
n_pro = int(n_pro_target / (1 if half_sphere else 2) - (1 if use_origin else 0))
usamp = np.sqrt(find_undersamp(matrix_size, n_pro))
grad = {'r':[], 'p':[], 's':[]}
radial_n_phi = []
print(' + Calculating Gradient Vectors...', end='')
n_theta = radial_angles(matrix_size, 1.0 / usamp)
for i_theta in range(n_theta):
theta = radial_angle(i_theta, n_theta)
n_phi = radial_angles(matrix_size, np.sin(theta) / usamp)
radial_n_phi.append(n_phi)
for i_phi in range(n_phi):
phi = radial_angle(i_phi, n_phi)
grad['r'].append(np.sin(theta) * np.cos(phi))
grad['p'].append(np.sin(theta) * np.sin(phi))
grad['s'].append(np.cos(theta))
print('done\n')
# convert grad object to numpy array
grad_array = np.stack([grad['r'], grad['p'], grad['s']], axis=0)
n_pro_created = grad_array.shape[-1] * (1 if half_sphere else 2) + (1 if use_origin else 0)
if n_pro_created != n_pro_target:
raise ValueError(f"Target number of projections can't be reached. Suggested adjusted value: {n_pro_created}.")
# reorder projections
if reorder:
print(' + Reordering projections...', end='')
grad_array = reorder_projections(n_theta, radial_n_phi, grad_array, reorder)
if reorder:
print('done\n')
# add gradients for other hemisphere
if not half_sphere:
grad_array = np.concatenate([grad_array, -1 * grad_array], axis=1)
if use_origin:
grad_array = np.concatenate([[[0, 0, 0]], grad.T], axis=0).T
return grad_array
def reorder_projections(n_theta, radial_n_phi, grad_array, reorder):
"""
Reorder radial projections for improved image spoiling.
Args:
n_theta (int): Number of theta angles.
radial_n_phi (list): Number of phi angles for each theta.
grad_array (ndarray): Gradient array.
reorder (bool): Whether to apply the reordering scheme.
Returns:
ndarray: Reordered gradient array.
"""
g = grad_array.copy()
if reorder:
def reorder_incr_index(n, i, d):
"""Increment index and switch direction at boundaries."""
if (i + d > n - 1 or i + d < 0):
d *= -1
i += d
return i, d
n_pro = g.shape[-1]
n_phi_max = max(radial_n_phi)
r_g = np.zeros_like(g) # template for the reordered g
r_mask = np.zeros([n_theta, n_phi_max])
for i_theta in range(n_theta):
for i_phi in range(radial_n_phi[i_theta], n_phi_max):
r_mask[i_theta][i_phi] = 1
# indices
i_theta = 0 # index for angle Theta
d_theta = 1 # step for Theta index
i_phi = 0 # index for angle Phi
d_phi = 1 # step for Phi index
# loop until each projection has a new position
for i in range(n_pro):
# find next Theta with unused Phi
while not any(r_mask[i_theta] == 0):
i_theta, d_theta = reorder_incr_index(n_theta, i_theta, d_theta)
# find next Phi
while r_mask[i_theta][i_phi] == 1:
i_phi, d_phi = reorder_incr_index(n_phi_max, i_phi, d_phi)
new_i = sum(radial_n_phi[:i_theta]) + i_phi
r_g[:, i] = g[:, new_i]
r_mask[i_theta][i_phi] = 1
# update to next theta
i_theta, d_theta = reorder_incr_index(n_theta, i_theta, d_theta)
i_phi, d_phi = reorder_incr_index(n_phi_max, i_phi, d_phi)
return r_g
else:
i = 0
for i_theta in range(n_theta):
if i_theta %2 == 1:
for i_phi in range(int(radial_n_phi[i_theta]/2)):
i0 = i + i_phi
i1 = i + radial_n_phi[i_theta] - 1 - i_phi
# swap between values in given index
g[:, i0], g[:, i1] = g[:, i1].copy(), g[:, i0].copy()
i += radial_n_phi[i_theta]
return g
def generate_hash(*args):
"""Generate a hash from the input arguments."""
import hashlib
hash_input = "".join(str(arg) for arg in args)
return hashlib.md5(hash_input.encode()).hexdigest()
def get_trajectory(args, params):
matrix_size = params['matrix_size'][0]
eff_bandwidth = params['eff_bandwidth']
over_sampling = params['over_sampling']
n_pro = params['fid_shape'][3]
half_acquisition = params['half_acquisition']
use_origin = params['use_origin']
reorder = params['reorder']
ext_factors = params['ext_factors']
traj_offset_time = args.traj_offset or params['traj_offset']
print(f' + Extension factors applied to matrix: {ext_factors}')
grad = calc_radial_grad3d(matrix_size, n_pro, half_acquisition, use_origin, reorder)
offset_factor = traj_offset_time * (10**-6) * eff_bandwidth * over_sampling
option_for_hashs = (traj_offset_time, matrix_size, eff_bandwidth, over_sampling, n_pro,
ext_factors[0], half_acquisition, use_origin, reorder, args.ramp_time)
hash = generate_hash(*option_for_hashs)
tmpdir = Path(params['tmpdir'])
if not tmpdir.exists():
tmpdir.mkdir(parents=True, exist_ok=True)
traj_path = tmpdir / f'{hash}.npy'
if traj_path.exists():
traj = np.load(traj_path)
else:
# to call back old function
traj = calc_radial_traj3d(grad, matrix_size, use_origin, over_sampling,
ramp_time_corr=args.ramp_time,
traj_offset=offset_factor, traj_denom=args.traj_denom)
np.save(traj_path, traj)
return traj
## Spoke timing correction method
def correct_spoketiming(fid_f, stc_f, args, params, stc_params):
""" Correct timing of each spoke to align center of scan time
(Same concept as slice timing correction, but applied to FID signal)
"""
from scipy.interpolate import interp1d
pro_loc = 0
stc_dtype = None
stc_buffer_size = None
target_timestamps = stc_params['target_timestamps']
for seg_size in tqdm(stc_params['segs'], desc=' - Segments',
file=sys.stdout, ncols=tqdm_ncols):
# load data
pro_offset = pro_loc * stc_params['buffer_size_per_pro']
seg_buffer_size = stc_params['buffer_size_per_pro'] * seg_size # total buffer size for current segment
seg = []
for t in range(recon_n_frames(args, params)):
frame_offset = t * params['buffer_size']
seek_loc = recon_buffer_offset(args, params) + frame_offset + pro_offset
fid_f.seek(seek_loc)
seg.append(fid_f.read(seg_buffer_size))
seg_data = np.frombuffer(b''.join(seg), dtype=params['dtype_code'])
seg_data = seg_data.reshape([2, np.prod(params['fid_shape'][1:3]),
seg_size, recon_n_frames(args, params)],
order='F')
# interpolation step
corrected_seg_data = np.empty_like(seg_data)
# each projection interpolated the timing at the middle of projection
# number of spokes * receivers processed together
# therefore, the spoke timing within a projection will not be corrected
# instead this actually corrected the spoke timing for each projection
# as single TR represent time of one projection
for pro_id in range(seg_size):
cur_pro = pro_loc + pro_id
ref_timestamps = stc_params['base_timestamps'] + (cur_pro * params['repetition_time'])
for c in range(2): # real and imaginary (complex)
for e in range(np.prod(params['fid_shape'][1:3])):
try:
data_feed = seg_data[c, e, pro_id, :]
interp_func = interp1d(ref_timestamps,
data_feed,
kind='linear',
fill_value='extrapolate')
corrected_seg_data[c, e, pro_id, :] = interp_func(target_timestamps)
except Exception as e:
print("****Debugging****")
print(f"RefTimeStamps: {ref_timestamps}")
print(f"DataFeed: {data_feed}")
raise e
# Store data
for t in range(recon_n_frames(args, params)):
frame_offset = t * params['buffer_size']
stc_f.seek(frame_offset + pro_offset)
stc_f.write(corrected_seg_data[:,:,:, t].flatten(order='F').tobytes())
if not stc_dtype:
stc_dtype = corrected_seg_data.dtype
stc_buffer_size = np.prod(params['fid_shape']) * stc_dtype.itemsize
pro_loc += seg_size
stc_params['buffer_size'] = stc_buffer_size
stc_params['dtype'] = stc_dtype
def run_spoketiming_correction(recon_f, objects, args, params):
print("\n++ Running spoke timing correction for --spoketiming")
## Run spoke timing correction
# parameters for spoke timing correction
pvobj = objects['pvobj']
n_pro = params['fid_shape'][3]
vol_scantime = get_vol_scantime(params)
base_timestamps = np.arange(recon_n_frames(args, params)) * vol_scantime
stc_buffer_size = int(params['buffer_size'] / n_pro)
print(f" + Buffer Size for each Projection: {stc_buffer_size}")
stc_params = dict(
base_timestamps = base_timestamps,
target_timestamps = base_timestamps + (vol_scantime / 2),
buffer_size_per_pro = stc_buffer_size,
)
with tempfile.NamedTemporaryFile(mode='w+b',
delete=False,
dir=objects['tmpdir']) as stc_f:
scan_id = args.scanid
with pvobj._open_object(pvobj._fid[scan_id]) as fid_f:
try:
# ZipFileObject case
file_size = pvobj.filelist[pvobj._fid[scan_id]].file_size
except:
# FileObj case
file_size = os.path.getsize(pvobj._fid[scan_id])
file_size *= (recon_n_frames(args, params) / params['n_frames']) # process only selected frames
file_size /= 1024 ** 3 # unit to GB
print(f' + Size: {file_size:.3f} GB')
# for safety reason, cut data into the size defined at limit_mem_size (in GB)
num_segs = np.ceil(file_size / args.mem_limit).astype(int)
print(f' + Split data into {num_segs} segments for saving memory.')
n_pro_per_seg = int(np.ceil(n_pro / num_segs))
if residual_pro := n_pro % n_pro_per_seg:
segs = [n_pro_per_seg for _ in range(num_segs -1)] + [residual_pro]
else:
segs = [n_pro_per_seg for _ in range(num_segs)]
print(f" + The number of projections for each segments: \n {segs}")
stc_params['segs'] = segs
correct_spoketiming(fid_f, stc_f, args, params, stc_params)
del fid_f
print("\n++ Reconstruction (FID -> Image[complex])")
with open(stc_f.name, 'r+b') as fid_f:
recon_params = reconstruct_image(fid_f, recon_f, objects, args, params, stc_params)
print(' + Success')
params['cache'].append(stc_f.name)
# end of sopketiming correction
return recon_params
def correct_offreso(k, shift_freq, params):
"""
i: number of frame
k: fid
traj: trajectory coordinate
"""
bw = params['eff_bandwidth'] * params['over_sampling']
m_k = k.copy()
num_samp = m_k.shape[1]
for samp_id in tqdm(range(num_samp)):
m_k[:, samp_id] *= np.exp(-1j * 2 * shift_freq * np.pi * ((samp_id+1)/ bw))
return m_k
def reconstruct_image(fid_f, recon_f, objects, args, params, stc_params = None):
recon_f.seek(0)
if stc_params:
fid_f.seek(0)
buffer_size = stc_params['buffer_size']
dtype = stc_params['dtype']
else:
fid_f.seek(recon_buffer_offset(args, params))
buffer_size = params['buffer_size']
dtype = params['dtype_code']
traj = objects['traj'][:, :-args.pass_samples, ...]
for n in tqdm(range(recon_n_frames(args, params)), desc=' - Frames',
file=sys.stdout, ncols=tqdm_ncols):
buffer = fid_f.read(buffer_size)
v = np.frombuffer(buffer, dtype=dtype).reshape(params['fid_shape'], order='F')
v = (v[0] + 1j*v[1])[np.newaxis, ...]
ksp = v.squeeze().T[..., args.pass_samples:]
n_receivers = params['fid_shape'][2]
if n_receivers > 1:
recon_vol = []
for ch_id in range(n_receivers):
# crop initial samples
k = ksp[:, ch_id, :]
if args.offreso_ch:
if ch_id == args.offreso_ch - 1:
print(f" - Correcting OffResonance Frequency for Channel {args.offreso_ch}: {args.offreso_freq}Hz")
k = correct_offreso(k, args.offreso_freq, params)
recon_vol.append(nufft_adjoint(params, k, traj))
recon_vol = np.stack(recon_vol, axis=0)
else:
recon_vol = nufft_adjoint(params, ksp, traj)
if n == 0:
recon_dtype = recon_vol.dtype
recon_f.write(recon_vol.T.flatten(order='C').tobytes())
return dict(dtype=recon_dtype)
def nufft_adjoint(params, kspace, traj, operator='finufft'):
"""Run nufft and return the reconstucted image"""
from mrinufft import get_operator
output_shape = recon_output_shape(params)
dcf = np.sqrt(np.square(traj).sum(-1)).flatten() ** 2
dcf /= dcf.max()
traj = traj.copy() / 0.5 * np.pi
nufft_op = get_operator(operator)(traj, shape=output_shape, density=dcf)
complex_img = nufft_op.adj_op(kspace.flatten())
return complex_img
def run_reconstruction(objects, args, params):
params['cache'] = []
with tempfile.NamedTemporaryFile(mode='w+b',
delete=False,
dir=objects['tmpdir']) as recon_f:
if args.spoketiming:
recon_params = run_spoketiming_correction(recon_f, objects, args, params)
else:
pvobj = objects['pvobj']
print("\n++ Reconstruction (FID -> Image[complex])")
with pvobj._open_object(pvobj._fid[args.scanid]) as fid_f:
recon_params = reconstruct_image(fid_f, recon_f, objects, args, params)
print(" + Success")
recon_params['name'] = recon_f.name
return recon_params
## NIFTI conversion ##
def trim_empty_space(img, chid, args):
# slice at middle
slice_loc = int(img.shape[1] / 2)
f_slicer = []
b_slicer = []
for i in range(3):
if i == 1:
f_slicer.append(slice(None, slice_loc, None))
b_slicer.append(slice(slice_loc, None, None))
else:
ns = slice(None, None, None)
f_slicer.append(ns)
b_slicer.append(ns)
f_img = img.copy()[tuple(f_slicer)]
b_img = img.copy()[tuple(b_slicer)]
# choose image with higher mean contrast
trimmed_img = f_img if f_img.mean() > b_img.mean() else b_img
if chid == args.reversed_ch - 1:
# rotate subject 180 degree
trimmed_img = np.flip(trimmed_img, (0, 1))
return trimmed_img
def apply_angle(rotate, rad_x=0, rad_y=0, rad_z=0):
''' axis = x or y or z '''
rmat = dict(x = np.array([[1, 0, 0],
[0, np.cos(rad_x), -np.sin(rad_x)],
[0, np.sin(rad_x), np.cos(rad_x)]]).astype('float'),
y = np.array([[np.cos(rad_y), 0, np.sin(rad_y)],
[0, 1, 0],
[-np.sin(rad_y), 0, np.cos(rad_y)]]).astype('float'),
z = np.array([[np.cos(rad_z), -np.sin(rad_z), 0],
[np.sin(rad_z), np.cos(rad_z), 0],
[0, 0, 1]]).astype('float'))
rotate = rmat['z'].dot(rmat['y'].dot(rmat['x'].dot(rotate)))
return rotate
def swap_axis_order(img, affine, axis_order):
""" Method use to swap axis of matrix while keep affine intact
"""
swapped_img = np.transpose(img, axis_order + [d for d in range(img.ndim-1, img.ndim)])
swapped_affine = np.eye(4)
swapped_affine[:3, :3] = affine[:3, :3][axis_order, :][:, axis_order]
swapped_affine[:3, 3] = affine[:3, 3][axis_order]
return swapped_img, swapped_affine
def rotate_img_axis(img, rotate, inv=False):
rotate = prep_rotate(rotate, inv)
axis_order = np.nonzero(rotate)[1].tolist()
corr_img = transpose_img(img, axis_order)
corr_img = corr_img[tuple([slice(None, None, d) for d in rotate.sum(1)])]
return corr_img
def update_affine(affine, rotate, inv=False):
""" Method to update affine matrix
"""
rotate = prep_rotate(rotate, inv)
affine = affine.copy()
affine[:3, :3] = rotate.dot(affine[:3, :3])
affine[:3, 3] = rotate.dot(affine[:3, 3])
return affine
def prep_rotate(rotate, inv):
rotate = rotate.copy()
if inv:
rotate = np.linalg.inv(rotate)
rotate = np.round(rotate, decimals=0).astype(np.int16)
return rotate
def transpose_img(img, axis_order):
tp_img = np.transpose(img, axis_order + [i for i in range(img.ndim-1, img.ndim)])
return tp_img
def legacy_pose_correction(img, affine, resol, orient):
img = img.copy()
affine = affine.copy()
inv_resol = np.linalg.inv(np.diag(resol[np.nonzero(orient)[1].tolist()]))
reset_tf = np.linalg.inv(affine[:3, :3].dot(inv_resol))
legacy_tf = apply_angle(np.eye(3), 0, 0, np.pi)
img = rotate_img_axis(img, legacy_tf)
origin_ijk = [img.shape[i] - 1 if d < 0 else 0 for i, d in enumerate(legacy_tf.sum(-1))] + [1]
affine[:3, 3] = affine.dot(origin_ijk)[:3].dot(legacy_tf.dot(reset_tf))
affine[:3, :3] = affine[:3, :3].dot(reset_tf)
return img, np.round(affine, decimals=3)
def update_orient(img, args, params):
print(f" + Correcting data orientation...")
# axis order (read, phase, slice order to 2dseq order)
axis_order = params['axis_order']
_2dseq = transpose_img(img, axis_order)
orient, is_reversed_slice, position = params['orient_info']
if is_reversed_slice:
_2dseq = _2dseq[:,:,::-1]
# get_affine
resol = np.array(params['resol'])[axis_order]
affine = np.eye(4)
affine[:3, :3] = np.linalg.inv(orient).dot(np.diag(resol))
affine[:3, 3] = position
# apply extension factor
matrix_size = np.array(params['matrix_size'])[axis_order]
center = (matrix_size - 1) / 2
scaled_matrix = (matrix_size * params['ext_factors']).astype(np.int16)
origin = center - (scaled_matrix - 1)/2
affine[:, 3] = affine.dot(origin.tolist() + [1])
# reorient to subject coordinate system
ortho_affine = update_affine(affine, orient, inv=True)
ortho_img = rotate_img_axis(_2dseq, orient, inv=True)
# legacy pose correction
if args.legacy_pose:
ortho_img, ortho_affine = legacy_pose_correction(ortho_img, ortho_affine, resol, orient)
# Reorient to LAS+
if params['subj_type'] != 'Biped':
las_img, las_affine = swap_axis_order(ortho_img, ortho_affine, [0, 2, 1])
else:
flip_tf = np.diag([1, -1, 1])
las_img = rotate_img_axis(ortho_img, flip_tf)
las_affine = update_affine(ortho_affine, flip_tf)
# Orient to RAS+
ras_affine = update_affine(las_affine, np.diag([-1, 1, 1]))
return las_img, np.round(ras_affine, decimals=3)
def calc_slope_inter(data):
print(f" + convert dtype to UINT16")
inter = np.min(data)
dmax = np.max(data)
slope = (dmax - inter) / 2**16
if data.ndim > 3:
converted_data = []
for t in tqdm(range(data.shape[-1]), desc=' - Frame', file=sys.stdout, ncols=tqdm_ncols):
converted_data.append(((data[..., t] - inter) / slope).round().astype(np.uint16)[..., np.newaxis])
converted_data = np.concatenate(converted_data, axis=-1)
else:
converted_data = ((data - inter) / slope).round().astype(np.uint16)
converted_data = converted_data.squeeze()
# Print out results
print(f" - Slope: {slope:.3f}")
print(f" - Intercept: {inter:.3f}")
print(f" - Min: {converted_data.min()}")
print(f" - Max: {converted_data.max()}")
return converted_data, slope, inter
def save_to_nifti(img, args, params, output_fpath, n_receivers, chid):
import nibabel as nib
img, affine = update_orient(img, args, params)
if n_receivers > 1:
# Centered FOV for dual-head dataset for convenience
origin = -1 * params['fov'] * params['ext_factors'] / 2
origin[1] /= 2
origin[0] *= -1
affine[:3, 3] = origin
iimg, slope, inter = calc_slope_inter(img)
print(f" - Saving {output_fpath}...", end='')
niiobj = nib.Nifti1Image(iimg, affine)
niiobj.set_qform(affine, 1)
niiobj.set_sform(affine, 0)
niiobj.header.set_slope_inter(slope, inter)
niiobj.header['pixdim'][4] = get_vol_scantime(params)
niiobj.to_filename(output_fpath)
print(f"success")
def convert_to_numpy(recon_params, args, params):
# Save to NIFTI file
n_receivers = params['fid_shape'][2]
n_frames = recon_n_frames(args, params)
with open(recon_params['name'], 'r+b') as img_f:
print(f" + Converting dtype (complex -> float32)...", end='')
if n_receivers > 1:
oshape = [n_receivers] + recon_output_shape(params)
else:
oshape = recon_output_shape(params)
imgs = np.abs(np.frombuffer(img_f.read(),
dtype=recon_params['dtype']).reshape(oshape + [n_frames], order='F'))
print('success')
print_recon_info(args, params)
return imgs
def convert_to_nifti(imgs, recon_params, args, params):
n_receivers = params['fid_shape'][2]
print(f"\n++ Converting to Nifti image")
print(f" + {n_receivers} images will be created.")
if n_receivers == 1:
# add new axis to work on loop below
imgs = imgs[np.newaxis, ...]
for chid, img in enumerate(imgs):
if n_receivers > 1:
print(f" + Receiver Channel ID: {chid+1}")
output_fpath = f'{args.prefix}_ch-{str(chid+1)}.nii.gz'
img = trim_empty_space(img, chid, args)
else:
output_fpath = f'{args.prefix}.nii.gz'
save_to_nifti(img, args, params, output_fpath, n_receivers, chid)
print('Done...')
params['cache'].append(recon_params['name'])
## Command line Argument Parser and Processing Sequences
def init():
parser = argparse.ArgumentParser(prog="sordino2nii", description='Reconstruction tool for the SORDINO MRI')
parser.add_argument("-i", "--input", help="Path to the input raw data", type=str, default=None)
parser.add_argument("-o", "--prefix", help="Prefix for the output reconstructed image", type=str, default='output')
parser.add_argument("-s", "--scanid", help="Scan ID", type=int, default=None)
parser.add_argument("-e", "--extention", help="FOV regridding extension factors (RAS+)", nargs="+",
type=validate_float_or_list, default=[1,1,1])
parser.add_argument("--pass-samples", help="Exclude a given number of initial samples from each spoke's FID to remove underpowered signals (default: 1)",
type=int, default=1) # 2024.09.24 LS
parser.add_argument("--offset", help="Index of the starting frame for reconstruction", type=int, default=0)
parser.add_argument("--num-frames", help='Number of frames to reconstruct (starting from the offset)', type=int, default=None)
parser.add_argument("--traj-offset", help='Offset (in microseconds) to define the true center of the trajectory (default is determined by the AcqDelayTotal parameter)',
type=float, default=None) # 2024.10.09 LS
parser.add_argument("--spoketiming", help="Apply spoke timing correction if this option is enabled", action='store_true')
parser.add_argument("--ramp-time", help='Apply ramp time correction if this option is enabled', action='store_true')
parser.add_argument("--legacy-pose", help='Enable if the subject pose is Prone, but set to Supine in the console (following Paravision convention)', action='store_true')
parser.add_argument("--reversed-ch", help='Channel number for the subject scanned in the opposite position (default: 1) during simultaneous dual-brain scanning', type=int, default=1)
parser.add_argument("--offreso-ch", help='Channel affected by off-resonance frequency shifts', type=int, default=None)
parser.add_argument("--offreso-freq", help='Frequency shift (in Hz) to correct off-resonance for the specified channel (works only with the offreso-ch option)', type=float, default=0)
parser.add_argument("--tmpdir", help="Directory for storing temporary files", type=str, default=None)
parser.add_argument("--clear-cache", help='Delete intermediate binary files generated during reconstruction', action='store_true')
parser.add_argument("--mem-limit", help='Set memory usage limit when loading data (in GB, default: 0.5)', type=float, default=0.5)
parser.add_argument("--traj-denom", help='If given, change denominator to given value while calculating trajectory (default: None)', type=float, default=None)
print(f"++ sordino2nii(v{__version__}): Reconstruction CLI for the SORDINO MRI")
print("++ Authored by: SungHo Lee (email: [email protected])")
return parser
def load_data(args):
import brkraw as brk
print(f"\n++ Loading input Bruker rawdata {args.input}")
raw = brk.load(args.input)
pvobj = raw._pvobj
# create temporary folder
tmpdir = args.tmpdir or os.path.join(os.curdir, '.tmp')
os.makedirs(tmpdir, exist_ok=True)
ext_factors = args.extention * 3 if len(args.extention) == 1 else args.extention
objects = dict(raw = raw, pvobj = pvobj, tmpdir = tmpdir)
params = parse_params(brk.lib.reference, raw, args.scanid)
params['ext_factors'] = np.array(ext_factors)
params['tmpdir'] = tmpdir
return objects, params
def get_gradient_axis_order(method):
axis_decoder = {'axial' : 'L_R',
'sagittal': 'A_P',
'coronal' : 'L_R'}
slice_orient = method['PVM_SPackArrSliceOrient']
read_orient = method['PVM_SPackArrReadOrient']
axis_order = [1, 0, 2] if axis_decoder[slice_orient] != read_orient else [0, 1, 2]
return axis_order
def get_fid_shape(method):
n_receivers = method['PVM_EncNReceivers']
n_points = method['NPoints']
n_pro = method['NPro']
fid_shape = np.array([2, n_points, n_receivers, n_pro])
return fid_shape
def get_orient_info(rawobj, scan_id):
visu_pars = rawobj.get_visu_pars(scan_id, 1).parameters
orient_matrix = np.squeeze(np.round(visu_pars['VisuCoreOrientation'])).reshape([3, 3])
reversed_slice = True if 'reverse' in visu_pars['VisuCoreDiskSliceOrder'] else False
# Coordinate of first voxel in subject coordinate system X (R->L), Y (V->D), Z (Cd->Ro)
position = np.squeeze(visu_pars['VisuCorePosition'])
return orient_matrix, reversed_slice, position
def get_dtype_code(reference, acqp):
wordtype = reference.WORDTYPE[f'_{"".join(acqp["ACQ_word_size"].split("_"))}_SGN_INT']
byteorder = reference.BYTEORDER[f'{acqp["BYTORDA"]}Endian']
dtype_code = np.dtype(f'{byteorder}{wordtype}')
return dtype_code
def parse_params(reference, rawobj, scan_id):
pvobj = rawobj.pvobj
acqp = rawobj.get_acqp(scan_id).parameters
method = rawobj.get_method(scan_id).parameters
fid_shape = get_fid_shape(method)
n_frames = method['PVM_NRepetitions']
dtype_code = get_dtype_code(reference, acqp)
buffer_size = np.prod(fid_shape) * dtype_code.itemsize
dwell_time = 1 / (method['PVM_EffSWh'] * method['OverSampling']) * 10**6
orient_info = get_orient_info(rawobj, scan_id)
return dict(
subj_type = pvobj.subj_type,
subj_pose = f'{pvobj.subj_entry}_{pvobj.subj_pose}',
orient_info = orient_info,
axis_order = get_gradient_axis_order(method),
fid_shape = fid_shape,
n_frames = n_frames,
matrix_size = method['PVM_Matrix'], #Read, Phase, Slice
eff_bandwidth = method['PVM_EffSWh'],
over_sampling = method['OverSampling'],
resol = method['PVM_SpatResol'], #Read, Phase, Slice
fov = method['PVM_Fov'], #Read, Phase, Slice
half_acquisition = True if method['HalfAcquisition'] == 'Yes' else False,
use_origin = True if method['UseOrigin'] == 'Yes' else False,
reorder = True if method['Reorder'] == 'Yes' else False,
dtype_code = dtype_code,
buffer_size = buffer_size,
repetition_time = method.get('RepetitionTime') / 1000,
dwell_time = dwell_time,
traj_offset = method['AcqDelayTotal'] + dwell_time
)
def print_recon_info(args, params):
n_frames = recon_n_frames(args, params)
output = recon_output_shape(params)
output = ' x '.join([str(s) for s in output])
n_receivers = params['fid_shape'][2]
ndim = 4 if n_frames > 1 else 3
print(f" + Output dim: {ndim}")
print(f" + Input Extension Factor: {params['ext_factors']}")
print(f" + Output Matrix Size: {output}")
if n_receivers > 1:
print(f" + Number of Receiver Coils: {n_receivers}")
if args.offset:
print(f" + Frame offset: {args.offset}")
if ndim == 4:
print(f" + Output num of frames: {n_frames}")
def finalize(args, params):
if args.clear_cache:
print("\n++ Clear cache for --clear-cache")
for f in params['cache']:
os.remove(f)
else:
cache_fpath = f'{args.prefix}_cache.log'
print(f"\n++ Saving cache file: {cache_fpath}")
with open(cache_fpath, 'w+t') as log_f:
for f in params['cache']:
log_f.write(f + '\n')
def main():
parser = init()
args = parser.parse_args()
try:
# loading data
objects, params = load_data(args)
except:
parser.print_help()
sys.exit()
## calculate trajectory
objects['traj'] = get_trajectory(args, params)
## run reconstruction
recon_params = run_reconstruction(objects, args, params)
imgs = convert_to_numpy(recon_params, args, params)
convert_to_nifti(imgs, recon_params, args, params)
finalize(args, params)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment