Created
September 16, 2021 10:31
-
-
Save Arafat245/c1d526b4f0103b48232af71c8a469407 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.metrics import accuracy_score | |
from sklearn.metrics import cohen_kappa_score | |
from sklearn.metrics import precision_recall_fscore_support | |
def validation_prediction(dataloader,model,loss_fn): | |
model.eval() | |
total = 0 | |
correct = 0 | |
running_loss = 0 | |
actual_label_list = [] | |
predicted_label_list = [] | |
with torch.no_grad(): | |
for x,y in dataloader: | |
output = model(x.to(device)) | |
loss = loss_fn(output, y.to(device)).item() | |
running_loss += loss | |
total += y.size(0) | |
predictions = output.argmax(dim=1).cpu().detach() | |
correct += (predictions == y.cpu().detach()).sum().item() | |
actual_label_list.append(y.cpu().detach().tolist()) | |
predicted_label_list.append(predictions.tolist()) | |
avg_loss = running_loss/len(dataloader) | |
validation_accuracy = 100*(correct/total) | |
print(f'\nValidation Loss = {avg_loss:.6f}',end='\t') | |
print(f'Accuracy on Validation set = {100*(correct/total):.6f}% [{correct}/{total}]') | |
actual_labels = [item for sublist in actual_label_list for item in sublist] | |
predicted_labels = [item for sublist in predicted_label_list for item in sublist] | |
acc = accuracy_score(actual_labels, predicted_labels) | |
p, r, f1, _ = precision_recall_fscore_support(actual_labels, predicted_labels, average='macro') | |
kappa = cohen_kappa_score(actual_labels, predicted_labels) | |
print('Acc: ' + str(acc*100)+'%') | |
print('Precision: ' + str(p*100)+'%') | |
print('Recall: ' + str(r*100)+'%') | |
print('F1 score: ' + str(f1*100)+'%') | |
print('Kappa score: ' + str(kappa*100)+'%') | |
return avg_loss, validation_accuracy | |
valid_loss, validation_accuracy = validation_prediction(valid_dataloader, model, loss_fn) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment