Skip to content

Instantly share code, notes, and snippets.

@mick001
Last active November 26, 2023 19:12
Show Gist options
  • Save mick001/49fad7f4c6112d954aff to your computer and use it in GitHub Desktop.
Save mick001/49fad7f4c6112d954aff to your computer and use it in GitHub Desktop.
A neural network exaple in R. Full article at: http://datascienceplus.com/fitting-neural-network-in-r/
# Set a seed
set.seed(500)
library(MASS)
data <- Boston
# Check that no data is missing
apply(data,2,function(x) sum(is.na(x)))
# Train-test random splitting for linear model
index <- sample(1:nrow(data),round(0.75*nrow(data)))
train <- data[index,]
test <- data[-index,]
# Fitting linear model
lm.fit <- glm(medv~., data=train)
summary(lm.fit)
# Predicted data from lm
pr.lm <- predict(lm.fit,test)
# Test MSE
MSE.lm <- sum((pr.lm - test$medv)^2)/nrow(test)
#-------------------------------------------------------------------------------
# Neural net fitting
# Scaling data for the NN
maxs <- apply(data, 2, max)
mins <- apply(data, 2, min)
scaled <- as.data.frame(scale(data, center = mins, scale = maxs - mins))
# Train-test split
train_ <- scaled[index,]
test_ <- scaled[-index,]
# NN training
library(neuralnet)
n <- names(train_)
f <- as.formula(paste("medv ~", paste(n[!n %in% "medv"], collapse = " + ")))
nn <- neuralnet(f,data=train_,hidden=c(5,3),linear.output=T)
# Visual plot of the model
plot(nn)
# Predict
pr.nn <- compute(nn,test_[,1:13])
# Results from NN are normalized (scaled)
# Descaling for comparison
pr.nn_ <- pr.nn$net.result*(max(data$medv)-min(data$medv))+min(data$medv)
test.r <- (test_$medv)*(max(data$medv)-min(data$medv))+min(data$medv)
# Calculating MSE
MSE.nn <- sum((test.r - pr.nn_)^2)/nrow(test_)
# Compare the two MSEs
print(paste(MSE.lm,MSE.nn))
# Plot predictions
par(mfrow=c(1,2))
plot(test$medv,pr.nn_,col='red',main='Real vs predicted NN',pch=18,cex=0.7)
abline(0,1,lwd=2)
legend('bottomright',legend='NN',pch=18,col='red', bty='n')
plot(test$medv,pr.lm,col='blue',main='Real vs predicted lm',pch=18, cex=0.7)
abline(0,1,lwd=2)
legend('bottomright',legend='LM',pch=18,col='blue', bty='n', cex=.95)
# Compare predictions on the same plot
plot(test$medv,pr.nn_,col='red',main='Real vs predicted NN',pch=18,cex=0.7)
points(test$medv,pr.lm,col='blue',pch=18,cex=0.7)
abline(0,1,lwd=2)
legend('bottomright',legend=c('NN','LM'),pch=18,col=c('red','blue'))
#-------------------------------------------------------------------------------
# Cross validating
library(boot)
set.seed(200)
# Linear model cross validation
lm.fit <- glm(medv~.,data=data)
cv.glm(data,lm.fit,K=10)$delta[1]
# Neural net cross validation
set.seed(450)
cv.error <- NULL
k <- 10
# Initialize progress bar
library(plyr)
pbar <- create_progress_bar('text')
pbar$init(k)
for(i in 1:k){
index <- sample(1:nrow(data),round(0.9*nrow(data)))
train.cv <- scaled[index,]
test.cv <- scaled[-index,]
nn <- neuralnet(f,data=train.cv,hidden=c(5,2),linear.output=T)
pr.nn <- compute(nn,test.cv[,1:13])
pr.nn <- pr.nn$net.result*(max(data$medv)-min(data$medv))+min(data$medv)
test.cv.r <- (test.cv$medv)*(max(data$medv)-min(data$medv))+min(data$medv)
cv.error[i] <- sum((test.cv.r - pr.nn)^2)/nrow(test.cv)
pbar$step()
}
# Average MSE
mean(cv.error)
# MSE vector from CV
cv.error
# Visual plot of CV results
boxplot(cv.error,xlab='MSE CV',col='cyan',
border='blue',names='CV error (MSE)',
main='CV error (MSE) for NN',horizontal=TRUE)
@PauliusZaicev
Copy link

Hi Mick, this is a very helpful example! Thank you very much. Do you have any tutorials on RNN with time series data? Example scenario: Dependent variable type continues numeric, and input observations which would have, year, month, week of the year, fiscal day od the week, hour and min intervals for each hour? Many thanks.

@ConstantFuture
Copy link

ConstantFuture commented Dec 30, 2018

Thank You! - do you by any chance know why this error would be coming up when running nm <- neuralnet(...): Error in x - y : non-conformable arrays ?

@goclem
Copy link

goclem commented Oct 16, 2019

Thanks!

@f12345zxcvbnm2
Copy link

Thank you for this contribution. A good code for beginner !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment