Skip to content

Instantly share code, notes, and snippets.

@fabsta
Last active November 25, 2017 13:36
Show Gist options
  • Save fabsta/ffe057e43910a524fc0c622f093b1514 to your computer and use it in GitHub Desktop.
Save fabsta/ffe057e43910a524fc0c622f093b1514 to your computer and use it in GitHub Desktop.
Visualization #datascience

Observing Model Predictions

source: https://www.cs.utah.edu/~cmertin/dogs+cats+redux.html

First, we need to calculate the predictions on the validation set, since we know those labels, rather than looking at the test set. In [19]:

vgg.model.load_weights(latest_weights_filename)

In [20]:

val_batches, probs = vgg.test(VAL_PATH, batch_size = batch_size)
Found 2000 images belonging to 2 classes.

In [22]:

filenames = val_batches.filenames
expected_labels = val_batches.classes # makes them 0 or 1

our_predictions = probs[:, 0]
our_labels = np.round(1 - our_predictions)

1 a few correct labels at random

correct = np.where(preds==val_labels[:,1])[0]
idx = permutation(correct)[:n_view]
plots_idx(idx, probs[idx])

2 a few incorrect labels at random

incorrect = np.where(preds!=val_labels[:,1])[0]
idx = permutation(incorrect)[:n_view]
plots_idx(idx, probs[idx])

3 the most correct labels of each class (ie those with highest probability that are correct)

correct_cats = np.where((preds==0) & (preds==val_labels[:,1]))[0]
most_correct_cats = np.argsort(probs[correct_cats])[::-1][:n_view]
plots_idx(correct_cats[most_correct_cats], probs[correct_cats][most_correct_cats])

4 the most incorrect labels of each class (ie those with highest probability that are incorrect)

incorrect_dogs = np.where((preds==1) & (preds!=val_labels[:,1]))[0]
most_incorrect_dogs = np.argsort(probs[incorrect_dogs])[:n_view]
plots_idx(incorrect_dogs[most_incorrect_dogs], 1-probs[incorrect_dogs][most_incorrect_dogs])

5 the most uncertain labels (ie those with probability closest to 0.5)

most_uncertain = np.argsort(np.abs(probs-0.5))
plots_idx(most_uncertain[:n_view], probs[most_uncertain])
fig = plt.figure(1, figsize=(NUM_CATEGORIES, NUM_CATEGORIES))
grid = ImageGrid(fig, 111, nrows_ncols=(NUM_CATEGORIES, NUM_CATEGORIES), axes_pad=0.05)
i = 0
for category_id, category in enumerate(CATEGORIES):
    for filepath in train[train['category'] == category]['file'].values[:NUM_CATEGORIES]:
        ax = grid[i]
        img = read_img(filepath, (224, 224))
        ax.imshow(img / 255.)
        ax.axis('off')
        if i % NUM_CATEGORIES == NUM_CATEGORIES - 1:
            ax.text(250, 112, filepath.split('/')[1], verticalalignment='center')
        i += 1
plt.show();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment