Skip to content

Instantly share code, notes, and snippets.

Created June 22, 2018 13:20
Show Gist options
  • Save weiaicunzai/2a5ae6eac6712c70bde0630f3e76b77b to your computer and use it in GitHub Desktop.
Save weiaicunzai/2a5ae6eac6712c70bde0630f3e76b77b to your computer and use it in GitHub Desktop.
compute top1, top5 error using pytorch
from __future__ import print_function, absolute_import
__all__ = ['accuracy']
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
Copy link

slmatrix commented Sep 5, 2023

code below requires invocation for each different K.
imo, removing the loop and tuple makes the code clearer:

def accuracy(result, answer, topk=1):
    result (batch_size, class_cnt)
    answer (batch_size)
    #save the batch size before tensor mangling
    bz = answer.size(0)
    #ignore result values. its indices: (sz,cnt) -> (sz,topk)
    values, indices = result.topk(topk)
    #transpose the k best indice
    result = indices.t()  #(sz,topk) -> (topk, sz)
    #repeat same labels topk times to match result's shape
    answer = answer.view(1, -1)       #(sz) -> (1,sz)
    answer = answer.expand_as(result) #(1,sz) -> (topk,sz)

    correct = (result == answer)    #(topk,sz) of bool vals
    correct = correct.flatten()     #(topk*sz) of bool vals
    correct = correct.float()       #(topk*sz) of 1s or 0s
    correct = correct.sum()         #counts 1s (correct guesses)
    correct = correct.mul_(100/bz)  #convert into percentage

    return correct.item()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment