Created
April 4, 2024 11:18
-
-
Save chrishanretty/ba0b46b767a169fa1f82ddfe69246f71 to your computer and use it in GitHub Desktop.
Predictions from predict.bam and marginaleffects::predictions with and without discretization
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### PURPOSE OF THIS CODE: estimate a large model in bam and speed-test | |
### predictions using mgcv::predict.bam and | |
### marginaleffects::predictions with and without discretization and | |
### parallelization | |
### ################################################################## | |
### load libraries | |
### ################################################################## | |
library(mgcv) | |
library(marginaleffects) | |
library(nycflights13) | |
library(tictoc) | |
data("flights") | |
### Some small config | |
my_threads <- 12 | |
set.seed(3) | |
### ################################################################## | |
### transform data to help in modelling departure delay | |
### ################################################################## | |
## Handle date: convert to a numeric so we can use it as a smooth term | |
flights <- flights |> | |
transform(date = as.Date(paste(year, month, day, sep = "/"))) |> | |
transform(date.num = as.numeric(date - min(date))) | |
### Handle wday: convert to numeric | |
flights <- flights |> | |
transform(wday = as.POSIXlt(date)$wday) | |
## Handle time of departure, again convering to numeric | |
flights <- flights |> | |
transform(time = as.POSIXct(paste(hour, minute, sep = ":"), format = "%H:%M")) |> | |
transform(time.dt = difftime(time, | |
as.POSIXct('00:00', format = '%H:%M'), units = 'min')) |> | |
transform(time.num = as.numeric(time.dt)) | |
### Handle the outcome, specifically early and missing data | |
flights <- flights |> | |
transform(dep_delay = ifelse(dep_delay < 0, 0, dep_delay)) |> | |
transform(dep_delay = ifelse(is.na(dep_delay), 0, dep_delay)) | |
### Transform some things to factors | |
flights <- flights |> | |
transform(carrier = factor(carrier)) |> | |
transform(dest = factor(dest)) |> | |
transform(origin = factor(origin)) | |
### ################################################################## | |
### Estimate models with and without discretization | |
### ################################################################## | |
### Takes about 135 seconds single-threaded | |
tic() | |
m_base <- bam(dep_delay ~ s(date.num, bs = "cr") + | |
s(wday, bs = "cc", k = 3) + | |
s(time.num, bs = "cr") + | |
s(carrier, bs = "re") + | |
origin + | |
s(distance, bs = "cr") + | |
s(dest, bs = "re"), | |
data = flights, | |
family = poisson, | |
discrete = FALSE) | |
toc() | |
### Takes about 8 seconds | |
tic() | |
m_discrete <- bam(dep_delay ~ s(date.num, bs = "cr") + | |
s(wday, bs = "cc", k = 3) + | |
s(time.num, bs = "cr") + | |
s(carrier, bs = "re") + | |
origin + | |
s(distance, bs = "cr") + | |
s(dest, bs = "re"), | |
data = flights, | |
family = poisson, | |
discrete = TRUE, | |
nthreads = my_threads) | |
toc() | |
### ################################################################## | |
### generate predictions, w/ and w/o SEs, w/ and w/o discretization | |
### ################################################################## | |
### Case 1: mgcv, w/o SEs, w/o discretization | |
### takes around 3 seconds | |
tic() | |
p1 <- predict(m_base, se.fit = FALSE) | |
tot <- toc() | |
e_1 <- tot$toc - tot$tic | |
### Case 2: marginaleffects, w/o SEs, w/o discretization | |
### takes around 3 seconds | |
tic() | |
p2 <- predictions(m_base, vcov = FALSE) | |
tot <- toc() | |
e_2 <- tot$toc - tot$tic | |
### Case 3: mgcv, w/ SEs, w/o discretization | |
### takes around 17 seconds | |
tic() | |
p3 <- predict(m_base, se.fit = TRUE) | |
tot <- toc() | |
e_3 <- tot$toc - tot$tic | |
### Case 4: marginaleffects, w/ SEs, w/o discretization | |
### takes around 356 sceonds | |
tic() | |
p4 <- predictions(m_base, vcov = TRUE) | |
tot <- toc() | |
e_4 <- tot$toc - tot$tic | |
### Case 5: mgcv, w/o SEs, w/ discretization | |
### takes around 1/3rd of a second | |
tic() | |
p5 <- predict(m_discrete, se.fit = FALSE, | |
discrete = TRUE, nthreads = my_threads) | |
tot <- toc() | |
e_5 <- tot$toc - tot$tic | |
### Case 6: marginaleffects, w/o SEs, w/ discretization | |
### takes around 0.8 seconds | |
tic() | |
p6 <- predictions(m_discrete, vcov = FALSE, | |
discrete = TRUE, nthreads = my_threads) | |
tot <- toc() | |
e_6 <- tot$toc - tot$tic | |
### Case 7: mgcv, w/ SEs, w/ discretization | |
### takes around 3.6 seconds | |
tic() | |
p7 <- predict(m_discrete, se.fit = TRUE, | |
discrete = TRUE, nthreads = my_threads) | |
tot <- toc() | |
e_7 <- tot$toc - tot$tic | |
### Case 8: marginaleffects, w/ SEs, w/ discretization | |
### takes around 82 seconds | |
tic() | |
p8 <- predictions(m_discrete, vcov = TRUE, | |
discrete = TRUE, nthreads = my_threads) | |
tot <- toc() | |
e_8 <- tot$toc - tot$tic |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment