Skip to content

Instantly share code, notes, and snippets.

@jcreinhold
Created November 23, 2021 21:38
Show Gist options
  • Save jcreinhold/a430ab793ffbfef5b61ea6e86619e7a5 to your computer and use it in GitHub Desktop.
Save jcreinhold/a430ab793ffbfef5b61ea6e86619e7a5 to your computer and use it in GitHub Desktop.
normalize by tissue mean
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Normalize the intensity of a set of images by
finding a tissue mean in the foreground and
voxel-wise dividing the image by that value
Author: Jacob Reinhold
"""
import os
import re
import sys
import warnings
from argparse import ArgumentParser
from pathlib import Path
from typing import List, Optional, Tuple
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchio as tio
from intensity_normalization.util.io import split_filename
from intensity_normalization.util.tissue_membership import find_tissue_memberships
from scipy.ndimage.filters import gaussian_filter1d
def plot(
data: List[np.ndarray],
out_filename: str,
title: Optional[str] = None,
n_bins: int = 200,
alpha: float = 0.5,
lw: float = 3.0,
log: bool = True,
smooth: bool = True,
figsize: Tuple[int, int] = (8, 8),
) -> None:
fig, ax = plt.subplots(figsize=figsize)
for datum in data:
hist, bin_edges = np.histogram(datum.flatten(), n_bins)
bins = np.diff(bin_edges) / 2 + bin_edges[:-1]
if log:
# catch divide by zero warnings in call to log
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
hist = np.log10(hist)
hist[np.isinf(hist)] = 0.0
if smooth:
hist = gaussian_filter1d(hist, sigma=1.0)
ax.plot(bins, hist, alpha=alpha, linewidth=lw)
ax.set_xlabel("Intensity")
ax.set_ylabel(r"Log$_{10}$ Count")
ax.set_ylim((0, None))
if title is not None:
ax.set_title(title)
plt.savefig(out_filename)
def main() -> int:
parser = ArgumentParser(description="Normalize images for image processing")
parser.add_argument("top_dir", type=Path)
parser.add_argument("out_dir", type=Path)
parser.add_argument("-i", "--include", type=str, default=[".*"], nargs="+")
parser.add_argument("-e", "--exclude", type=str, default=[None], nargs="+")
parser.add_argument("-t", "--tissue", type=int, default=1, choices=(0, 1, 2))
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument("-d", "--dry-run", action="store_true")
parser.add_argument("-s", "--save-tissue-mask", action="store_true")
parser.add_argument("--plot", action="store_true")
args = parser.parse_args()
include_progs = [re.compile(inc) for inc in args.include]
exclude_progs = [re.compile(exc) for exc in args.exclude if exc is not None]
means = []
foregrounds_before = []
foregrounds_after = []
try:
for root, dirs, files in os.walk(args.top_dir):
for _fn in files:
if all(inc.search(_fn) is None for inc in include_progs) or any(
exc.search(_fn) is not None for exc in exclude_progs
):
if args.verbose:
print(f"Skipping: {_fn}")
continue
fn = Path(root) / _fn
if args.verbose:
print(f"Normalizing: {str(fn)}")
out_fn = str(fn).replace(str(fn.parents[len(fn.parents) - 2]) + "/", "")
out_path = args.out_dir / out_fn
image = tio.ScalarImage(fn)
if not args.dry_run:
img = image.numpy()
mask = img > img.mean()
foreground = img[mask]
if args.plot:
foregrounds_before.append(foreground)
tissue_memberships = find_tissue_memberships(img, mask)
tissue_mem = tissue_memberships[..., args.tissue]
mean = np.average(img, weights=tissue_mem)
means.append(mean)
else:
mean = np.random.randn()
means.append(mean)
if args.verbose:
print(f"Saving normalized: {str(out_path)}; m={mean:0.3e}")
if not args.dry_run:
out_path.parent.mkdir(parents=True, exist_ok=True)
normalized = img / mean
if args.plot:
foregrounds_after.append(normalized[mask])
image.set_data(torch.from_numpy(normalized))
image.save(out_path)
if args.save_tissue_mask:
tissue_mask = np.zeros(img.shape)
masked = tissue_memberships[mask]
tissue_mask[mask] = np.argmax(masked, axis=1) + 1
image.set_data(torch.from_numpy(tissue_mask))
path, base, ext = split_filename(out_path)
tm_path = path / (base + "_tm" + ext)
if args.verbose:
print(f"Saving tissue mask: {str(tm_path)}")
image.save(tm_path)
except KeyboardInterrupt:
...
percentile = np.median(means)
std = np.std(means)
if args.verbose:
print(f"Median: {percentile}; Std: {std}")
if args.plot and not args.dry_run:
plot(foregrounds_before, "before.png", title="Before normalization")
plot(foregrounds_after, "after.png", title="After normalization")
return 0
if __name__ == "__main__":
sys.exit(main())
@jcreinhold
Copy link
Author

If you use this script in an academic paper, please cite the paper:

    @inproceedings{reinhold2019evaluating,
      title={Evaluating the impact of intensity normalization on {MR} image synthesis},
      author={Reinhold, Jacob C and Dewey, Blake E and Carass, Aaron and Prince, Jerry L},
      booktitle={Medical Imaging 2019: Image Processing},
      volume={10949},
      pages={109493H},
      year={2019},
      organization={International Society for Optics and Photonics}}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment