Created
May 31, 2023 19:22
-
-
Save isaaccorley/92d32c1cd818251f70996ea04ba83d1b to your computer and use it in GitHub Desktop.
PyTorch Lightning KNN Classifier Evaluation Callback
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pip install torch lightning scikit-learn numpy tqdm faissknn | |
import lightning.pytorch as pl | |
import numpy as np | |
import torch | |
from faissknn import FaissKNNClassifier | |
from lightning.pytorch.utilities import rank_zero_only | |
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score | |
from tqdm import tqdm | |
class KNNEval(pl.callbacks.Callback): | |
def __init__(self, datamodule, k=5, check_every_n_epochs=1, device: str = "cuda:0"): | |
self.datamodule = datamodule | |
self.datamodule.setup() | |
self.k = k | |
self.check_every_n_epochs = check_every_n_epochs | |
self.device = device | |
@rank_zero_only | |
def on_train_epoch_end(self, trainer, pl_module): | |
if (trainer.current_epoch + 1) % self.check_every_n_epochs != 0: | |
return | |
# Get image encoder | |
device = pl_module.device | |
backbone = pl_module.image_encoder | |
backbone.eval() | |
# Get train set embeddings | |
x_train, y_train = [], [] | |
dataloader = self.datamodule.train_dataloader() | |
for batch in tqdm(dataloader, total=len(dataloader)): | |
x, y = batch["image"].to(device), batch["label"] | |
with torch.no_grad(): | |
with torch.inference_mode(): | |
with torch.cuda.amp.autocast(): | |
emb = backbone(x).detach() | |
x_train.append(emb.cpu().numpy()) | |
y_train.append(y.detach().cpu().numpy()) | |
x_train = np.concatenate(x_train, axis=0) | |
y_train = np.concatenate(y_train, axis=0) | |
# Get val set embeddings | |
x_val, y_val = [], [] | |
dataloader = self.datamodule.val_dataloader() | |
for batch in tqdm(dataloader, total=len(dataloader)): | |
x, y = batch["image"].to(device), batch["label"] | |
with torch.no_grad(): | |
with torch.inference_mode(): | |
with torch.cuda.amp.autocast(): | |
emb = backbone(x).detach() | |
x_val.append(emb.cpu().numpy()) | |
y_val.append(y.detach().cpu().numpy()) | |
x_val = np.concatenate(x_val, axis=0) | |
y_val = np.concatenate(y_val, axis=0) | |
# Fit knn model | |
knn = FaissKNNClassifier(n_neighbors=self.k, device=self.device) | |
knn.fit(X=x_train, y=y_train) | |
y_pred = knn.predict(x_val) | |
# Compute metrics | |
metrics = { | |
"val_f1_weighted": f1_score(y_val, y_pred, average="weighted"), | |
"val_f1_macro": f1_score(y_val, y_pred, average="macro"), | |
"val_f1_micro": f1_score(y_val, y_pred, average="micro"), | |
"val_precision_micro": precision_score(y_val, y_pred, average="micro"), | |
"val_precision_macro": precision_score(y_val, y_pred, average="macro"), | |
"val_precision_weighted": precision_score( | |
y_val, y_pred, average="weighted" | |
), | |
"val_recall_micro": recall_score(y_val, y_pred, average="micro"), | |
"val_recall_macro": recall_score(y_val, y_pred, average="macro"), | |
"val_recall_weighted": recall_score(y_val, y_pred, average="weighted"), | |
"val_accuracy": accuracy_score(y_val, y_pred), | |
} | |
# Log metrics | |
pl_module.log_dict(metrics, rank_zero_only=True, on_epoch=True, sync_dist=True) | |
backbone.train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You will have to change line 26 depending on to link to the backbone you're using. In this case I used
timm.create_model("resnet50", pretrained=True, num_classes=0)
as myimage_encoder