Created
January 8, 2020 15:42
-
-
Save cjbayesian/1fd16e4c46798c7e0a32965b7e5216cf to your computer and use it in GitHub Desktop.
Plot kaplan-meier style survival curves with errorbars
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scipy as sp | |
def beta_errors(num, denom): | |
return sp.stats.beta.interval(0.95, num+1, denom-num+1) | |
def plot_km(df, threshold=0.5, max_days=365, y_text_shrink=1, ax=None): | |
days = range(max_days) | |
idb_above = df['Pred']>threshold | |
survival_series = df['survival_time_days'] | |
labels = ['High risk','Low risk'] | |
grps = [survival_series[idb_above].copy(),survival_series[~idb_above].copy()] | |
if ax is None: | |
fig, ax = plt.subplots(1,1) | |
for i, grp in enumerate(grps): | |
survival_mean = [] | |
survival_numerator = [] | |
survival_denominator = [] | |
for day in days: | |
idb = grp > day | |
survival_mean.append(idb.mean()) | |
survival_numerator.append(grp[idb].shape[0]) | |
survival_denominator.append((idb.shape[0])) | |
ci = [beta_errors(num, denom) for num, denom in zip(survival_numerator,survival_denominator)] | |
lower = [interval[0] for interval in ci] | |
upper = [interval[1] for interval in ci] | |
label = labels[i] | |
proportions = " (n={}, {:.1%})".format(grp.shape[0], float(grp.shape[0])/df.shape[0]) | |
ax.plot(days, survival_mean,'-',label=label+proportions) | |
ax.plot(days, lower,'--',color='grey') | |
ax.plot(days, upper,'--',color='grey') | |
#ax.set_xlim(0,max(days)) | |
ax.set_xlim(0,max(days)) | |
#xticks = ax.get_xticks() | |
xticks = [0,45,90,135,180] | |
ax.set_xticks(xticks) | |
if i == 0: | |
y_text = -0.15*y_text_shrink | |
ax.text(-xticks[1]/4,y_text,'High Risk n Survived',horizontalalignment='right') | |
else: | |
y_text = -0.2*y_text_shrink | |
ax.text(-xticks[1]/4,y_text,'Low Risk n Survived',horizontalalignment='right') | |
for tick in xticks: | |
txt = " {}".format(survival_numerator[int(tick)]) | |
ax.text(tick,y_text,txt,horizontalalignment='center') | |
ax.legend(loc=0) | |
#ax.set_title('Risk threshold: '+str(thresh)) | |
ax.set_xlim(0,max(days)) | |
ax.set_ylim(0,1) | |
ax.set_xlabel('Time (days)') | |
ax.set_ylabel('Survival Probability') | |
fig, axx = plt.subplots(1,1,figsize=(7,7)) | |
thresh = 0.5 | |
ttmp = preds.sort_values(['PAT_ID','APPT_TIME']) | |
ttmp = ttmp.drop_duplicates('PAT_ID', keep='last') | |
plot_km(ttmp, threshold=thresh, max_days=int(181), ax=axx) | |
axx.grid(True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment