Last active
November 21, 2023 17:46
-
-
Save afrendeiro/54b7e767e45e836227e06c192061507f to your computer and use it in GitHub Desktop.
Use torch dataloaders with nuclei coordinates for training.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Use dataloaders with nuclei coordinates for training. | |
""" | |
from functools import partial | |
import requests | |
import h5py | |
from tqdm import tqdm | |
import numpy as np | |
import pandas as pd | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
from wsi_core import WholeSlideImage | |
from wsi_core.utils import Path | |
from wsi_core.utils import collate_features | |
# set seed to test reproducibility | |
torch.manual_seed(42) | |
class ConcatDataset(Dataset): | |
def __init__(self, *datasets): | |
self.datasets = datasets | |
self.n_slides = len(self.datasets) | |
def __getitem__(self, i: int): | |
return self.datasets[torch.randint(0, self.n_slides, (1,))][i] | |
def __len__(self): | |
return min(len(d) for d in self.datasets) | |
def download_slide(slide_file: Path, overwrite: bool = False) -> None: | |
if not slide_file.exists() or overwrite: | |
url = f"https://brd.nci.nih.gov/brd/imagedownload/{slide_id}" | |
with open(slide_file, "wb") as handle: | |
req = requests.get(url) | |
for block in req.iter_content(1024): | |
handle.write(block) | |
def set_nuclei_coordinates(s: WholeSlideImage, tile_width: int = 32) -> None: | |
slide_id = s.path.stem | |
# read in nuclei positions | |
# # here I'm using output from hovernet, you'd need to change this to where you have the stardist files! | |
# # also, we need to double check the x/y coordinates are in the same order (usual issue of switching dimensions) | |
nuclei_file = Path(f"{slide_id}_nuclei.csv") | |
nuclei = pd.read_csv(nuclei_file)[['centroid_x', 'centroid_y']] | |
# write to h5 converting centroids to tile corners | |
s.hdf5_file = Path(f"{slide_id}.nuclei.h5") | |
with h5py.File(s.hdf5_file, "w") as f: | |
ds = f.create_dataset("coords", data=nuclei.values - (tile_width // 2)) | |
to_add = { | |
'downsample': np.array([1., 1.]), | |
'downsampled_level_dim': np.array(s.wsi.dimensions), | |
'level_dim': np.array(s.wsi.dimensions), | |
'name': slide_id, | |
'patch_level': 0, | |
'patch_size': tile_width, | |
'save_path': 'data/gtex/svs' | |
} | |
for k, v in to_add.items(): | |
ds.attrs[k] = v | |
# Let's download a couple slides and prepare them (set nuclear coordinates) | |
slide_ids = ['GTEX-SNMC-0626', 'GTEX-12ZZW-2726'] | |
slides = list() | |
_coords = list() | |
for slide_id in slide_ids: | |
slide_file = Path(f"{slide_id}.svs") | |
download_slide(slide_file) | |
s = WholeSlideImage(slide_file) | |
set_nuclei_coordinates(s) | |
slides.append(s) | |
c = pd.DataFrame(s.get_tile_coordinates(), columns=['x', 'y']).assign(slide=slide_id) | |
_coords.append(c) | |
# I am going to keep track of the coordinates for each slide to make some checks later | |
coords = pd.concat(_coords) | |
coords.groupby('slide').size() | |
# slide | |
# GTEX-12ZZW-2726 110836 | |
# GTEX-SNMC-0626 378617 | |
# check no overlap between tiles | |
assert coords.groupby('slide').apply(lambda x: x.duplicated().sum()).sum() == 0 | |
coords = coords.set_index(['x', 'y']) # just to make it easier to index later | |
# Now, the only thing we need to do is to create a dataset that concatenates datasets "ConcatDataset": | |
ds = ConcatDataset(*[s.as_tile_bag() for s in slides]) | |
# Then we create a dataloader that returns the coordinates of tiles across slides: | |
collate_fn = partial(collate_features, with_coords=True) # just here so we get coordinates now to check, not needed in training | |
dl = DataLoader(ds, batch_size=64, shuffle=True, collate_fn=collate_fn) | |
for batch, batch_coords in tqdm(dl): | |
# check every batch has coordinates from both slides (it does) | |
n_samples = coords.loc[batch_coords[:, 0], batch_coords[:, 1], :].groupby('slide').size() | |
assert n_samples.shape[0] == ds.n_slides | |
# To test the reproducibility of setting a seed, break here (get first batch) and compare the coordinates | |
# break | |
# ground_truth = np.load('test_batch_coords.npy') | |
# assert np.all(batch_coords == ground_truth) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment