Skip to content

Instantly share code, notes, and snippets.

@theeluwin
Last active December 13, 2021 13:12
Show Gist options
  • Save theeluwin/1b649906aff902ef1d1d83c78519ffcc to your computer and use it in GitHub Desktop.
Save theeluwin/1b649906aff902ef1d1d83c78519ffcc to your computer and use it in GitHub Desktop.
HR, Recall, nDCG, AUROC @ k (recommendation evaluation metric)
import torch
from typing import (
Any,
List,
Dict,
)
def calc_batch_rec_metrics_per_k(ranks: torch.LongTensor,
labels: torch.LongTensor,
ks: List[int]
) -> Dict[str, float]:
"""
Args:
ranks: LongTensor, (b x C), rank per pos (0 to C-1)
labels: LongTensor, (b x C), binary per pos (0 or 1)
ks: list of top-k values
Returns:
a dict of various metrics.
keys are 'count', 'means', 'values'.
put'em all in the same device.
"""
# prepare
batch_size = ranks.size(0)
metrics: Dict[str, Any] = {
'count': batch_size,
'values': {},
'means': {},
}
answer_count = labels.sum(1)
device = labels.device
ks = sorted(ks, reverse=True)
# for each k
for k in ks:
ranks_at_k = ranks[:, :k]
hit_per_pos = labels.gather(1, ranks_at_k)
# hr
hrs = hit_per_pos.sum(1).bool().float()
hrs_list = list(hrs.detach().cpu().numpy())
metrics['values'][f'HR@{k}'] = hrs_list
metrics['means'][f'HR@{k}'] = sum(hrs_list) / batch_size
# recall
divisor = torch.min(
torch.Tensor([k], device=device),
answer_count,
)
recalls = (hit_per_pos.sum(1) / divisor.float())
recalls_list = list(recalls.detach().cpu().numpy())
metrics['values'][f'Recall@{k}'] = recalls_list
metrics['means'][f'Recall@{k}'] = sum(recalls_list) / batch_size
# ndcg
positions = torch.arange(1, k + 1, device=device).float()
weights = 1 / (positions + 1).log2()
dcg = (hit_per_pos * weights).sum(1)
idcg = torch.Tensor([weights[:min(n, k)].sum() for n in answer_count], device=device)
ndcgs = dcg / idcg
ndcgs_list = list(ndcgs.detach().cpu().numpy())
metrics['values'][f'NDCG@{k}'] = ndcgs_list
metrics['means'][f'NDCG@{k}'] = sum(ndcgs_list) / batch_size
# auroc
positions = torch.arange(k)
flag_per_pos = hit_per_pos.detach().cpu().bool()
aurocs = []
for b in range(batch_size):
poss = positions[flag_per_pos[b]]
negs = positions[~flag_per_pos[b]]
poss_count = poss.size(0)
negs_count = negs.size(0)
total_count = poss_count * negs_count
if not negs_count:
auroc = 1.0
elif not poss_count:
auroc = 0.0
else:
rocs = [(pos < negs).int().sum() for pos in poss]
auroc = sum(rocs) / total_count
auroc = float(auroc)
aurocs.append(auroc)
metrics['values'][f'AUROC@{k}'] = aurocs
metrics['means'][f'AUROC@{k}'] = sum(aurocs) / batch_size
return metrics
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment