Created
March 14, 2021 05:44
-
-
Save ketanhdoshi/a1181fb8af0c0eb083b7646dc7afb053 to your computer and use it in GitHub Desktop.
Sound Classification Inference
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ---------------------------- | |
# Inference | |
# ---------------------------- | |
def inference (model, val_dl): | |
correct_prediction = 0 | |
total_prediction = 0 | |
# Disable gradient updates | |
with torch.no_grad(): | |
for data in val_dl: | |
# Get the input features and target labels, and put them on the GPU | |
inputs, labels = data[0].to(device), data[1].to(device) | |
# Normalize the inputs | |
inputs_m, inputs_s = inputs.mean(), inputs.std() | |
inputs = (inputs - inputs_m) / inputs_s | |
# Get predictions | |
outputs = model(inputs) | |
# Get the predicted class with the highest score | |
_, prediction = torch.max(outputs,1) | |
# Count of predictions that matched the target label | |
correct_prediction += (prediction == labels).sum().item() | |
total_prediction += prediction.shape[0] | |
acc = correct_prediction/total_prediction | |
print(f'Accuracy: {acc:.2f}, Total items: {total_prediction}') | |
# Run inference on trained model with the validation set | |
inference(myModel, val_dl) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment