Skip to content

Instantly share code, notes, and snippets.

@alonstern
Created April 13, 2020 12:39
Show Gist options
  • Save alonstern/4e0a8e6cfc1e42e23dfa724d6ea7d70c to your computer and use it in GitHub Desktop.
Save alonstern/4e0a8e6cfc1e42e23dfa724d6ea7d70c to your computer and use it in GitHub Desktop.
test the model
def test_model(model, test_dataset):
test_loader = data.DataLoader(test_dataset)
model.eval()
with torch.no_grad():
all_tags = []
all_tag_scores = []
for sample, tags in tqdm.tqdm(test_loader):
sample = sample[0]
tags = tags[0]
tag_scores = model(sample)
all_tags.extend(tags.numpy())
all_tag_scores.extend(tag_scores.numpy())
all_tags = numpy.array(all_tags)
all_tag_scores = numpy.array(all_tag_scores).argmax(axis=1)
accuracy = accuracy_score(all_tags, all_tag_scores)
pr = precision_score(all_tags, all_tag_scores)
recall = recall_score(all_tags, all_tag_scores)
f1 = f1_score(all_tags, all_tag_scores)
print("accuracy: {}".format(accuracy))
print("pr: {}".format(pr))
print("recall: {}".format(recall))
print("f1: {}".format(f1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment