Skip to content

Instantly share code, notes, and snippets.

@uncountablecat
Created November 17, 2018 05:12
Show Gist options
  • Save uncountablecat/dc835e624f3164bec6cf8954e379c6ca to your computer and use it in GitHub Desktop.
Save uncountablecat/dc835e624f3164bec6cf8954e379c6ca to your computer and use it in GitHub Desktop.
tune alpha and lambda of Tweedie model with elastic net
library(caret)
library(HDtweedie)
elasticTweedie = list(library="HDtweedie", type="Regression")
prm = data.frame(parameter=c("alpha", "lambda"),
class=rep("numeric", 2),
label=c("alpha", "lambda_vec"))
elasticTweedieGrid = function(x, y, len=NULL, search="grid") {
# we only tune alpha and lambda
library(HDtweedie)
if (search == "grid") {
out = expand.grid(alpha=seq(0, 1, 0.1), lambda=seq(0, 100, 10))
} else {
stop("Not implemented yet!")
}
return(out)
}
elasticTweedieFit = function(x, y, param, last, ...) {
HDtweedie::HDtweedie(x=as.matrix(x), y=y,
alpha=param$alpha, nlambda=1, lambda=c(param$lambda), eps=1e-10)
}
elasticTweediePred = function(modelFit, newdata, ...) {
predict(modelFit, newdata)
}
elasticTweedieProb = function(modelFit, newdata, preProc = NULL, submodels = NULL) {
}
elasticTweedie$parameters = prm
elasticTweedie$grid = elasticTweedieGrid
elasticTweedie$fit = elasticTweedieFit
elasticTweedie$predict = elasticTweediePred
elasticTweedie$prob = elasticTweedieProb
fitControl = trainControl(method="cv", number=3)
tuningGrid = expand.grid(alpha=seq(0,1,0.2), lambda=seq(0,100,20))
elasticTweedieModel = train(x=raw_data[,colnames(raw_data)!="y"], y=raw_data$y, method=elasticTweedie, trControl=fitControl, tuneGrid=tuningGrid)
fm = elasticTweedieModel$finalModel
print(fm$b0) # get intercept
print(fm$beta) # get coefficients
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment