Created
January 15, 2016 20:38
-
-
Save primaryobjects/96703184d05ee76bba17 to your computer and use it in GitHub Desktop.
Calculating parity (even or odd number of 1 bits) with a neural network.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## Including the required R packages. | |
packages <- c('caret', 'RSNNS') | |
if (length(setdiff(packages, rownames(installed.packages()))) > 0) { | |
install.packages(setdiff(packages, rownames(installed.packages()))) | |
} | |
# | |
library(caret) | |
library(RSNNS) | |
# Generate bit arrays for a bunch of numbers. | |
numbers <- sample(1:1000, 1000) | |
# Convert the numbers to their binary representations. | |
data <- sapply(numbers, function(n) as.numeric(intToBits(n))) | |
# Note, to convert the binary back into numbers: | |
# sapply(data, function(d) list(packBits(as.raw(d), 'integer'))) == numbers | |
data <- as.data.frame(t(data)) | |
# Calculate parity for each row (number of 1's in the row being even (0) or odd (1)). | |
maxVal <- max(rowSums(data == 1)) | |
#data$y <- rowSums(data == 1) / maxVal | |
data$y <- as.factor(rowSums(data == 1) %% 2) | |
#train <- data[1:50,] | |
#cv <- data[51:100,] | |
train <- data[data$y == 0,][1:250,] | |
train <- rbind(train, data[data$y == 1,][1:250,]) | |
cv <- data[data$y == 0,][251:500,] | |
cv <- rbind(cv, data[data$y == 1,][251:500,]) | |
fit <- train(y ~ ., data = train, method = 'nnet', tuneGrid = expand.grid(decay = c(0.01), size = c(8))) | |
#results <- predict(fit, newdata = cv) * maxVal | |
#conf <- confusionMatrix(round(results), cv$y * maxVal) | |
# Calculate accuracy. | |
results <- predict(fit, newdata = train) | |
conf <- confusionMatrix(results, train$y) | |
length(which(results == train$y)) / nrow(train) | |
# Calculate accuracy. | |
results <- predict(fit, newdata = cv) | |
conf <- confusionMatrix(results, cv$y) | |
length(which(results == cv$y)) / nrow(cv) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment