Skip to content

Instantly share code, notes, and snippets.

@Pangoraw
Created June 15, 2022 09:25
Show Gist options
  • Save Pangoraw/c5468d15b640063376010edb49ce731e to your computer and use it in GitHub Desktop.
Save Pangoraw/c5468d15b640063376010edb49ce731e to your computer and use it in GitHub Desktop.
Online Cohen Kappa coefficient
from typing import Dict, Tuple
import torch
from torch import Tensor
class OnlineKappa:
"""
Computes an online version of the Cohen's Kappa Coefficient.
>>> k = OnlineKappa(n_classes = 2)
>>> k.update(torch.tensor([1, 0, 0]), torch.tensor([1, 1, 0]))
3
>>> k.value()
0.39999999999999997
"""
def __init__(self, n_classes: int) -> None:
self.n_classes = n_classes
self.n_observations: int = 0
self.n_agreed: int = 0
self.data: Dict[int, Tuple[int, int]] = {
cls: (0, 0) for cls in range(n_classes)
}
def update(self, y1: Tensor, y2: Tensor) -> int:
assert y1.shape == y2.shape
assert torch.all(y1 >= 0) and torch.all(y1 < self.n_classes)
assert torch.all(y2 >= 0) and torch.all(y2 < self.n_classes)
self.n_observations += y1.numel()
for cls in range(self.n_classes):
c1, c2 = self.data[cls]
self.data[cls] = (
int(c1 + (y1 == cls).sum().item()),
int(c2 + (y2 == cls).sum().item()),
)
self.n_agreed += int((y1 == y2).sum())
return self.n_observations
def value(self) -> float:
po = self.n_agreed / self.n_observations
pe = sum([a * b for a, b in self.data.values()]) / (self.n_observations ** 2)
return (po - pe) / (1 - pe)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment