Skip to content

Instantly share code, notes, and snippets.

@innat
Created November 25, 2020 13:33
Show Gist options
  • Select an option

  • Save innat/61d76fc98f92b647135254884e6916f6 to your computer and use it in GitHub Desktop.

Select an option

Save innat/61d76fc98f92b647135254884e6916f6 to your computer and use it in GitHub Desktop.
# set plot figure size
fig, c_ax = plt.subplots(1,1, figsize = (12, 8))
def multiclass_roc_auc_score(y_test, y_pred, average="macro"):
lb = LabelBinarizer()
lb.fit(y_test)
y_test = lb.transform(y_test)
y_pred = lb.transform(y_pred)
for (idx, c_label) in enumerate(all_labels): # all_labels: no of the labels, for ex. ['cat', 'dog', 'rat']
fpr, tpr, thresholds = roc_curve(y_test[:,idx].astype(int), y_pred[:,idx])
c_ax.plot(fpr, tpr, label = '%s (AUC:%0.2f)' % (c_label, auc(fpr, tpr)))
c_ax.plot(fpr, fpr, 'b-', label = 'Random Guessing')
return roc_auc_score(y_test, y_pred, average=average)
# calling
valid_generator.reset() # resetting generator
y_pred = model.predict_generator(valid_generator, verbose = True)
y_pred = np.argmax(y_pred, axis=1)
multiclass_roc_auc_score(valid_generator.classes, y_pred)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment