Skip to content

Instantly share code, notes, and snippets.

@GarrettMooney
Last active April 30, 2018 00:18
Show Gist options
  • Save GarrettMooney/bcfa40b31bd9c3500906a60d0f652917 to your computer and use it in GitHub Desktop.
Save GarrettMooney/bcfa40b31bd9c3500906a60d0f652917 to your computer and use it in GitHub Desktop.
multi-class logistic regression
Multiclass classifier
----
```{r, include=FALSE}
knitr::opts_chunk$set(
comment = "#>",
collapse = TRUE,
cache = TRUE,
echo = FALSE,
out.width = "70%",
fig.align = "center",
fig.width = 6,
fig.asp = 0.618,
fig.show = "hold"
)
library(tidyverse)
library(arm)
`%>%` <- magrittr::`%>%`
`%<>%` <- magrittr::`%<>%`
```
One-vs-all:
Run three separate logistic regressions.
Pick the model that is most confident on a given prediction.
```{r, warning = FALSE}
cyl_levels <- levels(factor(mtcars$cyl))
# plot helper
plot_by_cyl <- function(cyl_level, ...) {
with(mtcars,
plot( mpg, disp,
col = ifelse(cyl == cyl_levels[cyl_level], 1, 2),
pch = ifelse(cyl == cyl_levels[cyl_level], 1, 2),
...
))
}
# dummy vars
mtcars %<>% mutate(i = 1, cyl_str = str_c('cyl', cyl)) %>% spread(cyl_str, i, fill = 0)
# model helpers
fit_model <- partial(arm::bayesglm, data = mtcars, family = binomial)
pred <- partial(predict, type = "link")
fit_and_pred <- compose(pred, fit_model)
# 4 cyl
plot_by_cyl(1, main = "4 cyl")
p.1 <- fit_and_pred(cyl4 ~ mpg + disp)
# 6 cyl
plot_by_cyl(2, main = "6 cyl")
p.2 <- fit_and_pred(cyl6 ~ mpg + disp)
# 8 cyl
plot_by_cyl(3, main = "8 cyl")
p.3 <- fit_and_pred(cyl8 ~ mpg + disp)
# Results
ll <- Reduce(cbind, list(as.numeric(factor(mtcars$cyl)), p.1, p.2, p.3))
results <- cbind(ll[, 1], apply(ll[,-1], 1, which.max))
caret::confusionMatrix(factor(results[, 1], levels = 1:3),
factor(results[, 2], levels = 1:3))
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment