Skip to content

Instantly share code, notes, and snippets.

@smc77
Created November 17, 2011 03:47
Show Gist options
  • Save smc77/1372318 to your computer and use it in GitHub Desktop.
Save smc77/1372318 to your computer and use it in GitHub Desktop.
Regularization
#
# Let's look at how the different models generalize between different datasets
#
n.training <- 10
n.test <- 100
error.function <- function(y, y.pred) sum((y.pred - y)^2) / 2
e.rms <- function(y, y.pred) sqrt(2 * error.function(y=y, y.pred=y.pred) / length(y))
build.data <- function(n) {
f <- function(x) sin(2 * pi * x)
x <- seq(0, 1, length=n)
y <- f(x) + rnorm(n, sd=0.2)
return(data.frame(y=y, x=x))
}
training <- build.data(n=n.training)
test <- build.data(n=n.test)
predict.ridge <- function(fit, test.x) {
return(scale(test.x, center = F, scale = fit$scales) %*% fit$coef[,which.min(fit$GCV)] + fit$ym)
}
test.poly.error <- function(training, test, polynomials=2:9) {
errors.training <- errors.test <- numeric()
for(i in polynomials) {
fit <- lm.ridge(y~poly(x, i, raw=TRUE), data=training, lambda=seq(0, 50, 1))
y.pred.training <- predict.ridge(fit, poly(training$x, i, raw=TRUE))
errors.training[i] <- e.rms(training$y, y.pred.training)
y.pred.test <- predict.ridge(fit, poly(test$x, i, raw=TRUE))
errors.test[i] <- e.rms(test$y, y.pred.test)
}
errors <- data.frame(polynomial=polynomials, training.error=errors.training[polynomials], test.error=errors.test[polynomials])
return(errors)
}
library(ggplot2)
library(MASS)
errors <- test.poly.error(training, test)
errors <- melt(errors, x)
colnames(errors) <- c("polynomial", "dataset", "error")
p <- ggplot(errors, aes(x=polynomial, y=error, grouping=dataset, colour=dataset)) + geom_line()
p
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment