Skip to content

Instantly share code, notes, and snippets.

@primaryobjects
Last active January 2, 2016 16:34
Show Gist options
  • Save primaryobjects/5a28e0c27fd433123f1a to your computer and use it in GitHub Desktop.
Save primaryobjects/5a28e0c27fd433123f1a to your computer and use it in GitHub Desktop.
Neural Network Sort
library(neuralnet)
# Helper method to generate a training set containing size random numbers (a, b, c) and sorted (x, y, z).
generateSet <- function(size = 100, max = 100) {
# Generate size random numbers between 1 and max.
training <- data.frame(a=sample(1:max, size, replace=TRUE),
b=sample(1:max, size, replace=TRUE),
c=sample(1:max, size, replace=TRUE))
# Generate output examples by sorting the numbers.
output <- data.frame()
x <- sapply(1:nrow(training), function(i) {
row <- training[i, ]
sorted <- row[order(row)]
output <<- rbind(output, unlist(sorted))
})
# Append output to the training set.
names(output) <- c('x', 'y', 'z')
cbind(training, output)
}
# Helper method to restore the original values after scaling. x is the object to unscale, orig is the originally scaled data.
unscale <- function(x, orig) {
t(apply(x, 1, function(r) {
r * attr(orig, 'scaled:scale') + attr(orig, 'scaled:center')
}))
}
# Helper method to run Neural Network Sort manually. Usage: nnsort(fit, data, 20, 77, 18)
nnsort <- function(fit, scaleVal, a, b, c) {
numbers <- data.frame(a=a, b=b, c=c, x=0, y=0, z=0)
numbersScaled <- as.data.frame(scale(numbers, attr(scaleVal, 'scaled:center'), attr(scaleVal, 'scaled:scale')))
round(unscale(compute(fit, numbersScaled[,1:3])$net.result, scaleVal))[,4:6]
}
# Generate training and cv data.
data <- generateSet(1500)
# Normalize data.
data <- scale(data)
# Split for a training and cv set.
half <- nrow(data)/2
training <- data[1:half,]
cv <- data[(half+1):nrow(data),]
# Train neural network.
fit <- neuralnet(x + y + z ~ a + b + c,
data = training,
hidden = c(40, 40),
threshold = 0.001,
rep=1,
learningrate = 0.6,
stepmax = 9999999,
lifesign = 'full')
# Check results.
results1 <- round(unscale(compute(fit, training[,1:3])$net.result, data))
results2 <- round(unscale(compute(fit, cv[,1:3])$net.result, data))
# Count rows that are correct. Note, we use round(i, 10) when comparing equality http://stackoverflow.com/a/18668681.
correct1 <- length(which(rowSums(round(unscale(training[,4:6], data), 10) == results1) == 3))
correct2 <- length(which(rowSums(round(unscale(cv[,4:6], data), 10) == results2) == 3))
# Calculate accuracy.
print(paste('Training:', correct1 / nrow(training) * 100, '% / CV:', correct2 / nrow(cv) * 100, '%'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment