-
-
Save sbalci/29a403ebc115cbf2c17392744ba7a1ac to your computer and use it in GitHub Desktop.
Analyze HoverNet outputs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import torch | |
import h5py | |
from openslide import OpenSlide | |
import os | |
import numpy as np | |
from joblib import Parallel, delayed | |
from features import extract_features_single_tile | |
import json | |
from scipy.stats import mannwhitneyu | |
from statsmodels.stats.multitest import multipletests | |
import anndata as ad | |
TILE_SIZE = 224 | |
MPP = 1.1428571429 | |
ATTN_DIR = "slide_attn_rollouts" | |
SLIDE_DIR = "slides" | |
EMB_DIR = "STAMP_output/h5_files" | |
SIG = 0.0001 | |
def _identify_top_attn_and_min_attn(row): | |
attn_path, emb_path, slide_path = ( | |
row["attn_file"], | |
row["emb_file"], | |
row["slide_file"], | |
) | |
# get native microns per pixel | |
slide = OpenSlide(slide_path) | |
native_mpp_of_slide = float(slide.properties["openslide.mpp-x"]) | |
del slide | |
# load attention | |
try: | |
attention = torch.load(attn_path) | |
except FileNotFoundError: | |
print(f"File not found: {attn_path}") | |
return [] | |
attention = attention[-1, 0, 1:] # last layer, wrt cls token, no cls token self-attn | |
attention = attention.flatten() | |
# load coords and convert to native resolution coords | |
with h5py.File(emb_path, "r") as f: | |
coords = f["coords"][: len(attention)] | |
native_scaled_coords = (coords * MPP / native_mpp_of_slide).astype(int) | |
native_scaled_tile_size = int(TILE_SIZE * MPP / native_mpp_of_slide) | |
# identify top 3 attention values | |
top_ind = np.argpartition(attention, -3)[-3:] | |
top_coords_svs_system = native_scaled_coords[top_ind] | |
top_coords_stamp_system = coords[top_ind] | |
# identify lowest attention value | |
min_ind = np.argmin(attention) | |
min_coord_svs_system = native_scaled_coords[min_ind] | |
min_coord_stamp_system = coords[min_ind] | |
# create records with coord0, coord1, tile_size, top/bottom | |
records = [] | |
for coord_svs, coord_stamp in zip(top_coords_svs_system, top_coords_stamp_system): | |
records.append( | |
{ | |
"attn_file": attn_path, | |
"emb_file": emb_path, | |
"slide_file": slide_path, | |
"oncotype_score": row["oncotype_score"], | |
"image_id": row["image_id"], | |
"regex_grade": row["regex_grade"], | |
"regex_ihc_pr": row["regex_ihc_pr"], | |
"coord0_svs_system": coord_svs[0], | |
"coord1_svs_system": coord_svs[1], | |
"coord0_stamp_res": coord_stamp[0], | |
"coord1_stamp_res": coord_stamp[1], | |
"tile_size_svs_system": native_scaled_tile_size, | |
"tile_size_stamp_res": TILE_SIZE, | |
"top_bottom": "top", | |
} | |
) | |
records.append( | |
{ | |
"attn_file": attn_path, | |
"emb_file": emb_path, | |
"slide_file": slide_path, | |
"oncotype_score": row["oncotype_score"], | |
"image_id": row["image_id"], | |
"regex_grade": row["regex_grade"], | |
"regex_ihc_pr": row["regex_ihc_pr"], | |
"coord0_svs_system": min_coord_svs_system[0], | |
"coord1_svs_system": min_coord_svs_system[1], | |
"coord0_stamp_res": min_coord_stamp_system[0], | |
"coord1_stamp_res": min_coord_stamp_system[1], | |
"tile_size_svs_system": native_scaled_tile_size, | |
"tile_size_stamp_res": TILE_SIZE, | |
"top_bottom": "bottom", | |
} | |
) | |
return records | |
def define_df(): | |
df = pd.read_csv("../dataframes/holdout_df_with_regex_feats.csv") | |
df = df[ | |
[ | |
"oncotype_score", | |
"image_id", | |
"regex_grade", | |
"regex_ihc_pr", | |
] | |
] | |
df["attn_file"] = df["image_id"].apply(lambda x: os.path.join(ATTN_DIR, f"{x}.pt")) | |
df["emb_file"] = df["image_id"].apply(lambda x: os.path.join(EMB_DIR, f"{x}.h5")) | |
df["slide_file"] = df["image_id"].apply(lambda x: os.path.join(SLIDE_DIR, f"{x}.svs")) | |
# records = [] | |
# for idx, row in df.iterrows(): | |
# records.extend(_identify_top_attn_and_min_attn(row)) | |
records = Parallel(n_jobs=36)( | |
delayed(_identify_top_attn_and_min_attn)(row) for idx, row in df.iterrows() | |
) | |
records = [item for sublist in records for item in sublist] | |
tilewise_df = pd.DataFrame.from_records(records) | |
return tilewise_df | |
def _load_tile_to_dir(row): | |
slide_path, coord0, coord1, tile_size, save_path = ( | |
row["slide_file"], | |
row["coord0"], | |
row["coord1"], | |
row["tile_size"], | |
row["tile_path"], | |
) | |
slide = OpenSlide(slide_path) | |
tile = slide.read_region((coord0, coord1), 0, (tile_size, tile_size)) | |
# get magnification | |
mag = int(slide.properties["aperio.AppMag"]) | |
if mag == 40: | |
pass | |
elif mag == 20: | |
tile = tile.resize((tile_size * 2, tile_size * 2)) | |
else: | |
raise ValueError(f"Mag {mag} not supported") | |
tile.save(save_path) | |
slide.close() | |
def load_tiles_to_dir(df): | |
df["tile_path"] = df.apply( | |
lambda x: f"tiles_for_hovernet/{x['image_id']}_{x['coord0']}_{x['coord1']}.png", | |
axis=1, | |
) | |
# for idx, row in df.iterrows(): | |
# _load_tile_to_dir(row) | |
Parallel(n_jobs=36)(delayed(_load_tile_to_dir)(row) for idx, row in df.iterrows()) | |
df.to_csv("tilewise_df_tiled.csv", index=False) | |
def extract_features_from_tiles(df): | |
df["hovernet_output"] = df["tile_path"].apply( | |
lambda x: x.replace("tiles_for_hovernet", "preds_by_hovernet/json").replace(".png", ".json") | |
) | |
records = Parallel(n_jobs=36)( | |
delayed(_extract_features_from_tile)(row) for idx, row in df.iterrows() | |
) | |
feats_df = pd.DataFrame.from_records(records) | |
feats_df.to_csv("tilewise_df_tiled_feats.csv", index=False) | |
def _extract_features_from_tile(row): | |
hovernet_output = row["hovernet_output"] | |
try: | |
with open(hovernet_output, "r") as f: | |
data = json.load(f) | |
feats, feature_names, graph_feats_dict = extract_features_single_tile(data) | |
# create dict of features to add to df | |
record = dict(zip(feature_names, feats)) | |
except FileNotFoundError: | |
record = dict() | |
record["hovernet_output"] = hovernet_output | |
record["top_bottom"] = row["top_bottom"] | |
record["oncotype_score"] = row["oncotype_score"] | |
record["pred"] = row["pred"] | |
record["hi_or_lo"] = row["hi_or_lo"] | |
return record | |
def test_top_vs_bottom_attn(df): | |
df = df.set_index("hovernet_output") | |
top_df = df[df["top_bottom"] == "top"].copy() | |
top_df = top_df.drop(columns=["oncotype_score", "top_bottom"]) | |
bottom_df = df[df["top_bottom"] == "bottom"].copy() | |
bottom_df = bottom_df.drop(columns=["oncotype_score", "top_bottom"]) | |
feat_names = top_df.columns | |
feat_names = [f for f in feat_names if f not in ["pred", "oncotype_score", "hi_or_lo"]] | |
records = [] | |
for feat_name in feat_names: | |
top_vals = top_df[feat_name].values | |
top_mean = np.mean(top_vals) | |
bottom_vals = bottom_df[feat_name].values | |
bottom_mean = np.mean(bottom_vals) | |
stat, p = mannwhitneyu(top_vals, bottom_vals) | |
records.append( | |
{ | |
"feat_name": feat_name, | |
"top_mean": top_mean, | |
"bottom_mean": bottom_mean, | |
"stat": stat, | |
"p": p, | |
} | |
) | |
records = pd.DataFrame.from_records(records) | |
records["p_adj"] = multipletests(records["p"], method="fdr_bh")[1] | |
records = records.sort_values("p_adj") | |
records.to_csv("results_top_vs_bottom.csv", index=False) | |
print(records) | |
def test_hi_vs_lo_onco(df): | |
df = df.set_index("hovernet_output") | |
df = condition_features(df) | |
df = df[df["top_bottom"] == "top"].copy().drop(columns=["top_bottom"]) | |
hi_df = df[df["hi_or_lo"] == "hi"].copy() | |
lo_df = df[df["hi_or_lo"] == "lo"].copy() | |
feat_names = hi_df.columns | |
feat_names = [f for f in feat_names if f not in ["pred", "oncotype_score", "hi_or_lo"]] | |
records = [] | |
for feat_name in feat_names: | |
hi_vals = hi_df[feat_name].dropna().values | |
hi_mean = np.mean(hi_vals) | |
lo_vals = lo_df[feat_name].dropna().values | |
lo_mean = np.mean(lo_vals) | |
stat, p = mannwhitneyu(hi_vals, lo_vals) | |
records.append( | |
{ | |
"feat_name": feat_name, | |
"hi_mean": hi_mean, | |
"lo_mean": lo_mean, | |
"stat": stat, | |
"p": p, | |
} | |
) | |
records = pd.DataFrame.from_records(records) | |
records["p_adj"] = multipletests(records["p"], method="fdr_bh")[1] | |
records = records.sort_values("p_adj") | |
records.to_csv("results_hi_vs_lo.csv", index=False) | |
print(records) | |
def plot_volcano(df, comp_str): | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
cell_name_mapping = { | |
"nolabe": "unlab.", | |
"no-neo": "non-neo.", | |
"inflam": "inflam.", | |
"neopla": "neo.", | |
"connec": "connec.", | |
"necros": "necros.", | |
"total": "total", | |
} | |
assert comp_str in ["hi_vs_lo", "top_vs_bottom"] | |
with open("type_info.json", "r") as f: | |
palette = json.load(f) | |
palette = dict([(cell_name_mapping[v[0]], [x / 255.0 for x in v[1]]) for v in palette.values()]) | |
palette["total"] = [165 / 255.0, 42 / 255.0, 42 / 255.0] | |
fig, ax = plt.subplots(figsize=(3, 3)) | |
df["-log10(p)"] = -np.log10(df["p_adj"]) | |
if comp_str == "hi_vs_lo": | |
df["log2(FC)"] = np.log2(df["hi_mean"] / df["lo_mean"]) | |
else: | |
df["log2(FC)"] = np.log2(df["top_mean"] / df["bottom_mean"]) | |
df["cell"] = df["feat_name"].apply(lambda x: x.split("_")[-1]) | |
df["cell"] = df["cell"].map(cell_name_mapping) | |
sns.scatterplot( | |
data=df, | |
x="log2(FC)", | |
y="-log10(p)", | |
hue="cell", | |
palette=palette, | |
ax=ax, | |
s=17, | |
linewidth=0, | |
) | |
sns.move_legend(ax, "upper left", labelspacing=0, borderpad=0.2, title=None, fontsize=10) | |
if comp_str == "hi_vs_lo": | |
ax.set_xlabel("log2(FC) (High vs low score)", fontsize=10) | |
else: | |
ax.set_xlabel("log2(FC) (Top vs bottom attn.)", fontsize=10) | |
ax.set_ylabel("-log10(p)", fontsize=10) | |
plt.tick_params(axis="both", which="major", labelsize=10) | |
ax.axhline(-np.log10(SIG), linestyle=":", color="black", alpha=0.5) | |
ax.text(0.2, -np.log10(SIG) + 0.2, f"p < {SIG}", fontsize=6) | |
ax.axvline(0, linestyle=":", color="black", alpha=0.5) | |
# remove upper and right spines | |
ax.spines["top"].set_visible(False) | |
ax.spines["right"].set_visible(False) | |
plt.tight_layout() | |
plt.savefig(f"../../plots/volcano_{comp_str}_unlabeled.png", dpi=300) | |
# label all the points with p < 0.001 | |
for idx, row in df.iterrows(): | |
if row["-log10(p)"] > -np.log10(SIG): | |
label = " ".join(row["feat_name"].split("_")[:-1]) | |
label = label.replace("abundance relative", "rel. abund.") | |
label = label.replace("std", "std. dev.") | |
label = label.replace("median", "med.") | |
label = label.replace("solidity", "solid.") | |
label = label.replace("perimeter", "perim.") | |
if row["log2(FC)"] > 0.2: | |
x_pos = row["log2(FC)"] - 0.05 | |
h_al = "right" | |
else: | |
x_pos = row["log2(FC)"] + 0.05 | |
h_al = "left" | |
ax.text( | |
x_pos, | |
row["-log10(p)"] + 0.5, | |
label, | |
horizontalalignment=h_al, | |
verticalalignment="bottom", | |
fontsize=6, | |
) | |
plt.tight_layout() | |
plt.savefig(f"../../plots/volcano_{comp_str}_labeled.png", dpi=300) | |
plt.show() | |
def get_example_extrema_of_signif_feats(r, f): | |
import shutil | |
r = r[r["p_adj"] < SIG] | |
f = condition_features(f) | |
f["tile_path"] = f["hovernet_output"].apply( | |
lambda x: x.replace("preds_by_hovernet/json", "preds_by_hovernet/overlay").replace( | |
".json", ".png" | |
) | |
) | |
for idx, row in r.iterrows(): | |
if row["p_adj"] > SIG: | |
continue | |
print(f"{row['feat_name']}: {row['p_adj']}") | |
feat_name = row["feat_name"] | |
# get bottom and top five | |
_f_ = f.sort_values(feat_name).dropna(subset=[feat_name]) | |
for idx, row in _f_.head().iterrows(): | |
shutil.copy( | |
row["tile_path"], | |
f"../../plots/tile_feat_examples/{feat_name}_bottom{idx}.png", | |
) | |
for idx, row in _f_.tail().iterrows(): | |
shutil.copy( | |
row["tile_path"], | |
f"../../plots/tile_feat_examples/{feat_name}_top{idx}.png", | |
) | |
def condition_features(df): | |
df = df[df["abundance_total"] > 50] | |
for cell in ["nolabe", "no-neo", "inflam", "neopla", "connec", "necros"]: | |
df["abundance_absolute_" + cell] = None | |
df.loc[:, "abundance_absolute_" + cell] = ( | |
df["abundance_relative_" + cell] * df["abundance_total"] | |
) | |
# replace any std dev values with nan if the total number of that cell type is < 10 | |
for col_name in df.columns: | |
if "std" in col_name: | |
cell = col_name.split("_")[-1] | |
df.loc[df["abundance_absolute_" + cell] < 10, col_name] = np.nan | |
df = df[[c for c in df.columns if "absolute" not in c]] | |
print(df) | |
return df | |
def make_emb_ad(df): | |
df = df[ | |
[ | |
"image_id", | |
"coord0_stamp_res", | |
"coord1_stamp_res", | |
"tile_size_stamp_res", | |
"top_bottom", | |
"oncotype_score", | |
"emb_file", | |
] | |
] | |
df["coord0_stamp_res"] = df["coord0_stamp_res"].astype(int) | |
df["coord1_stamp_res"] = df["coord1_stamp_res"].astype(int) | |
embs = [] | |
for row_idx, row in df.iterrows(): | |
with h5py.File(row["emb_file"], "r") as f: | |
print(f.keys()) | |
feats = f["feats"][:] | |
coords = f["coords"][:] | |
feats_idx = np.where( | |
(coords[:, 0] == row["coord0_stamp_res"]) & (coords[:, 1] == row["coord1_stamp_res"]) | |
)[0][0] | |
embs.append(feats[feats_idx]) | |
embs = np.stack(embs) | |
# make anndata | |
obs = df[ | |
[ | |
"image_id", | |
"coord0_stamp_res", | |
"coord1_stamp_res", | |
"tile_size_stamp_res", | |
"top_bottom", | |
"oncotype_score", | |
] | |
] | |
obs["image_id"] = obs["image_id"].astype(str) | |
frame = ad.AnnData(X=embs, obs=obs) | |
print(frame) | |
return frame | |
if __name__ == "__main__": | |
tilewise_df = define_df() | |
tilewise_df.to_csv("tilewise_df.csv", index=False) | |
tilewise_df = pd.read_csv("tilewise_df.csv") | |
emb_ad = make_emb_ad(tilewise_df) | |
emb_ad.write("emb_df.h5ad") | |
tilewise_df = pd.read_csv("tilewise_df.csv") | |
tilewise_df["coord0"] = tilewise_df["coord0_svs_system"] | |
tilewise_df["coord1"] = tilewise_df["coord1_svs_system"] | |
tilewise_df["tile_size"] = tilewise_df["tile_size_svs_system"] | |
# add predicted_score to tilewise_df | |
pred_df = pd.read_csv( | |
"df.csv" | |
) | |
pred_df = pred_df[["pred", "image_id"]].set_index("image_id") | |
tilewise_df = tilewise_df.join(pred_df, on="image_id", how="inner") | |
# sort by pred | |
tilewise_df = tilewise_df.sort_values("pred") | |
image_ids_unique = tilewise_df["image_id"].drop_duplicates().tolist() | |
top_50_img_ids = image_ids_unique[-100:] | |
bottom_50_img_ids = image_ids_unique[:100] | |
# create column hi_or_lo | |
tilewise_df["hi_or_lo"] = "--" | |
# top 50 images are hi, bottom 50 are lo, rest are dropped | |
tilewise_df.loc[tilewise_df.image_id.isin(top_50_img_ids), "hi_or_lo"] = "hi" | |
tilewise_df.loc[tilewise_df.image_id.isin(bottom_50_img_ids), "hi_or_lo"] = "lo" | |
tilewise_df = tilewise_df[tilewise_df["hi_or_lo"] != "--"] | |
print(tilewise_df) | |
load_tiles_to_dir(tilewise_df) | |
# stop here, run HoverNet | |
tilewise_df = pd.read_csv("tilewise_df_tiled.csv") | |
extract_features_from_tiles(tilewise_df) | |
feat_df = pd.read_csv("tilewise_df_tiled_feats.csv") | |
feat_df = feat_df.dropna() | |
test_top_vs_bottom_attn(feat_df) | |
test_hi_vs_lo_onco(feat_df) | |
results_df = pd.read_csv("results_hi_vs_lo.csv") | |
plot_volcano(results_df, "hi_vs_lo") | |
results_df = pd.read_csv("results_hi_vs_lo.csv") | |
feat_df = pd.read_csv("tilewise_df_tiled_feats.csv") | |
get_example_extrema_of_signif_feats(results_df, feat_df) | |
top_v_bottom_df = pd.read_csv("results_top_vs_bottom.csv") | |
plot_volcano(top_v_bottom_df, "top_vs_bottom") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from shapely import Polygon | |
import pandas as pd | |
import scipy.spatial.distance as dist | |
HOVERNET_KEY = { | |
"0" : ["nolabe", [0 , 0, 0]], | |
"1" : ["neopla", [255, 0, 0]], | |
"2" : ["inflam", [0 , 255, 0]], | |
"3" : ["connec", [0 , 0, 255]], | |
"4" : ["necros", [255, 255, 0]], | |
"5" : ["no-neo", [255, 165, 0]] | |
} # modify based on classes and colors of choice | |
MAPPING = dict((int(k), v[0]) for (k, v) in HOVERNET_KEY.items()) | |
def extract_features_single_tile(data): | |
nuclei = list(data["nuc"].items()) | |
# nuclei = filter_low_prob(nuclei) | |
graph_feats = dict() | |
graph_feats.update(extract_centroid_and_type(nuclei)) | |
graph_feats.update(extract_nuclear_geometry(nuclei)) | |
# graph_feats.update(extract_nuclear_pixels(graph_feats)) | |
simple_tile_wise_feats = extract_abundance(nuclei, relative=True) | |
simple_tile_wise_feats.update(aggregate_cell_geom_features_tilewise(graph_feats)) | |
# convert simple_tile_wise_feats to list of (name, value) tuples | |
simple_tile_wise_feats = [(k, v) for k, v in simple_tile_wise_feats.items()] | |
# sort by name | |
simple_tile_wise_feats = sorted(simple_tile_wise_feats, key=lambda x: x[0]) | |
simple_feature_names, simple_feature_values = zip(*simple_tile_wise_feats) | |
return simple_feature_values, simple_feature_names, graph_feats | |
def extract_centroid_and_type(nuclei): | |
features = {"centroid": [], "type": [], "type_prob": [], "nuc_id": []} | |
for nuc_id, nuc in nuclei: | |
features["centroid"].append(nuc["centroid"]) | |
features["type"].append(nuc["type"]) | |
features["type_prob"].append(nuc["type_prob"]) | |
features["nuc_id"].append(nuc_id) | |
return features | |
def filter_low_prob(nuclei): | |
return [nuc for nuc in nuclei if nuc["type_prob"] > 0.5] | |
def extract_abundance(nuclei, relative=False): | |
features = dict() | |
counts = dict([(v, 0) for v in MAPPING.values()]) | |
for nuc_id, nuc in nuclei: | |
nuc_type = MAPPING[nuc["type"]] | |
counts[nuc_type] += 1 | |
root_name = "abundance_relative_" if relative else "abundance_" | |
feat_names = [root_name + k for k in counts.keys()] | |
features["abundance_total"] = len(nuclei) | |
for k, v in zip(feat_names, counts.values()): | |
if relative: | |
try: | |
features[k] = v / len(nuclei) if k != "abundance_total" else v | |
except ZeroDivisionError: | |
features[k] = 0 | |
else: | |
features[k] = v | |
return features | |
def extract_nuclear_geometry(nuclei): | |
geom = {"area": [], "perimeter": [], "solidity": [], "type": []} | |
for nuc_id, nuc in nuclei: | |
# if nuc_type not in ['connec', 'inflam', 'neopla', 'no-neo']: | |
# continue | |
p = Polygon(nuc["contour"]) | |
geom["area"].append(p.area * 0.0625) # convert from pix_40x^2 to um^2 | |
geom["perimeter"].append(p.length * 0.25) # convert from pix_40x to um | |
geom["solidity"].append(p.area / p.convex_hull.area) # dimensionless | |
geom["type"].append(nuc["type"]) | |
return geom | |
def aggregate_cell_geom_features_tilewise(graph_feat_dict): | |
features = dict() | |
vars = [ | |
"area_mean_nolabe", | |
"area_mean_neopla", | |
"area_mean_inflam", | |
"area_mean_connec", | |
"area_mean_necros", | |
"area_mean_no-neo", | |
"area_std_nolabe", | |
"area_std_neopla", | |
"area_std_inflam", | |
"area_std_connec", | |
"area_std_necros", | |
"area_std_no-neo", | |
"area_median_nolabe", | |
"area_median_neopla", | |
"area_median_inflam", | |
"area_median_connec", | |
"area_median_necros", | |
"area_median_no-neo", | |
"perimeter_mean_nolabe", | |
"perimeter_mean_neopla", | |
"perimeter_mean_inflam", | |
"perimeter_mean_connec", | |
"perimeter_mean_necros", | |
"perimeter_mean_no-neo", | |
"perimeter_std_nolabe", | |
"perimeter_std_neopla", | |
"perimeter_std_inflam", | |
"perimeter_std_connec", | |
"perimeter_std_necros", | |
"perimeter_std_no-neo", | |
"perimeter_median_nolabe", | |
"perimeter_median_neopla", | |
"perimeter_median_inflam", | |
"perimeter_median_connec", | |
"perimeter_median_necros", | |
"perimeter_median_no-neo", | |
"solidity_mean_nolabe", | |
"solidity_mean_neopla", | |
"solidity_mean_inflam", | |
"solidity_mean_connec", | |
"solidity_mean_necros", | |
"solidity_mean_no-neo", | |
"solidity_std_nolabe", | |
"solidity_std_neopla", | |
"solidity_std_inflam", | |
"solidity_std_connec", | |
"solidity_std_necros", | |
"solidity_std_no-neo", | |
"solidity_median_nolabe", | |
"solidity_median_neopla", | |
"solidity_median_inflam", | |
"solidity_median_connec", | |
"solidity_median_necros", | |
"solidity_median_no-neo", | |
] | |
for var in vars: | |
features[var] = 0 | |
df = pd.DataFrame.from_dict(graph_feat_dict) | |
df = df[["area", "perimeter", "solidity", "type"]] | |
df = df.groupby("type").agg(["mean", "std", "median"]) | |
df = pd.melt(df.reset_index(), id_vars=["type"]) | |
df["variable"] = df["variable_0"] + "_" + df["variable_1"] + "_" + df["type"].map(MAPPING) | |
df = df.drop(columns=["variable_0", "variable_1", "type"]) | |
for _, row in df.iterrows(): | |
features[row["variable"]] = row["value"] | |
return features | |
def convert_tile_to_geometric_format(graph_feats_dict): | |
features = dict() | |
features["pos"] = np.array(graph_feats_dict["centroid"]) * 0.25 # convert pixels to microns | |
features["x"], features["nuc_id"], features["feat_names"] = _format_vertex_features(graph_feats_dict) | |
features["edge_index"], features["edge_attr"] = _format_edge_features(features["pos"], 50) | |
return features | |
def _format_vertex_features(graph_feats_dict): | |
df = pd.DataFrame.from_dict(graph_feats_dict) | |
for k, v in MAPPING.items(): | |
df["type_" + v] = df["type_prob"] * (df["type"] == k) | |
nuc_ids = df["nuc_id"].values | |
df = df.drop(columns=["centroid", "type", "nuc_id", "type_prob"]) | |
df = df.sort_index(axis=1) | |
return df.values, nuc_ids, df.columns | |
def _format_edge_features(centroids, threshold): | |
if len(centroids) == 0: | |
return np.zeros((2, 0)), np.zeros((0, 1)) | |
distances = dist.cdist(centroids, centroids) # in microns | |
distances += np.diag(np.ones(len(distances)) * np.inf) # don't include self loops | |
connected = distances <= threshold | |
connected = np.argwhere(connected) | |
edge_len = distances[connected[:, 0], connected[:, 1]] | |
edge_len = np.expand_dims(edge_len, axis=1) | |
connected = np.ascontiguousarray(connected.T) | |
return connected, edge_len |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment