Created
November 23, 2021 21:38
-
-
Save jcreinhold/a430ab793ffbfef5b61ea6e86619e7a5 to your computer and use it in GitHub Desktop.
normalize by tissue mean
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
"""Normalize the intensity of a set of images by | |
finding a tissue mean in the foreground and | |
voxel-wise dividing the image by that value | |
Author: Jacob Reinhold | |
""" | |
import os | |
import re | |
import sys | |
import warnings | |
from argparse import ArgumentParser | |
from pathlib import Path | |
from typing import List, Optional, Tuple | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import torchio as tio | |
from intensity_normalization.util.io import split_filename | |
from intensity_normalization.util.tissue_membership import find_tissue_memberships | |
from scipy.ndimage.filters import gaussian_filter1d | |
def plot( | |
data: List[np.ndarray], | |
out_filename: str, | |
title: Optional[str] = None, | |
n_bins: int = 200, | |
alpha: float = 0.5, | |
lw: float = 3.0, | |
log: bool = True, | |
smooth: bool = True, | |
figsize: Tuple[int, int] = (8, 8), | |
) -> None: | |
fig, ax = plt.subplots(figsize=figsize) | |
for datum in data: | |
hist, bin_edges = np.histogram(datum.flatten(), n_bins) | |
bins = np.diff(bin_edges) / 2 + bin_edges[:-1] | |
if log: | |
# catch divide by zero warnings in call to log | |
with warnings.catch_warnings(): | |
warnings.filterwarnings("ignore") | |
hist = np.log10(hist) | |
hist[np.isinf(hist)] = 0.0 | |
if smooth: | |
hist = gaussian_filter1d(hist, sigma=1.0) | |
ax.plot(bins, hist, alpha=alpha, linewidth=lw) | |
ax.set_xlabel("Intensity") | |
ax.set_ylabel(r"Log$_{10}$ Count") | |
ax.set_ylim((0, None)) | |
if title is not None: | |
ax.set_title(title) | |
plt.savefig(out_filename) | |
def main() -> int: | |
parser = ArgumentParser(description="Normalize images for image processing") | |
parser.add_argument("top_dir", type=Path) | |
parser.add_argument("out_dir", type=Path) | |
parser.add_argument("-i", "--include", type=str, default=[".*"], nargs="+") | |
parser.add_argument("-e", "--exclude", type=str, default=[None], nargs="+") | |
parser.add_argument("-t", "--tissue", type=int, default=1, choices=(0, 1, 2)) | |
parser.add_argument("-v", "--verbose", action="store_true") | |
parser.add_argument("-d", "--dry-run", action="store_true") | |
parser.add_argument("-s", "--save-tissue-mask", action="store_true") | |
parser.add_argument("--plot", action="store_true") | |
args = parser.parse_args() | |
include_progs = [re.compile(inc) for inc in args.include] | |
exclude_progs = [re.compile(exc) for exc in args.exclude if exc is not None] | |
means = [] | |
foregrounds_before = [] | |
foregrounds_after = [] | |
try: | |
for root, dirs, files in os.walk(args.top_dir): | |
for _fn in files: | |
if all(inc.search(_fn) is None for inc in include_progs) or any( | |
exc.search(_fn) is not None for exc in exclude_progs | |
): | |
if args.verbose: | |
print(f"Skipping: {_fn}") | |
continue | |
fn = Path(root) / _fn | |
if args.verbose: | |
print(f"Normalizing: {str(fn)}") | |
out_fn = str(fn).replace(str(fn.parents[len(fn.parents) - 2]) + "/", "") | |
out_path = args.out_dir / out_fn | |
image = tio.ScalarImage(fn) | |
if not args.dry_run: | |
img = image.numpy() | |
mask = img > img.mean() | |
foreground = img[mask] | |
if args.plot: | |
foregrounds_before.append(foreground) | |
tissue_memberships = find_tissue_memberships(img, mask) | |
tissue_mem = tissue_memberships[..., args.tissue] | |
mean = np.average(img, weights=tissue_mem) | |
means.append(mean) | |
else: | |
mean = np.random.randn() | |
means.append(mean) | |
if args.verbose: | |
print(f"Saving normalized: {str(out_path)}; m={mean:0.3e}") | |
if not args.dry_run: | |
out_path.parent.mkdir(parents=True, exist_ok=True) | |
normalized = img / mean | |
if args.plot: | |
foregrounds_after.append(normalized[mask]) | |
image.set_data(torch.from_numpy(normalized)) | |
image.save(out_path) | |
if args.save_tissue_mask: | |
tissue_mask = np.zeros(img.shape) | |
masked = tissue_memberships[mask] | |
tissue_mask[mask] = np.argmax(masked, axis=1) + 1 | |
image.set_data(torch.from_numpy(tissue_mask)) | |
path, base, ext = split_filename(out_path) | |
tm_path = path / (base + "_tm" + ext) | |
if args.verbose: | |
print(f"Saving tissue mask: {str(tm_path)}") | |
image.save(tm_path) | |
except KeyboardInterrupt: | |
... | |
percentile = np.median(means) | |
std = np.std(means) | |
if args.verbose: | |
print(f"Median: {percentile}; Std: {std}") | |
if args.plot and not args.dry_run: | |
plot(foregrounds_before, "before.png", title="Before normalization") | |
plot(foregrounds_after, "after.png", title="After normalization") | |
return 0 | |
if __name__ == "__main__": | |
sys.exit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
If you use this script in an academic paper, please cite the paper: