Skip to content

Instantly share code, notes, and snippets.

@yoursdearboy
Created January 6, 2017 20:46
Show Gist options
  • Save yoursdearboy/a5dc83a9fa31e40d491ecb390c8b6c6a to your computer and use it in GitHub Desktop.
Save yoursdearboy/a5dc83a9fa31e40d491ecb390c8b6c6a to your computer and use it in GitHub Desktop.
Good looking survival curves using ggplot2
# I recommend to use [GGaly::ggsurv](http://ggobi.github.io/ggally).
# Can't remember why, but it didn't work for my purpose, so I wrote this.
#
# - ggsurv.data - produces data.frame to use with ggplot
# * survfit
# * optional subset rule
#
# - ggsurv.plot - plots survfit or data.frame.
# * survfit / data.frame
# * ms - if data.frame is passed, set to true/false if model is multistate
# * strated - if data.frame is passed, set to true/false if model is stratified accross groups
# * ci - visibility of confidence intervals
# * cens - visibility of censoring marks
#
# Passed data frame includes next columns:
#
# * surv (prev if model is cum-inc.)
# * low, up (conf-int)
# * time
# * stderr
# * cens (# of censored), nev (# of events)
# * group (if model is stratified)
library(reshape2)
library(ggplot2)
ggsurv.data <- function(fit, ...) {
ssubset <- deparse(substitute(...))
ms <- any(class(fit)=="survfitms")
nevt <- ifelse(ms, ncol(fit$prev) - 1, 1)
strated <- !is.null(fit$strata)
time <- rep(fit$time, nevt)
cens <- rep(fit$n.censor, nevt)
stderr <- fit$std.err
up <- fit$upper
low <- fit$lower
nev <- fit$n.event
if (ms) {
stderr <- melt(stderr[,1:nevt])$value
up <- melt(up[,1:nevt])$value
low <- melt(low[,1:nevt])$value
nev <- melt(nev[,1:nevt])$value
}
out <- data.frame(time, stderr, up, low, cens, nev)
if (strated) {
group <- factor(rep(1:length(fit$strata), fit$strata))
out <- cbind.data.frame(out, data.frame(group))
}
if (ms) {
prev_by_ev <- melt(fit$prev[,1:nevt], value.name='prev', varnames=c('id','event'))
prev <- prev_by_ev[,'prev']
event <- prev_by_ev[,'event']
out <- cbind.data.frame(out, data.frame(prev, event))
} else {
surv <- fit$surv
out <- cbind.data.frame(out, data.frame(surv))
}
initial <- if (ms) {
initial_val <- 0
if (strated) {
cbind.data.frame(expand.grid(event=1:nevt, group=1:length(fit$strata)), data.frame(prev=initial_val, low=initial_val, up=initial_val))
} else {
initial_val <- rep(initial_val, nevt)
data.frame(prev=initial_val, up=initial_val, low=initial_val, event=1:nevt)
}
} else {
initial_val <- 1
if (strated) {
initial_val <- rep(initial_val, length(fit$strata))
data.frame(surv=initial_val, up=initial_val, low=initial_val, group=1:length(fit$strata))
} else {
data.frame(surv=initial_val,low=initial_val,up=initial_val)
}
}
initial <- cbind.data.frame(initial, data.frame(time=0, stderr=0, cens=0, nev=0))
out <- rbind.data.frame(initial, out)
if (ssubset != "NULL") out <- subset(out, eval(parse(text=ssubset)))
out
}
ggsurv.plot <- function(input, ms=NULL, strated=NULL, ci=FALSE, cens=TRUE) {
if (any(class(input) == "survfit")) {
fit <- input
ms <- any(class(input)=="survfitms")
strated <- !is.null(input$strata)
input <- ggsurv.data(input)
}
aeses <- aes(x=time)
aeses <- c(aeses, aes_string(y=ifelse(ms, 'prev', 'surv')))
if (strated) aeses <- c(aeses, aes(color=group))
class(aeses) <- 'uneval'
out <- ggplot(input,aeses)
out <- out + geom_step(direction = 'hv')
if (ci) out <- out + geom_step(aes(y=up),lty=2) + geom_step(aes(y=low),lty=2)
if (cens) out <- out + geom_point(data=function(data) subset(data, cens > 0),shape=3)
out
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment