Created
August 9, 2016 22:57
-
-
Save vlad17/eb71cdbcc5e4f1ad891e482a84c8d82b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Subset of the million-song dataset | |
# Is a song made before or after 2002, based on its vocal features? | |
# significantly handicapping gbm here: | |
# - no cv for stopping | |
# - forcing 0.5 threshold. | |
# The above is what spark has to do | |
library(gbm) | |
library(ROCR) | |
set.seed(123) | |
# Download data and get raw | |
f <- "/tmp/YearPredictionMSD.txt" | |
if (! file.exists(f)) { | |
z <- paste(f, ".zip", sep="") | |
download.file("http://archive.ics.uci.edu/ml/machine-learning-databases/00203/YearPredictionMSD.txt.zip", z, quiet = TRUE) | |
unzip(z, exdir="/tmp") | |
} | |
raw <- read.csv(f, header=FALSE) | |
# Data is already fairly clean (no nas). Just split as prescribed by the website | |
raw[,1] = as.numeric(raw[,1]) | |
cutoff <- 463715 | |
train <- raw[1:cutoff,] | |
test <- raw[-(1:cutoff),] | |
raw <- NULL | |
response_col = colnames(train)[1] | |
# convert to a binary problem by checking whether we're before or after 2002 | |
# this splits the dataset in about half. | |
train[[response_col]] = as.numeric(train[[response_col]] < 2002) | |
test[[response_col]] = as.numeric(test[[response_col]] < 2002) | |
# train | |
ntrees <- 700 | |
shrinkage <- 0.001 | |
gbm_formula <- as.formula(paste0(response_col, " ~ ", paste(colnames(train)[2:ncol(train)], collapse = " + "))) | |
duration <- proc.time() | |
gbm_model <- gbm(gbm_formula, train, distribution = "bernoulli", n.trees = ntrees, bag.fraction = 0.75, interaction.depth = 3, n.cores=4, shrinkage = shrinkage) | |
duration <- proc.time() - duration | |
# total time | |
print(duration) | |
# predict | |
predictions_gbm <- predict(gbm_model, newdata = test[, 2:ncol(test)], | |
n.trees = ntrees, type = "response") | |
pred <- prediction(predictions_gbm, test[[response_col]], label.ordering = NULL) | |
# evaluate | |
gbm_perf <- gbm.perf(gbm_model, method = "OOB") | |
plot(performance(pred, measure = "tpr", x.measure = "fpr")) | |
print("auc") | |
performance(pred, measure = "auc")@y.values | |
print("acc (thresh=0.5)") | |
acc <- performance(pred, measure = "acc") | |
[email protected][[1]][max(which([email protected][[1]] >= 0.5))] | |
print("precision (thresh=0.5)") | |
acc <- performance(pred, measure = "prec") | |
[email protected][[1]][max(which([email protected][[1]] >= 0.5))] | |
print("recall (thresh=0.5)") | |
acc <- performance(pred, measure = "rec") | |
[email protected][[1]][max(which([email protected][[1]] >= 0.5))] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading required package: survival | |
Loading required package: lattice | |
Loading required package: splines | |
Loading required package: parallel | |
Loaded gbm 2.1.1 | |
Loading required package: gplots | |
Attaching package: ‘gplots’ | |
The following object is masked from ‘package:stats’: | |
lowess | |
Loading required package: methods | |
user system elapsed | |
5709.492 2.464 5715.402 | |
Warning message: | |
In gbm.perf(gbm_model, method = "OOB") : | |
OOB generally underestimates the optimal number of iterations although predictive performance is reasonably competitive. Using cv.folds>0 when calling gbm usually results in improved predictive performance. | |
[1] "auc" | |
[[1]] | |
[1] 0.6951813 | |
[1] "acc (thresh=0.5)" | |
[1] 0.6398993 | |
[1] "precision (thresh=0.5)" | |
[1] 0.6806732 | |
[1] "recall (thresh=0.5)" | |
[1] 0.4725602 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment