Skip to content

Instantly share code, notes, and snippets.

@mingjiphd
Created January 28, 2026 02:15
Show Gist options
  • Select an option

  • Save mingjiphd/a14f870d79e71d3109879770f49eeb97 to your computer and use it in GitHub Desktop.

Select an option

Save mingjiphd/a14f870d79e71d3109879770f49eeb97 to your computer and use it in GitHub Desktop.
Machine Learning using R Regression and Classification using Shallow Neural Networks
This R Script is a step by step demonstration of how to do regression and classification using shallow neural network (with only one hidden layer) using R. Several R packages nnet, neuralnet, caret and NeuralNetTools are used to show how to perform regression or classification using both continuous and cateogrical predictors; how to assess accuracy; how to plot the network architecture and how to plot relative importance.
A step by step video demonstration can be found at: https://youtu.be/TVlpUJTZYT8?si=cRo93o_DqAb_ghm5
###########################################################################################
# Machine Learning using R Regression and Classification using Shallow Neural Networks #
###########################################################################################
# Load necessary libraries
library(nnet) # For shallow neural networks
library(neuralnet) # For visualizing and custom training
library(caret) # For data splitting and evaluation
install.packages("NeuralNetTools")
library(NeuralNetTools) # For plotting relative importance
# Set seed for reproducibility
set.seed(1262026)
# Generate Data
n <- 500
df <- data.frame(
# Independent Variables (3 Continuous)
cont1 = rnorm(n, 50, 10),
cont2 = runif(n, 0, 100),
cont3 = rgamma(n, shape = 2, scale = 5),
# Independent Variables (2 Categorical)
cat_pred1 = factor(sample(c("A", "B"), n, replace = TRUE)),
cat_pred2 = factor(sample(c("Low", "High"), n, replace = TRUE))
)
# Create Dependent Variables
# 1. Continuous DV (for Regression)
df$target_reg <- with(df, (0.5 * cont1) + (0.2 * cont2) - (0.1 * cont3) + rnorm(n, 0, 5))
# 2. 3-Level Categorical DV (for Classification)
df$target_class <- factor(ifelse(df$target_reg > 35, "High",
ifelse(df$target_reg > 25, "Medium", "Low")))
# Preprocessing: Scale continuous variables
preProcValues <- preProcess(df[,1:3], method = c("range")) ##preProcess from caret
### It is used to estimate transformations (like scaling, centering, or range
### normalization) based on your training data.
df_scaled <- predict(preProcValues, df)
# Split into training and testing
trainIndex <- createDataPartition(df$target_reg, p = .8, list = FALSE)
train_data <- df_scaled[trainIndex, ]
test_data <- df_scaled[-trainIndex, ]
#### Regression with Shallow Neural Network
# Train shallow neural network for regression
# size = number of units in hidden layer
nn_reg <- nnet(target_reg ~ cont1 + cont2 + cont3 + cat_pred1 + cat_pred2,
data = train_data,
size = 5, ## 5 nerons Rule of Thumb: (#of inputs + #of outputs)/2 or grid search
linout = TRUE, # TRUE for regression (linear output)
trace = FALSE,
maxit = 500)
# Prediction
pred_reg <- predict(nn_reg, test_data)
# Accuracy Assessment (RMSE and R-squared)
postResample(pred_reg, test_data$target_reg) #postResample from caret for automatically assessment
#of accuracy of regression or classification
#### Classification using Shallow Neural Network
# Train shallow neural network for classification
nn_class <- nnet(target_class ~ cont1 + cont2 + cont3 + cat_pred1 + cat_pred2,
data = train_data,
size = 5,
linout = FALSE, # FALSE for classification
trace = FALSE,
maxit = 500)
# Prediction
pred_class <- predict(nn_class, test_data, type = "class")
# Accuracy Assessment
postResample(pred_class, test_data$target_class)
# Accuracy Assessment (Confusion Matrix)
confusionMatrix(factor(pred_class), test_data$target_class)
##Note: NIR=No Information Rate which is he accuracy you would
# get by simply guessing the most frequent class
# (in this case, "High") every single time.
# Prevalence = relative frequency
# Detection Rate = sensitivity * prevalence
# Detection Prevalence = predicted relative frequency
# Balanced Accuracy =The average of Sensivity+Specificity <- Most iportant
#### Visualizing the Network and Relative Importance
# 1. Plot the Network Architecture
# Using NeuralNetTools to visualize weights
plotnet(nn_reg)
#### Input layer -> Hidden Layer ->Output Layer
#### B1, B2 are biases (like intercept)
#### _pred1B and ed2Low are dummy variables
#### Darker/Thicker Lines = Higher Weights Lighter/Thinner Lines= Low Weights
# 2. Relative Importance (Garson's Algorithm)
# This identifies which variables are most influential in the model
garson(nn_reg) + labs(title = "Variable Importance in Regression NN")
A step by step video demonstration can be found at: https://youtu.be/TVlpUJTZYT8?si=cRo93o_DqAb_ghm5
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment