Skip to content

Instantly share code, notes, and snippets.

@Kornel
Created March 24, 2016 18:19
Show Gist options
  • Save Kornel/c72d99bf1a6fb307ed15 to your computer and use it in GitHub Desktop.
Save Kornel/c72d99bf1a6fb307ed15 to your computer and use it in GitHub Desktop.
Decision boundary nnet
library(nnet)
n <- 20
x1 <- rnorm(n)
x2 <- rnorm(n)
y <- as.factor(rbinom(n, 1, 0.5))
data <- data.frame(x1, x2, y)
fit <- nnet(y ~ ., data, size = 10)
#library(devtools)
#source_url('https://gist.githubusercontent.com/fawda123/7471137/raw/466c1474d0a505ff044412703516c34f1a4684a5/nnet_plot_update.r')
#plot.nnet(fit)
py <- seq(-4, 4, 0.1)
px <- seq(-4, 4, 0.1)
grid <- expand.grid(px, py)
colnames(grid) <- c('x1', 'x2')
grid$pred <- predict(fit, grid, type = 'class')
data.0 <- data[data$y == 0,]
data.1 <- data[data$y == 1,]
plot(data.0$x1, data.0$x2, xlim = c(-4, 4), ylim = c(-4, 4), col = 'red')
points(data.1$x1, data.1$x2, col = 'blue')
contour(px, py, array(grid$pred, dim = c(length(px), length(py))),
drawlabels = F,
add = T,
col = 'green')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment