Skip to content

Instantly share code, notes, and snippets.

@sbfnk
Created October 12, 2023 09:48
Show Gist options
  • Save sbfnk/5f3c3281abad80c4b3cd5b7d39e0c95a to your computer and use it in GitHub Desktop.
Save sbfnk/5f3c3281abad80c4b3cd5b7d39e0c95a to your computer and use it in GitHub Desktop.
Weekly forecast plotting with EpiNow2
#' Aggregate daily forecasts to weekly forecasts
#'
#' Aggregate daily case forecasts from regional_epinow2 to weekly forecasts.
#'
#' @param inf List output from `EpiNow2::regional_epinow()`. No default.
#' @param week_start Week start day. Default is 1 (Monday).
#' @param summary_quantiles Quantiles for forecast summary. Default is median, plus upper and lower bound for 20\%, 50\% and 90\% forecasts.
#' @importFrom dplyr group_by summarise filter mutate select case_when contains
#' @importFrom tidyr pivot_wider
#' @importFrom purrr map_df
#' @importFrom lubridate floor_date
#' @author Sophie Meakin
#' @export
#' @return Weekly observed data; weekly forecast samples; weekly forecast summary (median and specified quantiles); forecast date.
#'
aggregate_weekly <- function(inf,
week_start = 1,
summary_quantiles) {
if(missing(inf)) {
stop("inf not specified, no daily forecast to summarise. Please give a list output from EpiNow2::regional_epinow().")
}
if(missing(summary_quantiles)) {
summary_quantiles <- c(0.05, 0.25, 0.4, 0.5, 0.6, 0.75, 0.95)
} else {
summary_quantiles <- sort(
unique(round(c(summary_quantiles,
1 - summary_quantiles,
0.5),
3)
)
)
}
# latest date
latest_date <- inf$summary$latest_date
# observed data
obs_daily <- inf$summary$reported_cases
obs_weekly <- obs_daily %>%
dplyr::group_by(region,
week = lubridate::floor_date(date, unit = "week", week_start = week_start)) %>%
dplyr::summarise(confirm = sum(confirm),
full_week = n() == 7,
.groups = "drop")
# vector of region names
regions <- inf$summary$summarised_results$regions_by_inc
# forecast samples
samples_weekly <- purrr::map_df(
.x = regions,
.f = ~ {
out_int <- inf$regional[[.x]]$estimates$samples %>%
dplyr::filter(variable == "reported_cases") %>%
dplyr::mutate(region = .x) %>%
dplyr::group_by(region,
week = lubridate::floor_date(date, unit = "week", week_start = week_start),
sample) %>%
dplyr::summarise(value = sum(value),
full_week = (dplyr::n() == 7),
week_type = dplyr::case_when(
("forecast" %in% unique(type)) ~ "forecast",
("estimate based on partial data" %in% unique(type)) ~ "estimate based on partial data",
("estimate" %in% unique(type)) ~ "estimate"
),
.groups = "drop") %>%
dplyr::mutate(week_type = ordered(week_type,
c("estimate", "estimate based on partial data", "forecast")))
return(out_int)
})
# forecast summary
summary_weekly <- samples_weekly %>%
dplyr::group_by(region, week, full_week, week_type) %>%
dplyr::summarise(quantile = summary_quantiles,
quantile_label = dplyr::case_when(
quantile == 0.5 ~ "median",
quantile < 0.5 ~ paste0("lower_", 100*(1 - 2*quantile)),
quantile > 0.5 ~ paste0("upper_", 100*(2*quantile - 1))
),
cases_quantile = quantile(value, summary_quantiles),
.groups = "drop") %>%
dplyr::select(-quantile) %>%
tidyr::pivot_wider(names_from = quantile_label, values_from = cases_quantile) %>%
dplyr::filter(full_week) %>%
dplyr::select(region, week, type = week_type, median, dplyr::contains("lower_"), dplyr::contains("upper_"))
return(list(
latest_date = latest_date,
observed = obs_weekly,
samples = samples_weekly,
summary = summary_weekly
))
}
#' Plot weekly forecasts
#'
#' Plot weekly forecast summary from output of aggregate_weekly.
#'
#' @param inf_weekly List output from aggregate_weekly(). No default.
#' @param plot_regions Vector of region names to plot. Default is all regions.
#'
#' @importFrom dplyr left_join filter
#' @importFrom ggplot2 ggplot aes geom_vline geom_col geom_errorbar facet_wrap scale_x_date scale_color_brewer scale_fill_manual labs theme_bw theme
#' @author Sophie Meakin
#' @export
#' @return A ggplot2 object of weekly forecasts faceted by region.
#'
plot_weekly <- function(inf_weekly, plot_regions) {
if(missing(inf_weekly)) {
stop("inf_weekly not specified, no forecast to plot. Please give a list output from aggregate_weekly().")
}
if(missing(plot_regions)) {
plot_regions <- unique(inf_weekly$summary$region)
}
plot_weeks <- unique(inf_weekly$summary$week)
g <- inf_weekly$summary %>%
dplyr::left_join(inf_weekly$observed,
by = c("region", "week")) %>%
dplyr::filter(region %in% plot_regions) %>%
ggplot2::ggplot(ggplot2::aes(x = week)) +
ggplot2::geom_vline(xintercept = inf_weekly$latest_date,
lty = 2, col = "grey30", lwd = 0.8) +
ggplot2::geom_col(ggplot2::aes(y = confirm, fill = full_week), show.legend = FALSE) +
#
ggplot2::geom_errorbar(ggplot2::aes(ymin = lower_90, ymax = upper_90, col = type), width = 0, lwd = 1.5) +
ggplot2::geom_errorbar(ggplot2::aes(ymin = lower_50, ymax = upper_50, col = type), width = 0, lwd = 4) +
ggplot2::geom_errorbar(ggplot2::aes(ymin = lower_20, ymax = upper_20, col = type), width = 0, lwd = 6) +
#
ggplot2::facet_wrap(~ region, ncol = 3, scales = "free_y") +
ggplot2::scale_x_date(breaks = plot_weeks, date_labels = "%d\n%b") +
ggplot2::scale_color_brewer(palette = "Set2") +
ggplot2::scale_fill_manual(values = c("TRUE" = "grey60", "FALSE" = "grey90")) +
ggplot2::labs(x = "Week start",
col = "Type") +
ggplot2::theme_bw() +
ggplot2::theme(legend.position = "bottom")
return(g)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment