Created
October 23, 2016 07:04
-
-
Save padamson/42044ba43a115b5f606835740122acb4 to your computer and use it in GitHub Desktop.
Plot the confusion matrix for a 10-class MNIST handwritten digit classification problem
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(caret) | |
library(kknn) | |
library(RColorBrewer) | |
library(cowplot) | |
mnist <- read.csv("data/mnist_small.csv", | |
colClasses = c(label = "factor")) | |
trainIndex <- createDataPartition(mnist$label, p = .8, | |
list = FALSE, | |
times = 1) | |
mnistTrain <- mnist[ trainIndex,] | |
mnistTest <- mnist[-trainIndex,] | |
mnist.kknn <- kknn(label~., mnistTrain, mnistTest, distance = 1, | |
kernel = "triangular") | |
confusionDF <- data.frame(confusionMatrix(fitted(mnist.kknn),mnistTest$label)$table) | |
confusionDF$Reference = with(confusionDF, | |
factor(Reference, levels = rev(levels(Reference)))) | |
jBuPuFun <- colorRampPalette(brewer.pal(n = 9, "BuPu")) | |
paletteSize <- 256 | |
jBuPuPalette <- jBuPuFun(paletteSize) | |
confusionPlot <- ggplot( | |
confusionDF, aes(x = Prediction, y = Reference, fill = Freq)) + | |
theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.5)) + | |
geom_tile() + | |
labs(x = "Predicted digit", y = "Actual digit") + | |
scale_fill_gradient2( | |
low = jBuPuPalette[1], | |
mid = jBuPuPalette[paletteSize/2], | |
high = jBuPuPalette[paletteSize], | |
midpoint = (max(confusionDF$Freq) + min(confusionDF$Freq)) / 2, | |
name = "") + | |
theme(legend.key.height = unit(2, "cm")) | |
ggdraw(switch_axis_position(confusionPlot, axis = 'x')) |
And secondly, how can i print the value rather that putting the colors in confusion matrx?
Thanks
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, thanks fro providing this best implementation, but i cant run the code. the is "object 'confusionDF' not found" and also "ould not find function "switch_axis_position"
Kindly help me out. Bcz i am new to R programming.