Skip to content

Instantly share code, notes, and snippets.

@kmboehm
Last active May 10, 2025 13:59
Show Gist options
  • Save kmboehm/aea77f24a9cdbb1f246dacaae812053d to your computer and use it in GitHub Desktop.
Save kmboehm/aea77f24a9cdbb1f246dacaae812053d to your computer and use it in GitHub Desktop.
Analyze HoverNet outputs
import pandas as pd
import torch
import h5py
from openslide import OpenSlide
import os
import numpy as np
from joblib import Parallel, delayed
from features import extract_features_single_tile
import json
from scipy.stats import mannwhitneyu
from statsmodels.stats.multitest import multipletests
import anndata as ad
TILE_SIZE = 224
MPP = 1.1428571429
ATTN_DIR = "slide_attn_rollouts"
SLIDE_DIR = "slides"
EMB_DIR = "STAMP_output/h5_files"
SIG = 0.0001
def _identify_top_attn_and_min_attn(row):
attn_path, emb_path, slide_path = (
row["attn_file"],
row["emb_file"],
row["slide_file"],
)
# get native microns per pixel
slide = OpenSlide(slide_path)
native_mpp_of_slide = float(slide.properties["openslide.mpp-x"])
del slide
# load attention
try:
attention = torch.load(attn_path)
except FileNotFoundError:
print(f"File not found: {attn_path}")
return []
attention = attention[-1, 0, 1:] # last layer, wrt cls token, no cls token self-attn
attention = attention.flatten()
# load coords and convert to native resolution coords
with h5py.File(emb_path, "r") as f:
coords = f["coords"][: len(attention)]
native_scaled_coords = (coords * MPP / native_mpp_of_slide).astype(int)
native_scaled_tile_size = int(TILE_SIZE * MPP / native_mpp_of_slide)
# identify top 3 attention values
top_ind = np.argpartition(attention, -3)[-3:]
top_coords_svs_system = native_scaled_coords[top_ind]
top_coords_stamp_system = coords[top_ind]
# identify lowest attention value
min_ind = np.argmin(attention)
min_coord_svs_system = native_scaled_coords[min_ind]
min_coord_stamp_system = coords[min_ind]
# create records with coord0, coord1, tile_size, top/bottom
records = []
for coord_svs, coord_stamp in zip(top_coords_svs_system, top_coords_stamp_system):
records.append(
{
"attn_file": attn_path,
"emb_file": emb_path,
"slide_file": slide_path,
"oncotype_score": row["oncotype_score"],
"image_id": row["image_id"],
"regex_grade": row["regex_grade"],
"regex_ihc_pr": row["regex_ihc_pr"],
"coord0_svs_system": coord_svs[0],
"coord1_svs_system": coord_svs[1],
"coord0_stamp_res": coord_stamp[0],
"coord1_stamp_res": coord_stamp[1],
"tile_size_svs_system": native_scaled_tile_size,
"tile_size_stamp_res": TILE_SIZE,
"top_bottom": "top",
}
)
records.append(
{
"attn_file": attn_path,
"emb_file": emb_path,
"slide_file": slide_path,
"oncotype_score": row["oncotype_score"],
"image_id": row["image_id"],
"regex_grade": row["regex_grade"],
"regex_ihc_pr": row["regex_ihc_pr"],
"coord0_svs_system": min_coord_svs_system[0],
"coord1_svs_system": min_coord_svs_system[1],
"coord0_stamp_res": min_coord_stamp_system[0],
"coord1_stamp_res": min_coord_stamp_system[1],
"tile_size_svs_system": native_scaled_tile_size,
"tile_size_stamp_res": TILE_SIZE,
"top_bottom": "bottom",
}
)
return records
def define_df():
df = pd.read_csv("../dataframes/holdout_df_with_regex_feats.csv")
df = df[
[
"oncotype_score",
"image_id",
"regex_grade",
"regex_ihc_pr",
]
]
df["attn_file"] = df["image_id"].apply(lambda x: os.path.join(ATTN_DIR, f"{x}.pt"))
df["emb_file"] = df["image_id"].apply(lambda x: os.path.join(EMB_DIR, f"{x}.h5"))
df["slide_file"] = df["image_id"].apply(lambda x: os.path.join(SLIDE_DIR, f"{x}.svs"))
# records = []
# for idx, row in df.iterrows():
# records.extend(_identify_top_attn_and_min_attn(row))
records = Parallel(n_jobs=36)(
delayed(_identify_top_attn_and_min_attn)(row) for idx, row in df.iterrows()
)
records = [item for sublist in records for item in sublist]
tilewise_df = pd.DataFrame.from_records(records)
return tilewise_df
def _load_tile_to_dir(row):
slide_path, coord0, coord1, tile_size, save_path = (
row["slide_file"],
row["coord0"],
row["coord1"],
row["tile_size"],
row["tile_path"],
)
slide = OpenSlide(slide_path)
tile = slide.read_region((coord0, coord1), 0, (tile_size, tile_size))
# get magnification
mag = int(slide.properties["aperio.AppMag"])
if mag == 40:
pass
elif mag == 20:
tile = tile.resize((tile_size * 2, tile_size * 2))
else:
raise ValueError(f"Mag {mag} not supported")
tile.save(save_path)
slide.close()
def load_tiles_to_dir(df):
df["tile_path"] = df.apply(
lambda x: f"tiles_for_hovernet/{x['image_id']}_{x['coord0']}_{x['coord1']}.png",
axis=1,
)
# for idx, row in df.iterrows():
# _load_tile_to_dir(row)
Parallel(n_jobs=36)(delayed(_load_tile_to_dir)(row) for idx, row in df.iterrows())
df.to_csv("tilewise_df_tiled.csv", index=False)
def extract_features_from_tiles(df):
df["hovernet_output"] = df["tile_path"].apply(
lambda x: x.replace("tiles_for_hovernet", "preds_by_hovernet/json").replace(".png", ".json")
)
records = Parallel(n_jobs=36)(
delayed(_extract_features_from_tile)(row) for idx, row in df.iterrows()
)
feats_df = pd.DataFrame.from_records(records)
feats_df.to_csv("tilewise_df_tiled_feats.csv", index=False)
def _extract_features_from_tile(row):
hovernet_output = row["hovernet_output"]
try:
with open(hovernet_output, "r") as f:
data = json.load(f)
feats, feature_names, graph_feats_dict = extract_features_single_tile(data)
# create dict of features to add to df
record = dict(zip(feature_names, feats))
except FileNotFoundError:
record = dict()
record["hovernet_output"] = hovernet_output
record["top_bottom"] = row["top_bottom"]
record["oncotype_score"] = row["oncotype_score"]
record["pred"] = row["pred"]
record["hi_or_lo"] = row["hi_or_lo"]
return record
def test_top_vs_bottom_attn(df):
df = df.set_index("hovernet_output")
top_df = df[df["top_bottom"] == "top"].copy()
top_df = top_df.drop(columns=["oncotype_score", "top_bottom"])
bottom_df = df[df["top_bottom"] == "bottom"].copy()
bottom_df = bottom_df.drop(columns=["oncotype_score", "top_bottom"])
feat_names = top_df.columns
feat_names = [f for f in feat_names if f not in ["pred", "oncotype_score", "hi_or_lo"]]
records = []
for feat_name in feat_names:
top_vals = top_df[feat_name].values
top_mean = np.mean(top_vals)
bottom_vals = bottom_df[feat_name].values
bottom_mean = np.mean(bottom_vals)
stat, p = mannwhitneyu(top_vals, bottom_vals)
records.append(
{
"feat_name": feat_name,
"top_mean": top_mean,
"bottom_mean": bottom_mean,
"stat": stat,
"p": p,
}
)
records = pd.DataFrame.from_records(records)
records["p_adj"] = multipletests(records["p"], method="fdr_bh")[1]
records = records.sort_values("p_adj")
records.to_csv("results_top_vs_bottom.csv", index=False)
print(records)
def test_hi_vs_lo_onco(df):
df = df.set_index("hovernet_output")
df = condition_features(df)
df = df[df["top_bottom"] == "top"].copy().drop(columns=["top_bottom"])
hi_df = df[df["hi_or_lo"] == "hi"].copy()
lo_df = df[df["hi_or_lo"] == "lo"].copy()
feat_names = hi_df.columns
feat_names = [f for f in feat_names if f not in ["pred", "oncotype_score", "hi_or_lo"]]
records = []
for feat_name in feat_names:
hi_vals = hi_df[feat_name].dropna().values
hi_mean = np.mean(hi_vals)
lo_vals = lo_df[feat_name].dropna().values
lo_mean = np.mean(lo_vals)
stat, p = mannwhitneyu(hi_vals, lo_vals)
records.append(
{
"feat_name": feat_name,
"hi_mean": hi_mean,
"lo_mean": lo_mean,
"stat": stat,
"p": p,
}
)
records = pd.DataFrame.from_records(records)
records["p_adj"] = multipletests(records["p"], method="fdr_bh")[1]
records = records.sort_values("p_adj")
records.to_csv("results_hi_vs_lo.csv", index=False)
print(records)
def plot_volcano(df, comp_str):
import matplotlib.pyplot as plt
import seaborn as sns
cell_name_mapping = {
"nolabe": "unlab.",
"no-neo": "non-neo.",
"inflam": "inflam.",
"neopla": "neo.",
"connec": "connec.",
"necros": "necros.",
"total": "total",
}
assert comp_str in ["hi_vs_lo", "top_vs_bottom"]
with open("type_info.json", "r") as f:
palette = json.load(f)
palette = dict([(cell_name_mapping[v[0]], [x / 255.0 for x in v[1]]) for v in palette.values()])
palette["total"] = [165 / 255.0, 42 / 255.0, 42 / 255.0]
fig, ax = plt.subplots(figsize=(3, 3))
df["-log10(p)"] = -np.log10(df["p_adj"])
if comp_str == "hi_vs_lo":
df["log2(FC)"] = np.log2(df["hi_mean"] / df["lo_mean"])
else:
df["log2(FC)"] = np.log2(df["top_mean"] / df["bottom_mean"])
df["cell"] = df["feat_name"].apply(lambda x: x.split("_")[-1])
df["cell"] = df["cell"].map(cell_name_mapping)
sns.scatterplot(
data=df,
x="log2(FC)",
y="-log10(p)",
hue="cell",
palette=palette,
ax=ax,
s=17,
linewidth=0,
)
sns.move_legend(ax, "upper left", labelspacing=0, borderpad=0.2, title=None, fontsize=10)
if comp_str == "hi_vs_lo":
ax.set_xlabel("log2(FC) (High vs low score)", fontsize=10)
else:
ax.set_xlabel("log2(FC) (Top vs bottom attn.)", fontsize=10)
ax.set_ylabel("-log10(p)", fontsize=10)
plt.tick_params(axis="both", which="major", labelsize=10)
ax.axhline(-np.log10(SIG), linestyle=":", color="black", alpha=0.5)
ax.text(0.2, -np.log10(SIG) + 0.2, f"p < {SIG}", fontsize=6)
ax.axvline(0, linestyle=":", color="black", alpha=0.5)
# remove upper and right spines
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.tight_layout()
plt.savefig(f"../../plots/volcano_{comp_str}_unlabeled.png", dpi=300)
# label all the points with p < 0.001
for idx, row in df.iterrows():
if row["-log10(p)"] > -np.log10(SIG):
label = " ".join(row["feat_name"].split("_")[:-1])
label = label.replace("abundance relative", "rel. abund.")
label = label.replace("std", "std. dev.")
label = label.replace("median", "med.")
label = label.replace("solidity", "solid.")
label = label.replace("perimeter", "perim.")
if row["log2(FC)"] > 0.2:
x_pos = row["log2(FC)"] - 0.05
h_al = "right"
else:
x_pos = row["log2(FC)"] + 0.05
h_al = "left"
ax.text(
x_pos,
row["-log10(p)"] + 0.5,
label,
horizontalalignment=h_al,
verticalalignment="bottom",
fontsize=6,
)
plt.tight_layout()
plt.savefig(f"../../plots/volcano_{comp_str}_labeled.png", dpi=300)
plt.show()
def get_example_extrema_of_signif_feats(r, f):
import shutil
r = r[r["p_adj"] < SIG]
f = condition_features(f)
f["tile_path"] = f["hovernet_output"].apply(
lambda x: x.replace("preds_by_hovernet/json", "preds_by_hovernet/overlay").replace(
".json", ".png"
)
)
for idx, row in r.iterrows():
if row["p_adj"] > SIG:
continue
print(f"{row['feat_name']}: {row['p_adj']}")
feat_name = row["feat_name"]
# get bottom and top five
_f_ = f.sort_values(feat_name).dropna(subset=[feat_name])
for idx, row in _f_.head().iterrows():
shutil.copy(
row["tile_path"],
f"../../plots/tile_feat_examples/{feat_name}_bottom{idx}.png",
)
for idx, row in _f_.tail().iterrows():
shutil.copy(
row["tile_path"],
f"../../plots/tile_feat_examples/{feat_name}_top{idx}.png",
)
def condition_features(df):
df = df[df["abundance_total"] > 50]
for cell in ["nolabe", "no-neo", "inflam", "neopla", "connec", "necros"]:
df["abundance_absolute_" + cell] = None
df.loc[:, "abundance_absolute_" + cell] = (
df["abundance_relative_" + cell] * df["abundance_total"]
)
# replace any std dev values with nan if the total number of that cell type is < 10
for col_name in df.columns:
if "std" in col_name:
cell = col_name.split("_")[-1]
df.loc[df["abundance_absolute_" + cell] < 10, col_name] = np.nan
df = df[[c for c in df.columns if "absolute" not in c]]
print(df)
return df
def make_emb_ad(df):
df = df[
[
"image_id",
"coord0_stamp_res",
"coord1_stamp_res",
"tile_size_stamp_res",
"top_bottom",
"oncotype_score",
"emb_file",
]
]
df["coord0_stamp_res"] = df["coord0_stamp_res"].astype(int)
df["coord1_stamp_res"] = df["coord1_stamp_res"].astype(int)
embs = []
for row_idx, row in df.iterrows():
with h5py.File(row["emb_file"], "r") as f:
print(f.keys())
feats = f["feats"][:]
coords = f["coords"][:]
feats_idx = np.where(
(coords[:, 0] == row["coord0_stamp_res"]) & (coords[:, 1] == row["coord1_stamp_res"])
)[0][0]
embs.append(feats[feats_idx])
embs = np.stack(embs)
# make anndata
obs = df[
[
"image_id",
"coord0_stamp_res",
"coord1_stamp_res",
"tile_size_stamp_res",
"top_bottom",
"oncotype_score",
]
]
obs["image_id"] = obs["image_id"].astype(str)
frame = ad.AnnData(X=embs, obs=obs)
print(frame)
return frame
if __name__ == "__main__":
tilewise_df = define_df()
tilewise_df.to_csv("tilewise_df.csv", index=False)
tilewise_df = pd.read_csv("tilewise_df.csv")
emb_ad = make_emb_ad(tilewise_df)
emb_ad.write("emb_df.h5ad")
tilewise_df = pd.read_csv("tilewise_df.csv")
tilewise_df["coord0"] = tilewise_df["coord0_svs_system"]
tilewise_df["coord1"] = tilewise_df["coord1_svs_system"]
tilewise_df["tile_size"] = tilewise_df["tile_size_svs_system"]
# add predicted_score to tilewise_df
pred_df = pd.read_csv(
"df.csv"
)
pred_df = pred_df[["pred", "image_id"]].set_index("image_id")
tilewise_df = tilewise_df.join(pred_df, on="image_id", how="inner")
# sort by pred
tilewise_df = tilewise_df.sort_values("pred")
image_ids_unique = tilewise_df["image_id"].drop_duplicates().tolist()
top_50_img_ids = image_ids_unique[-100:]
bottom_50_img_ids = image_ids_unique[:100]
# create column hi_or_lo
tilewise_df["hi_or_lo"] = "--"
# top 50 images are hi, bottom 50 are lo, rest are dropped
tilewise_df.loc[tilewise_df.image_id.isin(top_50_img_ids), "hi_or_lo"] = "hi"
tilewise_df.loc[tilewise_df.image_id.isin(bottom_50_img_ids), "hi_or_lo"] = "lo"
tilewise_df = tilewise_df[tilewise_df["hi_or_lo"] != "--"]
print(tilewise_df)
load_tiles_to_dir(tilewise_df)
# stop here, run HoverNet
tilewise_df = pd.read_csv("tilewise_df_tiled.csv")
extract_features_from_tiles(tilewise_df)
feat_df = pd.read_csv("tilewise_df_tiled_feats.csv")
feat_df = feat_df.dropna()
test_top_vs_bottom_attn(feat_df)
test_hi_vs_lo_onco(feat_df)
results_df = pd.read_csv("results_hi_vs_lo.csv")
plot_volcano(results_df, "hi_vs_lo")
results_df = pd.read_csv("results_hi_vs_lo.csv")
feat_df = pd.read_csv("tilewise_df_tiled_feats.csv")
get_example_extrema_of_signif_feats(results_df, feat_df)
top_v_bottom_df = pd.read_csv("results_top_vs_bottom.csv")
plot_volcano(top_v_bottom_df, "top_vs_bottom")
import numpy as np
from shapely import Polygon
import pandas as pd
import scipy.spatial.distance as dist
HOVERNET_KEY = {
"0" : ["nolabe", [0 , 0, 0]],
"1" : ["neopla", [255, 0, 0]],
"2" : ["inflam", [0 , 255, 0]],
"3" : ["connec", [0 , 0, 255]],
"4" : ["necros", [255, 255, 0]],
"5" : ["no-neo", [255, 165, 0]]
} # modify based on classes and colors of choice
MAPPING = dict((int(k), v[0]) for (k, v) in HOVERNET_KEY.items())
def extract_features_single_tile(data):
nuclei = list(data["nuc"].items())
# nuclei = filter_low_prob(nuclei)
graph_feats = dict()
graph_feats.update(extract_centroid_and_type(nuclei))
graph_feats.update(extract_nuclear_geometry(nuclei))
# graph_feats.update(extract_nuclear_pixels(graph_feats))
simple_tile_wise_feats = extract_abundance(nuclei, relative=True)
simple_tile_wise_feats.update(aggregate_cell_geom_features_tilewise(graph_feats))
# convert simple_tile_wise_feats to list of (name, value) tuples
simple_tile_wise_feats = [(k, v) for k, v in simple_tile_wise_feats.items()]
# sort by name
simple_tile_wise_feats = sorted(simple_tile_wise_feats, key=lambda x: x[0])
simple_feature_names, simple_feature_values = zip(*simple_tile_wise_feats)
return simple_feature_values, simple_feature_names, graph_feats
def extract_centroid_and_type(nuclei):
features = {"centroid": [], "type": [], "type_prob": [], "nuc_id": []}
for nuc_id, nuc in nuclei:
features["centroid"].append(nuc["centroid"])
features["type"].append(nuc["type"])
features["type_prob"].append(nuc["type_prob"])
features["nuc_id"].append(nuc_id)
return features
def filter_low_prob(nuclei):
return [nuc for nuc in nuclei if nuc["type_prob"] > 0.5]
def extract_abundance(nuclei, relative=False):
features = dict()
counts = dict([(v, 0) for v in MAPPING.values()])
for nuc_id, nuc in nuclei:
nuc_type = MAPPING[nuc["type"]]
counts[nuc_type] += 1
root_name = "abundance_relative_" if relative else "abundance_"
feat_names = [root_name + k for k in counts.keys()]
features["abundance_total"] = len(nuclei)
for k, v in zip(feat_names, counts.values()):
if relative:
try:
features[k] = v / len(nuclei) if k != "abundance_total" else v
except ZeroDivisionError:
features[k] = 0
else:
features[k] = v
return features
def extract_nuclear_geometry(nuclei):
geom = {"area": [], "perimeter": [], "solidity": [], "type": []}
for nuc_id, nuc in nuclei:
# if nuc_type not in ['connec', 'inflam', 'neopla', 'no-neo']:
# continue
p = Polygon(nuc["contour"])
geom["area"].append(p.area * 0.0625) # convert from pix_40x^2 to um^2
geom["perimeter"].append(p.length * 0.25) # convert from pix_40x to um
geom["solidity"].append(p.area / p.convex_hull.area) # dimensionless
geom["type"].append(nuc["type"])
return geom
def aggregate_cell_geom_features_tilewise(graph_feat_dict):
features = dict()
vars = [
"area_mean_nolabe",
"area_mean_neopla",
"area_mean_inflam",
"area_mean_connec",
"area_mean_necros",
"area_mean_no-neo",
"area_std_nolabe",
"area_std_neopla",
"area_std_inflam",
"area_std_connec",
"area_std_necros",
"area_std_no-neo",
"area_median_nolabe",
"area_median_neopla",
"area_median_inflam",
"area_median_connec",
"area_median_necros",
"area_median_no-neo",
"perimeter_mean_nolabe",
"perimeter_mean_neopla",
"perimeter_mean_inflam",
"perimeter_mean_connec",
"perimeter_mean_necros",
"perimeter_mean_no-neo",
"perimeter_std_nolabe",
"perimeter_std_neopla",
"perimeter_std_inflam",
"perimeter_std_connec",
"perimeter_std_necros",
"perimeter_std_no-neo",
"perimeter_median_nolabe",
"perimeter_median_neopla",
"perimeter_median_inflam",
"perimeter_median_connec",
"perimeter_median_necros",
"perimeter_median_no-neo",
"solidity_mean_nolabe",
"solidity_mean_neopla",
"solidity_mean_inflam",
"solidity_mean_connec",
"solidity_mean_necros",
"solidity_mean_no-neo",
"solidity_std_nolabe",
"solidity_std_neopla",
"solidity_std_inflam",
"solidity_std_connec",
"solidity_std_necros",
"solidity_std_no-neo",
"solidity_median_nolabe",
"solidity_median_neopla",
"solidity_median_inflam",
"solidity_median_connec",
"solidity_median_necros",
"solidity_median_no-neo",
]
for var in vars:
features[var] = 0
df = pd.DataFrame.from_dict(graph_feat_dict)
df = df[["area", "perimeter", "solidity", "type"]]
df = df.groupby("type").agg(["mean", "std", "median"])
df = pd.melt(df.reset_index(), id_vars=["type"])
df["variable"] = df["variable_0"] + "_" + df["variable_1"] + "_" + df["type"].map(MAPPING)
df = df.drop(columns=["variable_0", "variable_1", "type"])
for _, row in df.iterrows():
features[row["variable"]] = row["value"]
return features
def convert_tile_to_geometric_format(graph_feats_dict):
features = dict()
features["pos"] = np.array(graph_feats_dict["centroid"]) * 0.25 # convert pixels to microns
features["x"], features["nuc_id"], features["feat_names"] = _format_vertex_features(graph_feats_dict)
features["edge_index"], features["edge_attr"] = _format_edge_features(features["pos"], 50)
return features
def _format_vertex_features(graph_feats_dict):
df = pd.DataFrame.from_dict(graph_feats_dict)
for k, v in MAPPING.items():
df["type_" + v] = df["type_prob"] * (df["type"] == k)
nuc_ids = df["nuc_id"].values
df = df.drop(columns=["centroid", "type", "nuc_id", "type_prob"])
df = df.sort_index(axis=1)
return df.values, nuc_ids, df.columns
def _format_edge_features(centroids, threshold):
if len(centroids) == 0:
return np.zeros((2, 0)), np.zeros((0, 1))
distances = dist.cdist(centroids, centroids) # in microns
distances += np.diag(np.ones(len(distances)) * np.inf) # don't include self loops
connected = distances <= threshold
connected = np.argwhere(connected)
edge_len = distances[connected[:, 0], connected[:, 1]]
edge_len = np.expand_dims(edge_len, axis=1)
connected = np.ascontiguousarray(connected.T)
return connected, edge_len
@kmboehm
Copy link
Author

kmboehm commented Feb 22, 2024

Code used to generate and analyze nuclear features based on HoverNet nuclear features and tile-wise attention scores from Orpheus.

Disclaimer: Author makes no warranties or representations, and hereby disclaims any warranties, express or implied, with respect to any of the Content, including as to the present accuracy, completeness, timeliness, adequacy, or usefulness of any of the Content. The entire risk as to the quality and performance of the Content is with you. By using this Content, you agree that Memorial Sloan Kettering Cancer Center will not be liable for any losses or damages arising from your use of or reliance on the Content, or other websites or information to which this Content may be linked, including any general, special, incidental or consequential damages arising out of the use or inability to use the Content including, but not limited to, loss of data or data being rendered inaccurate or losses sustained by you or third parties or a failure of the Content to operate with any other software, programs, source code, etc.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment