Skip to content

Instantly share code, notes, and snippets.

@dgrtwo
Created May 31, 2017 18:56
Show Gist options
  • Save dgrtwo/aaef94ecc6a60cd50322c0054cc04478 to your computer and use it in GitHub Desktop.
Save dgrtwo/aaef94ecc6a60cd50322c0054cc04478 to your computer and use it in GitHub Desktop.
Comparing pairs of MNIST digits based on one pixel
library(tidyverse)
# Data is downloaded from here:
# https://www.kaggle.com/c/digit-recognizer
kaggle_data <- read_csv("~/Downloads/train.csv")
pixels_gathered <- kaggle_data %>%
mutate(instance = row_number()) %>%
gather(pixel, value, -label, -instance) %>%
extract(pixel, "pixel", "(\\d+)", convert = TRUE)
roc_by_pixel <- pixels_gathered %>%
filter(instance %% 20 == 0) %>%
crossing(compare1 = 0:4, compare2 = 0:4) %>%
filter(label == compare1 | label == compare2, compare1 != compare2) %>%
group_by(compare1, compare2, pixel, value) %>%
summarize(positive = sum(label == compare2),
negative = n() - positive) %>%
arrange(desc(value)) %>%
mutate(tpr = cumsum(positive) / sum(positive),
fpr = cumsum(negative) / sum(negative)) %>%
filter(n() > 1)
roc_by_pixel %>%
summarize(auc = sum(diff(fpr) * (tpr + lag(tpr))[-1]) / 2) %>%
arrange(desc(auc)) %>%
mutate(row = pixel %/% 28, column = pixel %% 28) %>%
ggplot(aes(column, 28 - row, fill = auc)) +
geom_tile() +
scale_fill_gradient2(low = "blue", high = "red", mid = "white", midpoint = .5) +
facet_grid(compare2 ~ compare1) +
labs(title = "AUC for distinguishing pairs of MNIST digits by one pixel",
subtitle = "Red means pixel is predictive of the row, blue predictive of the column",
fill = "AUC") +
theme_void()
@dgrtwo
Copy link
Author

dgrtwo commented Sep 10, 2019

Hmm, that's strange, since the first column of the Kaggle data is label last time I checked.

Can you check that there's a label column in your train.csv, and then check that there's a label column in pixels_gathered after running that line?

@mathematicalmichael
Copy link

mathematicalmichael commented Sep 11, 2019

oh goodness, thank you so much. I didn't realize it was imperative that I use Kaggle's version of the dataset (I didn't want to sign up just to download it, so I found the dataset elsewhere). I'll give it another go later today once I have stable internet and post an update. I have a suspicion that's exactly the problem. It's annoying that the dataset isn't accessible via download through command-line.

It looks like the csv I got has no labels at all. It's just pixel values for each image comma-separated, one image per line (which now explains the "renaming" portion of the stack trace). If you have a suggestion for how to add the requisite label using R after the data is loaded, I would appreciate that (much as I do your prompt reply), as it's been a while since I've written any R myself (these days it's all Python for me).

Once I get the Kaggle dataset downloaded to my computer, do you think it would be apropos to upload them to a public server I rent and make them accessible via wget?

@mathematicalmichael
Copy link

mathematicalmichael commented Sep 11, 2019

IT WORKED (the environment I used was sufficient to handle all dependencies)! thanks so much for your help, @dgrtwo

I'm not sure I understand why the figure that gets plotted at the end is indicative of predictive potential by a single pixel. Does it have to do with sharp boundaries? There aren't really comments anywhere to help. What are the four rows/columns representing?

@hot9cups
Copy link

Any update on what @mathemaicalmichael said? Still not sure how the predictive potential of a single pixel is portrayed by the figure

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment