Skip to content

Instantly share code, notes, and snippets.

@SumanSudhir
Created May 19, 2020 13:51
Show Gist options
  • Save SumanSudhir/0dd313e3ad82f6a76659032f20d0b2d0 to your computer and use it in GitHub Desktop.
Save SumanSudhir/0dd313e3ad82f6a76659032f20d0b2d0 to your computer and use it in GitHub Desktop.
def kappa_loss(p, y, n_classes=6, eps=1e-10):
"""
QWK loss function as described in https://arxiv.org/pdf/1612.00775.pdf
Arguments:
p: a tensor with probability predictions, [batch_size, n_classes],
y, a tensor with one-hot encoded class labels, [batch_size, n_classes]
Returns:
QWK loss
"""
y = y.float()
p = p.float()
W = np.zeros((n_classes, n_classes))
for i in range(n_classes):
for j in range(n_classes):
W[i,j] = (i-j)**2
W = torch.from_numpy(W.astype(np.float32)).float().to(device)
O = torch.matmul(y.t(), p)
E = torch.matmul(y.sum(dim=0).view(-1,1), p.sum(dim=0).view(1,-1)) / O.sum()
return (W*O).sum() / ((W*E).sum() + eps)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment