Skip to content

Instantly share code, notes, and snippets.

@williamFalcon
Created November 23, 2019 09:43
Show Gist options
  • Save williamFalcon/2ff96b23fe28a479ac079b8580d4d3ce to your computer and use it in GitHub Desktop.
Save williamFalcon/2ff96b23fe28a479ac079b8580d4d3ce to your computer and use it in GitHub Desktop.
def test_step(self, batch, batch_nb):
input_ids, attention_mask, token_type_ids, label = batch
y_hat, attn = self.forward(input_ids, attention_mask, token_type_ids)
a, y_hat = torch.max(y_hat, dim=1)
test_acc = accuracy_score(y_hat.cpu(), label.cpu())
return {'test_acc': torch.tensor(test_acc)}
def test_end(self, outputs):
avg_test_acc = torch.stack([x['test_acc'] for x in outputs]).mean()
tensorboard_logs = {'avg_test_acc': avg_test_acc}
return {'avg_test_acc': avg_test_acc, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment