Last active
July 9, 2024 10:10
-
-
Save jphdotam/b4b6b582cf75f7c873b1891627d45eb3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import math | |
from glob import glob | |
import pydicom | |
import numpy as np | |
import onnxruntime | |
from loguru import logger | |
from matplotlib import pyplot as plt | |
from skimage.transform import resize | |
MAG_DIR = r"C:\Users\James\Desktop\flow_maps_perfusion_1_\series0086-Body" | |
MAGN_DIR = r"C:\Users\James\Desktop\flow_maps_perfusion_1_\series0087-Body" | |
PHASE_DIR = r"C:\Users\James\Desktop\flow_maps_perfusion_1_\series0088-Body" | |
ONNX_MODEL_PATH = "../deploy/models/024.yaml__1lujhsy3__epoch=39__loss_val=0.004__iou_val=0.879.ckpt_cuda.onnx" | |
def pad_video_to_square(video, n_channels=None): | |
n_frames, h_orig, w_orig = video.shape[:3] | |
if h_orig == w_orig: | |
return video | |
new_dim = max(h_orig, w_orig) | |
if len(video.shape) == 4: | |
assert n_channels is not None | |
new_video = np.zeros((n_frames, new_dim, new_dim, n_channels), dtype=video.dtype) | |
else: | |
new_video = np.zeros((n_frames, new_dim, new_dim), dtype=video.dtype) | |
h_from = 0 if h_orig > w_orig else (new_dim // 2 - h_orig // 2) | |
w_from = 0 if h_orig < w_orig else (new_dim // 2 - w_orig // 2) | |
new_video[:, h_from:h_from + h_orig, w_from:w_from + w_orig] = video | |
return new_video | |
def depad_video_from_square(video, h_orig, w_orig, n_channels=None): | |
n_frames, new_h, new_w, *rest = video.shape | |
assert new_h == new_w, f"Expected square video, got {video.shape}" | |
new_dim = new_h | |
if len(rest) == 1: | |
assert n_channels is not None and rest[0] == n_channels | |
h_from = 0 if h_orig > w_orig else (new_dim // 2 - h_orig // 2) | |
w_from = 0 if h_orig < w_orig else (new_dim // 2 - w_orig // 2) | |
if len(rest) == 1: | |
depadded_video = video[:, h_from:h_from + h_orig, w_from:w_from + w_orig, :] | |
else: | |
depadded_video = video[:, h_from:h_from + h_orig, w_from:w_from + w_orig] | |
return depadded_video | |
def load_and_normalise_dicoms(*args, normalise=True): | |
"""Loops over all positional args and loads and normalises all the stacks within each series""" | |
out = [] | |
for i, dicom_paths in enumerate(args): | |
dcms = [pydicom.dcmread(dicom_path) for dicom_path in dicom_paths] | |
imgs = np.array([dcm.pixel_array for dcm in dcms]).astype(np.float32) | |
# rescale 0 - 1 | |
if normalise: | |
imgs -= imgs.min() | |
imgs /= imgs.max() | |
out.append(imgs) | |
return out | |
def mag_magn_phase_to_predicted_mask(mag_magn_phase: np.ndarray, | |
ort_session: onnxruntime.InferenceSession, | |
inference_dim=(320,320)): | |
# store original H/W | |
n_frames = len(mag_magn_phase[0]) | |
orig_h, orig_w = mag_magn_phase.shape[-2:] | |
orig_hw = max(orig_h, orig_w) | |
# centre pad | |
logger.debug(f"Pre padding: {mag_magn_phase.shape=}") | |
mag_magn_phase = np.array([pad_video_to_square(series) for series in mag_magn_phase]) | |
# resize | |
logger.debug(f"Pre resize: {mag_magn_phase.shape=}") | |
mag_magn_phase = np.array([[resize(frame, inference_dim) for frame in series] for series in mag_magn_phase]) | |
# shape check | |
logger.debug(f"Post resize: {mag_magn_phase.shape=}") | |
assert mag_magn_phase.shape == (3, n_frames, *inference_dim), f"Expected shape {(3, n_frames, *inference_dim)}, got {mag_magn_phase.shape}" | |
# prepare batch - (mag, magn, phase) == 3 * N_FRAMES * H * W -> N_FRAMES * 3 * H * W | |
x = np.transpose(mag_magn_phase, (1, 0, 2, 3)) | |
# forward pass | |
logger.debug(f"Pre forward pass: {x.shape=} {x.min()=} {x.max()=} {x.mean()=} {x.std()=}") | |
for i_input_channel in range(3): | |
input_channel = x[:, i_input_channel] | |
logger.debug(f"\tChannel {i_input_channel} min={input_channel.min()} max={input_channel.max()} mean={input_channel.mean()} std={input_channel.std()}") | |
pred_logit = ort_session.run(None, {'input': x})[0].transpose(1, 0, 2, 3) # N_CLASSES * N_FRAMES * H * W | |
# de-resize - let's do this before argmax so we dont get anti-aliasing issues etc. | |
logger.debug(f"Pre de-resize: {pred_logit.shape=}") | |
pred_logit = np.array([[resize(frame, (orig_hw, orig_hw)) for frame in channel] for channel in pred_logit]) | |
# de-pad | |
logger.debug(f"Pre de-pad: {pred_logit.shape=}") | |
pred_logit = np.array([depad_video_from_square(channel, orig_h, orig_w) for channel in pred_logit]) | |
# argmax | |
logger.debug(f"Pre argmax: {pred_logit.shape=}") | |
pred_cls = np.argmax(pred_logit, axis=0) # N_FRAMES * H * W | |
assert pred_cls.shape == (n_frames, orig_h, orig_w), f"Expected shape {n_frames, orig_h, orig_w}, got {pred_cls.shape=}" | |
logger.debug(f"Post argmax: {pred_cls.shape=}") | |
return pred_cls | |
dicom_paths_mag = sorted(glob(os.path.join(MAG_DIR, '*.dcm'))) | |
dicom_paths_magn = sorted(glob(os.path.join(MAGN_DIR, '*.dcm'))) | |
dicom_paths_phase = sorted(glob(os.path.join(PHASE_DIR, '*.dcm'))) | |
mag, magn, phase = load_and_normalise_dicoms(dicom_paths_mag, dicom_paths_magn, dicom_paths_phase) | |
mag_magn_phase = np.array([mag, np.zeros_like(magn), phase]) | |
ort_session = onnxruntime.InferenceSession(ONNX_MODEL_PATH) | |
pred_cls = mag_magn_phase_to_predicted_mask(mag_magn_phase, ort_session) | |
# create a grid of images and plot pred cls for each frame in each | |
n_frames = len(mag_magn_phase[0]) | |
n_sqrt = math.ceil(np.sqrt(n_frames)) | |
fig, axs = plt.subplots(n_sqrt, n_sqrt, figsize=(20, 20)) | |
for i, ax in enumerate(axs.flatten()): | |
if i < n_frames: | |
ax.imshow(pred_cls[i], cmap='gray') | |
ax.set_title(f"Frame {i}") | |
else: | |
ax.axis('off') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment