Skip to content

Instantly share code, notes, and snippets.

@AustinRochford
Created March 15, 2025 22:01
Show Gist options
  • Save AustinRochford/302a9e63d17d0ac3780b1c6b5dbbd9f0 to your computer and use it in GitHub Desktop.
Save AustinRochford/302a9e63d17d0ac3780b1c6b5dbbd9f0 to your computer and use it in GitHub Desktop.
Joint Modeling of Longitudinal and Survival Outcomes in PyMC
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "e15a67fc-ab90-45d5-9685-c98b3d70717d",
"metadata": {},
"source": [
"It should be clear from the back posts on this blog ([1](https://austinrochford.com/posts/2015-10-05-bayes-survival.html) [2](https://austinrochford.com/posts/2017-10-02-bayes-param-survival.html) [3](https://austinrochford.com/posts/revisit-survival-pymc.html)) that I have a long-standing interesting in [survival analysis](https://en.wikipedia.org/wiki/Survival_analysis). Over the last few years, I have been sporadically learning about joint models for longitudinal and time-to-event data. These models augment survival models by incorporating information from other (non-survival) outcomes repeatedly measured from the subjects over time. The resulting models can provide better survival estimates by incorporating this information. Recently I succeeded in wrapping my head around the theory of one class of such models, and wanted to record that understanding for my future self here, along with any one else this may help. This post contains a crash course in the basics of these joint models, along with a worked example in Python using [PyMC](https://www.pymc.io/welcome.html).\n",
"\n",
"My present understanding of this topic is largely based on [Dimitris Rizopoulos's](https://www.drizopoulos.com/) excellent presentation [_Joint Modeling of Longitudinal and Time-to-Event Data\n",
"with Applications in R_](https://www.drizopoulos.com/courses/EMC/ESP72.pdf) and the [documentation](https://drizopoulos.github.io/JMbayes2/) for his R package, `JMBayes2`."
]
},
{
"cell_type": "markdown",
"id": "bd24e306-7f92-4dc6-bed2-99f980d634be",
"metadata": {},
"source": [
"## Theory"
]
},
{
"cell_type": "markdown",
"id": "88ee121f-93da-44fc-a765-84476d598a62",
"metadata": {},
"source": [
"### Surival analysis\n",
"\n",
"For the survival component of our models, we will use the [proportional hazards model](https://en.wikipedia.org/wiki/Proportional_hazards_model#The_Cox_model) that I have written about in two previous posts ([2023](https://austinrochford.com/posts/revisit-survival-pymc.html), [2015](https://austinrochford.com/posts/2015-10-05-bayes-survival.html)). In this model, we represent the [hazard function](https://en.wikipedia.org/wiki/Survival_analysis#Hazard_function_and_cumulative_hazard_function) of the $i$-th subject associated with covariates $\\mathbf{x}_i$ as\n",
"\n",
"$$\\lambda(t\\ |\\ \\mathbf{x}_i) = \\lambda_0(t) \\cdot \\exp(\\alpha \\cdot \\mathbf{x}_i),$$\n",
"\n",
"where $\\lambda_0(t)$ is the baseline hazard at time $t$ and $\\alpha$ is a vector of regression coefficients.\n",
"\n",
"In this post, we will use the equivalent Poisson model discussed in the past posts to perform inference on these survival models."
]
},
{
"cell_type": "markdown",
"id": "0cfdc122-48a3-42bb-a95e-08b43887d888",
"metadata": {},
"source": [
"### Joint model\n",
"\n",
"The goal of this post is to show how we can improve our models by incorporating information from longitudinal outcomes into our survival models. We denote the value of the longitudinal outcome for the $i$-th subject at time $t$ by $y_{i, t}$. There are many ways to incorporate this information into our survival model (entire [books](https://www.routledge.com/Joint-Modeling-of-Longitudinal-and-Time-to-Event-Data/Elashoff-li-Li/p/book/9780367570576) have been written on the subject); in this post we take the approach of assuming independence of the survival and longitudinal outcomes conditional on random effects. Specifically, we posit a random effects model for $y_{i, t}$, $y_{i, t} \\sim N(\\mu_{i, t}, \\sigma^2)$ with\n",
"\n",
"$$\\mu_{i, t} = \\beta \\cdot \\mathbf{x}_i + \\gamma_{i, t},$$\n",
"\n",
"where $\\gamma_{i, t}$ is a set of [random effects](https://en.wikipedia.org/wiki/Random_effects_model) that can vary based on the subject and time.\n",
"\n",
"Our conditional independence model assumes that the longitudinal outcome only influences survival through the randome effects $\\gamma_{i, t}$, and incorporates these into the survival model as\n",
"\n",
"$$\\lambda(t\\ |\\ \\mathbf{x}_i, \\gamma_{i, t}) = \\lambda_0(t) \\cdot \\exp(\\alpha \\cdot \\mathbf{x}_i + \\nu \\cdot \\gamma_{i, t}).$$"
]
},
{
"cell_type": "markdown",
"id": "8fe82ecf-b707-4273-91ec-2f53234450d9",
"metadata": {},
"source": [
"## Worked example"
]
},
{
"cell_type": "markdown",
"id": "023146e4-b963-48f0-b718-9ac070b7fb76",
"metadata": {},
"source": [
"First we make the necessary Python imports and do some light configuration."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "24ccfd3c-2493-4da5-9deb-8143cd438063",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ccc5b703-a817-4931-bc11-c60e15f04c79",
"metadata": {},
"outputs": [],
"source": [
"import arviz as az\n",
"from matplotlib import pyplot as plt\n",
"import numpy as np\n",
"import nutpie\n",
"import polars as pl\n",
"import pymc as pm\n",
"from pytensor import tensor as pt\n",
"import seaborn as sns\n",
"from seaborn import objects as so"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "fb1bed02-0d07-4327-8349-69a3f8806557",
"metadata": {},
"outputs": [],
"source": [
"sns.set(color_codes=True)"
]
},
{
"cell_type": "markdown",
"id": "790400e5-f319-4730-91ac-5af9a1ece9fc",
"metadata": {},
"source": [
"### Load the data\n",
"\n",
"In this worked example, we use longitudinal [data](https://vincentarelbundock.github.io/Rdatasets/doc/survival/pbcseq.html) from a Mayo Clinic study on primary biliary cirrhosis."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "076dd427-b9cf-483f-809c-642945de1897",
"metadata": {},
"outputs": [],
"source": [
"DATA_PATH = \"https://vincentarelbundock.github.io/Rdatasets/csv/survival/pbcseq.csv\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "fb2f44ec-f393-4116-9340-298f182be677",
"metadata": {},
"outputs": [],
"source": [
"COLS = [\n",
" \"id\",\n",
" \"status\",\n",
" \"trt\",\n",
" \"day\",\n",
" \"bili\",\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "4ec683b4-0a1c-41d0-b57d-198f70f8c6f5",
"metadata": {},
"outputs": [],
"source": [
"df = pl.read_csv(DATA_PATH, columns=COLS)"
]
},
{
"cell_type": "markdown",
"id": "1653250c-0784-4b1d-80d3-a199026dd0ee",
"metadata": {},
"source": [
"#### Data exploration and transformation\n",
"\n",
"We examine this data below."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "266bdcb5-3704-48d2-a6f2-feb7466cbc12",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div><style>\n",
".dataframe > thead > tr,\n",
".dataframe > tbody > tr {\n",
" text-align: right;\n",
" white-space: pre-wrap;\n",
"}\n",
"</style>\n",
"<small>shape: (1_945, 5)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>id</th><th>status</th><th>trt</th><th>day</th><th>bili</th></tr><tr><td>i64</td><td>i64</td><td>i64</td><td>i64</td><td>f64</td></tr></thead><tbody><tr><td>1</td><td>2</td><td>1</td><td>0</td><td>14.5</td></tr><tr><td>1</td><td>2</td><td>1</td><td>192</td><td>21.3</td></tr><tr><td>2</td><td>0</td><td>1</td><td>0</td><td>1.1</td></tr><tr><td>2</td><td>0</td><td>1</td><td>182</td><td>0.8</td></tr><tr><td>2</td><td>0</td><td>1</td><td>365</td><td>1.0</td></tr><tr><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td></tr><tr><td>312</td><td>0</td><td>0</td><td>0</td><td>6.4</td></tr><tr><td>312</td><td>0</td><td>0</td><td>206</td><td>5.5</td></tr><tr><td>312</td><td>0</td><td>0</td><td>390</td><td>7.4</td></tr><tr><td>312</td><td>0</td><td>0</td><td>775</td><td>16.3</td></tr><tr><td>312</td><td>0</td><td>0</td><td>1075</td><td>23.4</td></tr></tbody></table></div>"
],
"text/plain": [
"shape: (1_945, 5)\n",
"┌─────┬────────┬─────┬──────┬──────┐\n",
"│ id ┆ status ┆ trt ┆ day ┆ bili │\n",
"│ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n",
"│ i64 ┆ i64 ┆ i64 ┆ i64 ┆ f64 │\n",
"╞═════╪════════╪═════╪══════╪══════╡\n",
"│ 1 ┆ 2 ┆ 1 ┆ 0 ┆ 14.5 │\n",
"│ 1 ┆ 2 ┆ 1 ┆ 192 ┆ 21.3 │\n",
"│ 2 ┆ 0 ┆ 1 ┆ 0 ┆ 1.1 │\n",
"│ 2 ┆ 0 ┆ 1 ┆ 182 ┆ 0.8 │\n",
"│ 2 ┆ 0 ┆ 1 ┆ 365 ┆ 1.0 │\n",
"│ … ┆ … ┆ … ┆ … ┆ … │\n",
"│ 312 ┆ 0 ┆ 0 ┆ 0 ┆ 6.4 │\n",
"│ 312 ┆ 0 ┆ 0 ┆ 206 ┆ 5.5 │\n",
"│ 312 ┆ 0 ┆ 0 ┆ 390 ┆ 7.4 │\n",
"│ 312 ┆ 0 ┆ 0 ┆ 775 ┆ 16.3 │\n",
"│ 312 ┆ 0 ┆ 0 ┆ 1075 ┆ 23.4 │\n",
"└─────┴────────┴─────┴──────┴──────┘"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df"
]
},
{
"cell_type": "markdown",
"id": "3d02229c-4a7c-4fd2-8549-2c9cd86692b0",
"metadata": {},
"source": [
"* `id` is the case number of the subject.\n",
"* `status` indicates the subject's status at the end of their time in the study:\n",
" * `0` indicates that they were alive at the end of the study,\n",
" * `1` indicates that they exited the study upon receiving a liver transplant,\n",
" * and `2` indicates that they died during the study.\n",
"* `trt` indicates if they received a placebo or the true treatment.\n",
"* `day` indicates the number of days between enrollment of the patient and the visit.\n",
"* `bili` indicates the concentration of [bilirubin](https://www.mayoclinic.org/tests-procedures/bilirubin/about/pac-20393041) in the blood during that visit, in mg/dL.\n",
"\n",
"The survival outcome is derived from the `status`, and the longitudinal outcome is derived from `bili`.\n",
"\n",
"First we (crudely) reduce the `day` column to monthly (really 30 day) granularity for ease of modeling."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "fe9fe0e8-ef06-4e54-b29a-f6dfeabda25e",
"metadata": {},
"outputs": [],
"source": [
"df = df.with_columns(month=pl.col(\"day\") // 30)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f707877c-3ae0-4df3-b557-df261a138f2d",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div><style>\n",
".dataframe > thead > tr,\n",
".dataframe > tbody > tr {\n",
" text-align: right;\n",
" white-space: pre-wrap;\n",
"}\n",
"</style>\n",
"<small>shape: (1_945, 6)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>id</th><th>status</th><th>trt</th><th>day</th><th>bili</th><th>month</th></tr><tr><td>i64</td><td>i64</td><td>i64</td><td>i64</td><td>f64</td><td>i64</td></tr></thead><tbody><tr><td>1</td><td>2</td><td>1</td><td>0</td><td>14.5</td><td>0</td></tr><tr><td>1</td><td>2</td><td>1</td><td>192</td><td>21.3</td><td>6</td></tr><tr><td>2</td><td>0</td><td>1</td><td>0</td><td>1.1</td><td>0</td></tr><tr><td>2</td><td>0</td><td>1</td><td>182</td><td>0.8</td><td>6</td></tr><tr><td>2</td><td>0</td><td>1</td><td>365</td><td>1.0</td><td>12</td></tr><tr><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td></tr><tr><td>312</td><td>0</td><td>0</td><td>0</td><td>6.4</td><td>0</td></tr><tr><td>312</td><td>0</td><td>0</td><td>206</td><td>5.5</td><td>6</td></tr><tr><td>312</td><td>0</td><td>0</td><td>390</td><td>7.4</td><td>13</td></tr><tr><td>312</td><td>0</td><td>0</td><td>775</td><td>16.3</td><td>25</td></tr><tr><td>312</td><td>0</td><td>0</td><td>1075</td><td>23.4</td><td>35</td></tr></tbody></table></div>"
],
"text/plain": [
"shape: (1_945, 6)\n",
"┌─────┬────────┬─────┬──────┬──────┬───────┐\n",
"│ id ┆ status ┆ trt ┆ day ┆ bili ┆ month │\n",
"│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n",
"│ i64 ┆ i64 ┆ i64 ┆ i64 ┆ f64 ┆ i64 │\n",
"╞═════╪════════╪═════╪══════╪══════╪═══════╡\n",
"│ 1 ┆ 2 ┆ 1 ┆ 0 ┆ 14.5 ┆ 0 │\n",
"│ 1 ┆ 2 ┆ 1 ┆ 192 ┆ 21.3 ┆ 6 │\n",
"│ 2 ┆ 0 ┆ 1 ┆ 0 ┆ 1.1 ┆ 0 │\n",
"│ 2 ┆ 0 ┆ 1 ┆ 182 ┆ 0.8 ┆ 6 │\n",
"│ 2 ┆ 0 ┆ 1 ┆ 365 ┆ 1.0 ┆ 12 │\n",
"│ … ┆ … ┆ … ┆ … ┆ … ┆ … │\n",
"│ 312 ┆ 0 ┆ 0 ┆ 0 ┆ 6.4 ┆ 0 │\n",
"│ 312 ┆ 0 ┆ 0 ┆ 206 ┆ 5.5 ┆ 6 │\n",
"│ 312 ┆ 0 ┆ 0 ┆ 390 ┆ 7.4 ┆ 13 │\n",
"│ 312 ┆ 0 ┆ 0 ┆ 775 ┆ 16.3 ┆ 25 │\n",
"│ 312 ┆ 0 ┆ 0 ┆ 1075 ┆ 23.4 ┆ 35 │\n",
"└─────┴────────┴─────┴──────┴──────┴───────┘"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df"
]
},
{
"cell_type": "markdown",
"id": "27faec9e-5b97-46e4-bc50-a7a2df6fd847",
"metadata": {},
"source": [
"Next we reduce this longitudinal dataframe, which may have multiple rows per subject, to a dataframe that has one row per subject."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "4dc90926-e9ce-4f9f-a406-8a4cbf3fd4ef",
"metadata": {},
"outputs": [],
"source": [
"subj_df = (\n",
" df.group_by(\"id\")\n",
" .agg(pl.col(\"month\").max(), pl.col(\"trt\").first(), pl.col(\"status\").first())\n",
" .sort(\"id\")\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "f198f6c2-a6dc-4005-9600-b4b9296ba1e9",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div><style>\n",
".dataframe > thead > tr,\n",
".dataframe > tbody > tr {\n",
" text-align: right;\n",
" white-space: pre-wrap;\n",
"}\n",
"</style>\n",
"<small>shape: (312, 4)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>id</th><th>month</th><th>trt</th><th>status</th></tr><tr><td>i64</td><td>i64</td><td>i64</td><td>i64</td></tr></thead><tbody><tr><td>1</td><td>6</td><td>1</td><td>2</td></tr><tr><td>2</td><td>107</td><td>1</td><td>0</td></tr><tr><td>3</td><td>24</td><td>1</td><td>2</td></tr><tr><td>4</td><td>60</td><td>1</td><td>2</td></tr><tr><td>5</td><td>48</td><td>0</td><td>1</td></tr><tr><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td></tr><tr><td>308</td><td>38</td><td>1</td><td>0</td></tr><tr><td>309</td><td>42</td><td>0</td><td>0</td></tr><tr><td>310</td><td>45</td><td>1</td><td>0</td></tr><tr><td>311</td><td>36</td><td>1</td><td>0</td></tr><tr><td>312</td><td>35</td><td>0</td><td>0</td></tr></tbody></table></div>"
],
"text/plain": [
"shape: (312, 4)\n",
"┌─────┬───────┬─────┬────────┐\n",
"│ id ┆ month ┆ trt ┆ status │\n",
"│ --- ┆ --- ┆ --- ┆ --- │\n",
"│ i64 ┆ i64 ┆ i64 ┆ i64 │\n",
"╞═════╪═══════╪═════╪════════╡\n",
"│ 1 ┆ 6 ┆ 1 ┆ 2 │\n",
"│ 2 ┆ 107 ┆ 1 ┆ 0 │\n",
"│ 3 ┆ 24 ┆ 1 ┆ 2 │\n",
"│ 4 ┆ 60 ┆ 1 ┆ 2 │\n",
"│ 5 ┆ 48 ┆ 0 ┆ 1 │\n",
"│ … ┆ … ┆ … ┆ … │\n",
"│ 308 ┆ 38 ┆ 1 ┆ 0 │\n",
"│ 309 ┆ 42 ┆ 0 ┆ 0 │\n",
"│ 310 ┆ 45 ┆ 1 ┆ 0 │\n",
"│ 311 ┆ 36 ┆ 1 ┆ 0 │\n",
"│ 312 ┆ 35 ┆ 0 ┆ 0 │\n",
"└─────┴───────┴─────┴────────┘"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"subj_df"
]
},
{
"cell_type": "markdown",
"id": "de38b14c-659c-4717-9408-bbbd90d1cfbb",
"metadata": {},
"source": [
"* `id`, `trt`, and `status` have retained their meanings from the longitudinal data frame.\n",
"* `month` indicates the number of months (really 30-day periods) after which they exited the study."
]
},
{
"cell_type": "markdown",
"id": "cae7c1e3-4536-4acb-bd9c-4a1f06d2d5a5",
"metadata": {},
"source": [
"### Modeling\n",
"\n",
"We now turn to modeling impact of treatment on survival using this data."
]
},
{
"cell_type": "markdown",
"id": "e95681e2-887e-4d14-aa9f-0554956b10da",
"metadata": {},
"source": [
"#### Survival model\n",
"\n",
"We first implement a pure survival model for two reasons:\n",
"\n",
"1. it is a key component of the joint model, and\n",
"2. its inferences will provide a good baseline against which to compare those of the joint model.\n",
"\n",
"First we derive NumPy arrays indicating the time each subject spent in the study (`t`), whether or not they died during the study (`died`), and whether or not they were treated (`trt`)."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "cc840a85-1097-4e37-a4ae-b04fa4eacef6",
"metadata": {},
"outputs": [],
"source": [
"t = subj_df[\"month\"].to_numpy()\n",
"died = subj_df[\"status\"].eq(2).to_numpy()\n",
"trt = subj_df[\"trt\"].eq(1).to_numpy()"
]
},
{
"cell_type": "markdown",
"id": "f20c7ba6-a559-4468-aac7-2525fdbe4ee6",
"metadata": {},
"source": [
"Next we derive some ancillary quantities necessary to use a Poisson likelihood to perform inference on the proportional hazard model. For a detailed treatment of these quantities, refer to a prior [post](https://austinrochford.com/posts/revisit-survival-pymc.html)."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "bd4c1a72-6da9-41bf-9341-8a064fd78365",
"metadata": {},
"outputs": [],
"source": [
"exposed = np.full((subj_df.shape[0], t.max() + 2), True, dtype=np.bool_)\n",
"np.put_along_axis(exposed, t[:, np.newaxis] + 1, False, axis=1)\n",
"exposed = np.minimum.accumulate(exposed, axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "49b4ed08-2d18-4c5f-a4e3-2757074bb2fb",
"metadata": {},
"outputs": [],
"source": [
"died_ = np.full_like(exposed, False, dtype=np.bool_)\n",
"np.put_along_axis(died_, t[:, np.newaxis], died[:, np.newaxis], axis=1)\n",
"\n",
"assert (died_ & ~exposed).sum() == 0"
]
},
{
"cell_type": "markdown",
"id": "712a2b28-a300-4388-9f4e-de64679330a6",
"metadata": {},
"source": [
"We are now ready to begin building the survival model with PyMC. For the baseline hazard we choose a hierachical normal prior,\n",
"\n",
"$$\n",
"\\begin{align}\n",
" \\mu_{\\lambda_0}\n",
" & \\sim N(0, 2.5^2) \\\\\n",
" \\sigma_{\\lambda_0}\n",
" & \\sim \\text{Half}-N(1) \\\\\n",
" \\log \\lambda_0(t)\n",
" & \\sim N(\\mu_{\\lambda_0}, \\sigma_{\\lambda_0}^2).\n",
"\\end{align}\n",
"$$\n",
"\n",
"For computational efficiency, we implement this prior using a [non-centered parameterization](https://twiecki.io/blog/2017/02/08/bayesian-hierchical-non-centered/)."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "965ecf65-79ba-4079-babf-b91ff3cee51b",
"metadata": {},
"outputs": [],
"source": [
"# the scale necessary to make a halfnormal distribution have unit variance\n",
"HALFNORMAL_SCALE = 1 / np.sqrt(1 - 2 / np.pi)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "7f263dd3-8f2b-4025-95f2-9bbd995fb214",
"metadata": {},
"outputs": [],
"source": [
"def noncentered_normal(name, *, dims, μ=None):\n",
" if μ is None:\n",
" μ = pm.Normal(f\"μ_{name}\", 0, 2.5)\n",
"\n",
" Δ = pm.Normal(f\"Δ_{name}\", 0, 1, dims=dims)\n",
" σ = pm.HalfNormal(f\"σ_{name}\", 2.5 * HALFNORMAL_SCALE)\n",
"\n",
" return pm.Deterministic(name, μ + Δ * σ, dims=dims)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "a2727155-3d17-401a-8849-715732978029",
"metadata": {},
"outputs": [],
"source": [
"coords = {\"drug\": np.array([False, True]), \"t\": np.arange(t.max() + 2)}"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "200bf773-e090-4e84-bec7-c4f535f9c622",
"metadata": {},
"outputs": [],
"source": [
"with pm.Model(coords=coords) as surv_model:\n",
" log_λ0 = noncentered_normal(\"log_λ0\", dims=\"t\")\n",
" λ0 = pt.exp(log_λ0)"
]
},
{
"cell_type": "markdown",
"id": "8dbc0db5-2b06-41fe-a672-85d8991adf47",
"metadata": {},
"source": [
"Now we introduce the regression component of the model, making survival dependent on treatment.\n",
"\n",
"We let $\\alpha_{\\text{trt}} \\sim N(0, 2.5^2)$ and define the hazard function as\n",
"\n",
"$$\\lambda(t\\ |\\ x_{\\text{trt}, i}) = \\lambda_0(t) \\cdot \\exp(\\alpha_{\\text{trt}} \\cdot x_{\\text{trt}, i}).$$"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "d5225384-c321-4aec-8a0e-6638662082e0",
"metadata": {},
"outputs": [],
"source": [
"with surv_model:\n",
" α_trt = pm.Normal(\"α_trt\", 0, 2.5)\n",
"\n",
" λ = pt.outer(pt.exp(α_trt * trt), λ0)"
]
},
{
"cell_type": "markdown",
"id": "ec74b4d5-251a-4152-b6f8-9b9b7d31fc12",
"metadata": {},
"source": [
"Note that we have not included an intercept term in our regression, as that combined with the baseline hazard would lead to an [unidentified](https://en.wikipedia.org/wiki/Identifiability) model.\n",
"\n",
"Finally we specify the Poisson likelihood for our model."
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "28414ff6-4582-490f-b90d-062d969a4c15",
"metadata": {},
"outputs": [],
"source": [
"with surv_model:\n",
" pm.Poisson(\"died\", exposed * λ, observed=died_)"
]
},
{
"cell_type": "markdown",
"id": "5d8abc89-8ccc-48c5-b84f-55d8ca4f2ae4",
"metadata": {},
"source": [
"Before sampling, we define the cumulative survival function of our model, in order to obtain samples from its posterior predictive distribution."
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "24a70e35-7665-48a0-ac11-f7e9bef583b0",
"metadata": {},
"outputs": [],
"source": [
"with surv_model:\n",
" λ_pred = pt.outer(pt.exp(α_trt * np.array([0, 1])), λ0)\n",
" Λ_pred = λ_pred.cumsum(axis=1)\n",
" sf_pred = pm.Deterministic(\"sf_pred\", pt.exp(-Λ_pred), dims=(\"drug\", \"t\"))"
]
},
{
"cell_type": "markdown",
"id": "ebb4098a-b2e8-4db3-b29f-b395afb3c55f",
"metadata": {},
"source": [
"We are now ready to sample from our model."
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "5835ca18-bac3-435a-8689-4df8b92d4fdf",
"metadata": {},
"outputs": [],
"source": [
"SAMPLER_KWARGS = {\"cores\": 8, \"seed\": 1234567890}"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "14f6c7d7-04fa-4288-acfa-05d2148a76ce",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"<style>\n",
" :root {\n",
" --column-width-1: 40%; /* Progress column width */\n",
" --column-width-2: 15%; /* Chain column width */\n",
" --column-width-3: 15%; /* Divergences column width */\n",
" --column-width-4: 15%; /* Step Size column width */\n",
" --column-width-5: 15%; /* Gradients/Draw column width */\n",
" }\n",
"\n",
" .nutpie {\n",
" max-width: 800px;\n",
" margin: 10px auto;\n",
" font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;\n",
" //color: #333;\n",
" //background-color: #fff;\n",
" padding: 10px;\n",
" box-shadow: 0 4px 6px rgba(0,0,0,0.1);\n",
" border-radius: 8px;\n",
" font-size: 14px; /* Smaller font size for a more compact look */\n",
" }\n",
" .nutpie table {\n",
" width: 100%;\n",
" border-collapse: collapse; /* Remove any extra space between borders */\n",
" }\n",
" .nutpie th, .nutpie td {\n",
" padding: 8px 10px; /* Reduce padding to make table more compact */\n",
" text-align: left;\n",
" border-bottom: 1px solid #888;\n",
" }\n",
" .nutpie th {\n",
" //background-color: #f0f0f0;\n",
" }\n",
"\n",
" .nutpie th:nth-child(1) { width: var(--column-width-1); }\n",
" .nutpie th:nth-child(2) { width: var(--column-width-2); }\n",
" .nutpie th:nth-child(3) { width: var(--column-width-3); }\n",
" .nutpie th:nth-child(4) { width: var(--column-width-4); }\n",
" .nutpie th:nth-child(5) { width: var(--column-width-5); }\n",
"\n",
" .nutpie progress {\n",
" width: 100%;\n",
" height: 15px; /* Smaller progress bars */\n",
" border-radius: 5px;\n",
" }\n",
" progress::-webkit-progress-bar {\n",
" background-color: #eee;\n",
" border-radius: 5px;\n",
" }\n",
" progress::-webkit-progress-value {\n",
" background-color: #5cb85c;\n",
" border-radius: 5px;\n",
" }\n",
" progress::-moz-progress-bar {\n",
" background-color: #5cb85c;\n",
" border-radius: 5px;\n",
" }\n",
" .nutpie .progress-cell {\n",
" width: 100%;\n",
" }\n",
"\n",
" .nutpie p strong { font-size: 16px; font-weight: bold; }\n",
"\n",
" @media (prefers-color-scheme: dark) {\n",
" .nutpie {\n",
" //color: #ddd;\n",
" //background-color: #1e1e1e;\n",
" box-shadow: 0 4px 6px rgba(0,0,0,0.2);\n",
" }\n",
" .nutpie table, .nutpie th, .nutpie td {\n",
" border-color: #555;\n",
" color: #ccc;\n",
" }\n",
" .nutpie th {\n",
" background-color: #2a2a2a;\n",
" }\n",
" .nutpie progress::-webkit-progress-bar {\n",
" background-color: #444;\n",
" }\n",
" .nutpie progress::-webkit-progress-value {\n",
" background-color: #3178c6;\n",
" }\n",
" .nutpie progress::-moz-progress-bar {\n",
" background-color: #3178c6;\n",
" }\n",
" }\n",
"</style>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"<div class=\"nutpie\">\n",
" <p><strong>Sampler Progress</strong></p>\n",
" <p>Total Chains: <span id=\"total-chains\">6</span></p>\n",
" <p>Active Chains: <span id=\"active-chains\">0</span></p>\n",
" <p>\n",
" Finished Chains:\n",
" <span id=\"active-chains\">6</span>\n",
" </p>\n",
" <p>Sampling for now</p>\n",
" <p>\n",
" Estimated Time to Completion:\n",
" <span id=\"eta\">now</span>\n",
" </p>\n",
"\n",
" <progress\n",
" id=\"total-progress-bar\"\n",
" max=\"7800\"\n",
" value=\"7800\">\n",
" </progress>\n",
" <table>\n",
" <thead>\n",
" <tr>\n",
" <th>Progress</th>\n",
" <th>Draws</th>\n",
" <th>Divergences</th>\n",
" <th>Step Size</th>\n",
" <th>Gradients/Draw</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody id=\"chain-details\">\n",
" \n",
" <tr>\n",
" <td class=\"progress-cell\">\n",
" <progress\n",
" max=\"1300\"\n",
" value=\"1300\">\n",
" </progress>\n",
" </td>\n",
" <td>1300</td>\n",
" <td>0</td>\n",
" <td>0.41</td>\n",
" <td>15</td>\n",
" </tr>\n",
" \n",
" <tr>\n",
" <td class=\"progress-cell\">\n",
" <progress\n",
" max=\"1300\"\n",
" value=\"1300\">\n",
" </progress>\n",
" </td>\n",
" <td>1300</td>\n",
" <td>0</td>\n",
" <td>0.39</td>\n",
" <td>15</td>\n",
" </tr>\n",
" \n",
" <tr>\n",
" <td class=\"progress-cell\">\n",
" <progress\n",
" max=\"1300\"\n",
" value=\"1300\">\n",
" </progress>\n",
" </td>\n",
" <td>1300</td>\n",
" <td>0</td>\n",
" <td>0.41</td>\n",
" <td>15</td>\n",
" </tr>\n",
" \n",
" <tr>\n",
" <td class=\"progress-cell\">\n",
" <progress\n",
" max=\"1300\"\n",
" value=\"1300\">\n",
" </progress>\n",
" </td>\n",
" <td>1300</td>\n",
" <td>0</td>\n",
" <td>0.40</td>\n",
" <td>15</td>\n",
" </tr>\n",
" \n",
" <tr>\n",
" <td class=\"progress-cell\">\n",
" <progress\n",
" max=\"1300\"\n",
" value=\"1300\">\n",
" </progress>\n",
" </td>\n",
" <td>1300</td>\n",
" <td>0</td>\n",
" <td>0.40</td>\n",
" <td>15</td>\n",
" </tr>\n",
" \n",
" <tr>\n",
" <td class=\"progress-cell\">\n",
" <progress\n",
" max=\"1300\"\n",
" value=\"1300\">\n",
" </progress>\n",
" </td>\n",
" <td>1300</td>\n",
" <td>0</td>\n",
" <td>0.42</td>\n",
" <td>15</td>\n",
" </tr>\n",
" \n",
" </tr>\n",
" </tbody>\n",
" </table>\n",
"</div>\n"
],
"text/plain": [
"<nutpie.sample._BackgroundSampler at 0x30ad7f2c0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"surv_trace = nutpie.sample(nutpie.compile_pymc_model(surv_model), **SAMPLER_KWARGS)"
]
},
{
"cell_type": "markdown",
"id": "a3811197-52b7-4e7e-b80d-3b851c09152d",
"metadata": {},
"source": [
"Standard sampling diagnostics show no cause for concern."
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "b2fdc5ae-2786-4d41-9b53-51b49a140ec5",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div><svg style=\"position: absolute; width: 0; height: 0; overflow: hidden\">\n",
"<defs>\n",
"<symbol id=\"icon-database\" viewBox=\"0 0 32 32\">\n",
"<path d=\"M16 0c-8.837 0-16 2.239-16 5v4c0 2.761 7.163 5 16 5s16-2.239 16-5v-4c0-2.761-7.163-5-16-5z\"></path>\n",
"<path d=\"M16 17c-8.837 0-16-2.239-16-5v6c0 2.761 7.163 5 16 5s16-2.239 16-5v-6c0 2.761-7.163 5-16 5z\"></path>\n",
"<path d=\"M16 26c-8.837 0-16-2.239-16-5v6c0 2.761 7.163 5 16 5s16-2.239 16-5v-6c0 2.761-7.163 5-16 5z\"></path>\n",
"</symbol>\n",
"<symbol id=\"icon-file-text2\" viewBox=\"0 0 32 32\">\n",
"<path d=\"M28.681 7.159c-0.694-0.947-1.662-2.053-2.724-3.116s-2.169-2.030-3.116-2.724c-1.612-1.182-2.393-1.319-2.841-1.319h-15.5c-1.378 0-2.5 1.121-2.5 2.5v27c0 1.378 1.122 2.5 2.5 2.5h23c1.378 0 2.5-1.122 2.5-2.5v-19.5c0-0.448-0.137-1.23-1.319-2.841zM24.543 5.457c0.959 0.959 1.712 1.825 2.268 2.543h-4.811v-4.811c0.718 0.556 1.584 1.309 2.543 2.268zM28 29.5c0 0.271-0.229 0.5-0.5 0.5h-23c-0.271 0-0.5-0.229-0.5-0.5v-27c0-0.271 0.229-0.5 0.5-0.5 0 0 15.499-0 15.5 0v7c0 0.552 0.448 1 1 1h7v19.5z\"></path>\n",
"<path d=\"M23 26h-14c-0.552 0-1-0.448-1-1s0.448-1 1-1h14c0.552 0 1 0.448 1 1s-0.448 1-1 1z\"></path>\n",
"<path d=\"M23 22h-14c-0.552 0-1-0.448-1-1s0.448-1 1-1h14c0.552 0 1 0.448 1 1s-0.448 1-1 1z\"></path>\n",
"<path d=\"M23 18h-14c-0.552 0-1-0.448-1-1s0.448-1 1-1h14c0.552 0 1 0.448 1 1s-0.448 1-1 1z\"></path>\n",
"</symbol>\n",
"</defs>\n",
"</svg>\n",
"<style>/* CSS stylesheet for displaying xarray objects in jupyterlab.\n",
" *\n",
" */\n",
"\n",
":root {\n",
" --xr-font-color0: var(--jp-content-font-color0, rgba(0, 0, 0, 1));\n",
" --xr-font-color2: var(--jp-content-font-color2, rgba(0, 0, 0, 0.54));\n",
" --xr-font-color3: var(--jp-content-font-color3, rgba(0, 0, 0, 0.38));\n",
" --xr-border-color: var(--jp-border-color2, #e0e0e0);\n",
" --xr-disabled-color: var(--jp-layout-color3, #bdbdbd);\n",
" --xr-background-color: var(--jp-layout-color0, white);\n",
" --xr-background-color-row-even: var(--jp-layout-color1, white);\n",
" --xr-background-color-row-odd: var(--jp-layout-color2, #eeeeee);\n",
"}\n",
"\n",
"html[theme=dark],\n",
"html[data-theme=dark],\n",
"body[data-theme=dark],\n",
"body.vscode-dark {\n",
" --xr-font-color0: rgba(255, 255, 255, 1);\n",
" --xr-font-color2: rgba(255, 255, 255, 0.54);\n",
" --xr-font-color3: rgba(255, 255, 255, 0.38);\n",
" --xr-border-color: #1F1F1F;\n",
" --xr-disabled-color: #515151;\n",
" --xr-background-color: #111111;\n",
" --xr-background-color-row-even: #111111;\n",
" --xr-background-color-row-odd: #313131;\n",
"}\n",
"\n",
".xr-wrap {\n",
" display: block !important;\n",
" min-width: 300px;\n",
" max-width: 700px;\n",
"}\n",
"\n",
".xr-text-repr-fallback {\n",
" /* fallback to plain text repr when CSS is not injected (untrusted notebook) */\n",
" display: none;\n",
"}\n",
"\n",
".xr-header {\n",
" padding-top: 6px;\n",
" padding-bottom: 6px;\n",
" margin-bottom: 4px;\n",
" border-bottom: solid 1px var(--xr-border-color);\n",
"}\n",
"\n",
".xr-header > div,\n",
".xr-header > ul {\n",
" display: inline;\n",
" margin-top: 0;\n",
" margin-bottom: 0;\n",
"}\n",
"\n",
".xr-obj-type,\n",
".xr-array-name {\n",
" margin-left: 2px;\n",
" margin-right: 10px;\n",
"}\n",
"\n",
".xr-obj-type {\n",
" color: var(--xr-font-color2);\n",
"}\n",
"\n",
".xr-sections {\n",
" padding-left: 0 !important;\n",
" display: grid;\n",
" grid-template-columns: 150px auto auto 1fr 0 20px 0 20px;\n",
"}\n",
"\n",
".xr-section-item {\n",
" display: contents;\n",
"}\n",
"\n",
".xr-section-item input {\n",
" display: inline-block;\n",
" opacity: 0;\n",
"}\n",
"\n",
".xr-section-item input + label {\n",
" color: var(--xr-disabled-color);\n",
"}\n",
"\n",
".xr-section-item input:enabled + label {\n",
" cursor: pointer;\n",
" color: var(--xr-font-color2);\n",
"}\n",
"\n",
".xr-section-item input:focus + label {\n",
" border: 2px solid var(--xr-font-color0);\n",
"}\n",
"\n",
".xr-section-item input:enabled + label:hover {\n",
" color: var(--xr-font-color0);\n",
"}\n",
"\n",
".xr-section-summary {\n",
" grid-column: 1;\n",
" color: var(--xr-font-color2);\n",
" font-weight: 500;\n",
"}\n",
"\n",
".xr-section-summary > span {\n",
" display: inline-block;\n",
" padding-left: 0.5em;\n",
"}\n",
"\n",
".xr-section-summary-in:disabled + label {\n",
" color: var(--xr-font-color2);\n",
"}\n",
"\n",
".xr-section-summary-in + label:before {\n",
" display: inline-block;\n",
" content: '►';\n",
" font-size: 11px;\n",
" width: 15px;\n",
" text-align: center;\n",
"}\n",
"\n",
".xr-section-summary-in:disabled + label:before {\n",
" color: var(--xr-disabled-color);\n",
"}\n",
"\n",
".xr-section-summary-in:checked + label:before {\n",
" content: '▼';\n",
"}\n",
"\n",
".xr-section-summary-in:checked + label > span {\n",
" display: none;\n",
"}\n",
"\n",
".xr-section-summary,\n",
".xr-section-inline-details {\n",
" padding-top: 4px;\n",
" padding-bottom: 4px;\n",
"}\n",
"\n",
".xr-section-inline-details {\n",
" grid-column: 2 / -1;\n",
"}\n",
"\n",
".xr-section-details {\n",
" display: none;\n",
" grid-column: 1 / -1;\n",
" margin-bottom: 5px;\n",
"}\n",
"\n",
".xr-section-summary-in:checked ~ .xr-section-details {\n",
" display: contents;\n",
"}\n",
"\n",
".xr-array-wrap {\n",
" grid-column: 1 / -1;\n",
" display: grid;\n",
" grid-template-columns: 20px auto;\n",
"}\n",
"\n",
".xr-array-wrap > label {\n",
" grid-column: 1;\n",
" vertical-align: top;\n",
"}\n",
"\n",
".xr-preview {\n",
" color: var(--xr-font-color3);\n",
"}\n",
"\n",
".xr-array-preview,\n",
".xr-array-data {\n",
" padding: 0 5px !important;\n",
" grid-column: 2;\n",
"}\n",
"\n",
".xr-array-data,\n",
".xr-array-in:checked ~ .xr-array-preview {\n",
" display: none;\n",
"}\n",
"\n",
".xr-array-in:checked ~ .xr-array-data,\n",
".xr-array-preview {\n",
" display: inline-block;\n",
"}\n",
"\n",
".xr-dim-list {\n",
" display: inline-block !important;\n",
" list-style: none;\n",
" padding: 0 !important;\n",
" margin: 0;\n",
"}\n",
"\n",
".xr-dim-list li {\n",
" display: inline-block;\n",
" padding: 0;\n",
" margin: 0;\n",
"}\n",
"\n",
".xr-dim-list:before {\n",
" content: '(';\n",
"}\n",
"\n",
".xr-dim-list:after {\n",
" content: ')';\n",
"}\n",
"\n",
".xr-dim-list li:not(:last-child):after {\n",
" content: ',';\n",
" padding-right: 5px;\n",
"}\n",
"\n",
".xr-has-index {\n",
" font-weight: bold;\n",
"}\n",
"\n",
".xr-var-list,\n",
".xr-var-item {\n",
" display: contents;\n",
"}\n",
"\n",
".xr-var-item > div,\n",
".xr-var-item label,\n",
".xr-var-item > .xr-var-name span {\n",
" background-color: var(--xr-background-color-row-even);\n",
" margin-bottom: 0;\n",
"}\n",
"\n",
".xr-var-item > .xr-var-name:hover span {\n",
" padding-right: 5px;\n",
"}\n",
"\n",
".xr-var-list > li:nth-child(odd) > div,\n",
".xr-var-list > li:nth-child(odd) > label,\n",
".xr-var-list > li:nth-child(odd) > .xr-var-name span {\n",
" background-color: var(--xr-background-color-row-odd);\n",
"}\n",
"\n",
".xr-var-name {\n",
" grid-column: 1;\n",
"}\n",
"\n",
".xr-var-dims {\n",
" grid-column: 2;\n",
"}\n",
"\n",
".xr-var-dtype {\n",
" grid-column: 3;\n",
" text-align: right;\n",
" color: var(--xr-font-color2);\n",
"}\n",
"\n",
".xr-var-preview {\n",
" grid-column: 4;\n",
"}\n",
"\n",
".xr-index-preview {\n",
" grid-column: 2 / 5;\n",
" color: var(--xr-font-color2);\n",
"}\n",
"\n",
".xr-var-name,\n",
".xr-var-dims,\n",
".xr-var-dtype,\n",
".xr-preview,\n",
".xr-attrs dt {\n",
" white-space: nowrap;\n",
" overflow: hidden;\n",
" text-overflow: ellipsis;\n",
" padding-right: 10px;\n",
"}\n",
"\n",
".xr-var-name:hover,\n",
".xr-var-dims:hover,\n",
".xr-var-dtype:hover,\n",
".xr-attrs dt:hover {\n",
" overflow: visible;\n",
" width: auto;\n",
" z-index: 1;\n",
"}\n",
"\n",
".xr-var-attrs,\n",
".xr-var-data,\n",
".xr-index-data {\n",
" display: none;\n",
" background-color: var(--xr-background-color) !important;\n",
" padding-bottom: 5px !important;\n",
"}\n",
"\n",
".xr-var-attrs-in:checked ~ .xr-var-attrs,\n",
".xr-var-data-in:checked ~ .xr-var-data,\n",
".xr-index-data-in:checked ~ .xr-index-data {\n",
" display: block;\n",
"}\n",
"\n",
".xr-var-data > table {\n",
" float: right;\n",
"}\n",
"\n",
".xr-var-name span,\n",
".xr-var-data,\n",
".xr-index-name div,\n",
".xr-index-data,\n",
".xr-attrs {\n",
" padding-left: 25px !important;\n",
"}\n",
"\n",
".xr-attrs,\n",
".xr-var-attrs,\n",
".xr-var-data,\n",
".xr-index-data {\n",
" grid-column: 1 / -1;\n",
"}\n",
"\n",
"dl.xr-attrs {\n",
" padding: 0;\n",
" margin: 0;\n",
" display: grid;\n",
" grid-template-columns: 125px auto;\n",
"}\n",
"\n",
".xr-attrs dt,\n",
".xr-attrs dd {\n",
" padding: 0;\n",
" margin: 0;\n",
" float: left;\n",
" padding-right: 10px;\n",
" width: auto;\n",
"}\n",
"\n",
".xr-attrs dt {\n",
" font-weight: normal;\n",
" grid-column: 1;\n",
"}\n",
"\n",
".xr-attrs dt:hover span {\n",
" display: inline-block;\n",
" background: var(--xr-background-color);\n",
" padding-right: 10px;\n",
"}\n",
"\n",
".xr-attrs dd {\n",
" grid-column: 2;\n",
" white-space: pre-wrap;\n",
" word-break: break-all;\n",
"}\n",
"\n",
".xr-icon-database,\n",
".xr-icon-file-text2,\n",
".xr-no-icon {\n",
" display: inline-block;\n",
" vertical-align: middle;\n",
" width: 1em;\n",
" height: 1.5em !important;\n",
" stroke-width: 0;\n",
" stroke: currentColor;\n",
" fill: currentColor;\n",
"}\n",
"</style><pre class='xr-text-repr-fallback'>&lt;xarray.DataArray ()&gt; Size: 8B\n",
"array(1.00598908)</pre><div class='xr-wrap' style='display:none'><div class='xr-header'><div class='xr-obj-type'>xarray.DataArray</div><div class='xr-array-name'></div></div><ul class='xr-sections'><li class='xr-section-item'><div class='xr-array-wrap'><input id='section-98fc2642-d0e5-4743-87cf-d04c1e4840e6' class='xr-array-in' type='checkbox' checked><label for='section-98fc2642-d0e5-4743-87cf-d04c1e4840e6' title='Show/hide data repr'><svg class='icon xr-icon-database'><use xlink:href='#icon-database'></use></svg></label><div class='xr-array-preview xr-preview'><span>1.006</span></div><div class='xr-array-data'><pre>array(1.00598908)</pre></div></div></li><li class='xr-section-item'><input id='section-f580247d-c8df-4ec3-805a-7768c555ec67' class='xr-section-summary-in' type='checkbox' disabled ><label for='section-f580247d-c8df-4ec3-805a-7768c555ec67' class='xr-section-summary' title='Expand/collapse section'>Coordinates: <span>(0)</span></label><div class='xr-section-inline-details'></div><div class='xr-section-details'><ul class='xr-var-list'></ul></div></li><li class='xr-section-item'><input id='section-153c886e-5dba-4b8c-8dcb-653327a8238d' class='xr-section-summary-in' type='checkbox' disabled ><label for='section-153c886e-5dba-4b8c-8dcb-653327a8238d' class='xr-section-summary' title='Expand/collapse section'>Indexes: <span>(0)</span></label><div class='xr-section-inline-details'></div><div class='xr-section-details'><ul class='xr-var-list'></ul></div></li><li class='xr-section-item'><input id='section-89d5aceb-5eb7-424a-97f9-d3045e338b4d' class='xr-section-summary-in' type='checkbox' disabled ><label for='section-89d5aceb-5eb7-424a-97f9-d3045e338b4d' class='xr-section-summary' title='Expand/collapse section'>Attributes: <span>(0)</span></label><div class='xr-section-inline-details'></div><div class='xr-section-details'><dl class='xr-attrs'></dl></div></li></ul></div></div>"
],
"text/plain": [
"<xarray.DataArray ()> Size: 8B\n",
"array(1.00598908)"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"az.rhat(surv_trace).max().to_array().max()"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "33bc4d44-1d92-4d3c-b597-e4c00e068bd9",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"az.plot_energy(surv_trace);"
]
},
{
"cell_type": "markdown",
"id": "4bbe3598-0107-4105-95cd-60a32ebe2faf",
"metadata": {},
"source": [
"This model shows little, if any, influence of treatment on survival, as illustrated in the following plots."
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "86725096-4c03-4d53-af46-56e210aac5d1",
"metadata": {},
"outputs": [],
"source": [
"ALPHA = 0.05\n",
"\n",
"ci = so.Perc([100 * ALPHA / 2, 100 * (1 - ALPHA / 2)])"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "1bda569c-3fc8-400a-b7ae-0a98638861f7",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1400x600 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, (α_ax, sf_ax) = plt.subplots(figsize=(14, 6), ncols=2)\n",
"\n",
"az.plot_posterior(surv_trace, var_names=\"α_trt\", ref_val=0, ax=α_ax)\n",
"\n",
"(\n",
" so.Plot(\n",
" surv_trace.posterior[\"sf_pred\"].to_dataframe(), x=\"t\", y=\"sf_pred\", color=\"drug\"\n",
" )\n",
" .add(so.Line(), so.Agg())\n",
" .add(so.Band(), ci)\n",
" .scale(color=so.Nominal(), y=so.Continuous().tick(every=0.25).label(like=\"{x:.0%}\"))\n",
" .limit(x=(0, t.max()), y=(0, 1))\n",
" .label(x=\"Month\", y=\"Posterior predictive\\nsurvival function\")\n",
" .on(sf_ax)\n",
" .show()\n",
")\n",
"\n",
"fig.tight_layout();"
]
},
{
"cell_type": "markdown",
"id": "d05e84a4-7aa0-42ff-859e-e1f4a900cb8a",
"metadata": {},
"source": [
"#### Joint model\n",
"\n",
"We now get to the core of this post: implementing the joint model and observing how its inferences differ from those of the pure survival model.\n",
"\n",
"First we derive a NumPy arrays for the longitudinal outcome, the concentration of bilirubin (`bili`), the index of each subject (`i`), and the time of each visit (`t_visit`)."
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "c4a0acd4-234e-4a0d-980d-0d039d7b5537",
"metadata": {},
"outputs": [],
"source": [
"def make_time_scaler(t_max):\n",
" def time_scaler(t):\n",
" return t // t_max\n",
"\n",
" return time_scaler"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "02fc60b1-4971-47cc-85b8-f139d9d177ce",
"metadata": {},
"outputs": [],
"source": [
"bili = df[\"bili\"].to_numpy()\n",
"i = (df[\"id\"] - df[\"id\"].min()).to_numpy()\n",
"\n",
"time_scaler = make_time_scaler(df[\"month\"].max())\n",
"t_visit = time_scaler(df[\"month\"].to_numpy())"
]
},
{
"cell_type": "markdown",
"id": "18fd5c25-36e1-4926-b7cc-d09459c24532",
"metadata": {},
"source": [
"We also add subject ID to our model's coordinates."
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "0c6f5fdd-b0dd-4a80-a742-d36667d03ecb",
"metadata": {},
"outputs": [],
"source": [
"coords[\"id\"] = subj_df[\"id\"].to_numpy()"
]
},
{
"cell_type": "markdown",
"id": "f7381a1d-f9f7-4cca-86b6-6221e9e6afab",
"metadata": {},
"source": [
"We are now ready to specify a random effects model for the longitudinal outcome. We let\n",
"\n",
"$$\\mu_{\\text{bili}, t, i} = \\gamma_{0, i} + \\gamma_{t, i} \\cdot t + \\beta_{\\text{trt}} \\cdot x_{\\text{trt}, i}.$$\n",
"\n",
"We place a normal prior on the treatment coefficient and noncentered hierarchical normal random effects priors on the intercept and time coefficient."
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "fdf17a80-587f-49f2-8f25-2f045dce93d3",
"metadata": {},
"outputs": [],
"source": [
"with pm.Model(coords=coords) as joint_model:\n",
" γ0 = noncentered_normal(\"γ0\", dims=\"id\")\n",
" γ_t = noncentered_normal(\"γ\", dims=\"id\")\n",
" β_trt = pm.Normal(\"β_trt\", 0, 2.5)\n",
"\n",
" μ_bili = γ0[i] + γ_t[i] * t_visit + β_trt * trt[i]"
]
},
{
"cell_type": "markdown",
"id": "689443e1-d9b3-466a-a519-fc15d0b85f06",
"metadata": {},
"source": [
"We then specify the likelihood for the longitudinal outcome as\n",
"\n",
"$$\\log y_{\\text{bili}, i, t} \\sim N(\\mu_{\\text{bili}, i, t}, \\sigma_{\\text{bili}}^2)$$\n",
"\n",
"with $\\sigma_{\\text{bili}} \\sim \\text{Half}-N(2.5^2)$."
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "a0420c8b-ea25-445d-8269-f9976bf0a6b8",
"metadata": {},
"outputs": [],
"source": [
"with joint_model:\n",
" σ_bili = pm.HalfNormal(\"σ_bili\", 2.5 * HALFNORMAL_SCALE)\n",
" pm.Normal(\"log_bili\", μ_bili, σ_bili, observed=np.log(bili))"
]
},
{
"cell_type": "markdown",
"id": "8e877b7f-3c9a-4992-85ea-40247218f0eb",
"metadata": {},
"source": [
"The baseline hazard is specified the same as in the survival model."
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "ca2cc066-a576-4a98-84bd-99d1e07f8ff0",
"metadata": {},
"outputs": [],
"source": [
"with joint_model:\n",
" log_λ0 = noncentered_normal(\"log_λ0\", dims=\"t\")\n",
" λ0 = pt.exp(log_λ0)"
]
},
{
"cell_type": "markdown",
"id": "4f208a58-232e-42c5-9263-545a38128885",
"metadata": {},
"source": [
"Now let\n",
"\n",
"$$\\eta_{i, t} = \\alpha_\\text{trt} \\cdot x_{\\text{trt}, i} + \\nu_0 \\cdot \\gamma_{0, i} + \\nu_t \\cdot \\gamma_{t, i}$$\n",
"\n",
"with $\\alpha_\\text{trt}, \\nu_0, \\nu_t \\sim N(0, 2.5^2)$."
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "c7d65767-0910-4193-946a-94bb57927b8a",
"metadata": {},
"outputs": [],
"source": [
"t_surv = time_scaler(coords[\"t\"])"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "b4290406-e5cf-4f59-acd8-c97e274e6f73",
"metadata": {},
"outputs": [],
"source": [
"with joint_model:\n",
" α_trt = pm.Normal(\"α_trt\", 0, 2.5)\n",
" ν0 = pm.Normal(\"ν0\", 0, 2.5)\n",
" ν_t = pm.Normal(\"ν_t\", 0, 2.5)\n",
"\n",
" η = sum(\n",
" [\n",
" pt.atleast_2d(α_trt * trt + ν0 * γ0).T,\n",
" ν_t * pt.outer(γ_t, pt.as_tensor(t_surv)),\n",
" ]\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "0e1e2590-61aa-4646-816c-0afc9328fd47",
"metadata": {},
"source": [
"As before we model the hazard rate as $\\lambda_{i, t} = \\lambda_{0, t} \\cdot \\exp(\\eta_{i, t})$ and use the Poisson likelihood."
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "e27bd685-1f23-422e-bf36-cae6ee51fed1",
"metadata": {},
"outputs": [],
"source": [
"with joint_model:\n",
" λ = λ0 * pt.exp(η)\n",
"\n",
" pm.Poisson(\"died\", exposed * λ, observed=died_)"
]
},
{
"cell_type": "markdown",
"id": "3756e02a-d5be-45ed-933f-05bcf108b757",
"metadata": {},
"source": [
"As before, we define the cumulative survival function of our model, then sample from the model. Note that we add the average values of the random effects $\\gamma_0$ and $\\gamma_t$ to obtain predictions for the average subject."
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "669925c4-364f-40fd-ab0f-6a2eee494c1e",
"metadata": {},
"outputs": [],
"source": [
"with joint_model:\n",
" η_pred = pt.add.outer(\n",
" α_trt * np.array([0, 1]) + ν0 * γ0.mean(),\n",
" ν_t * γ_t.mean() * t_surv,\n",
" )\n",
" λ_pred = λ0 * pt.exp(η_pred)\n",
" Λ_pred = λ_pred.cumsum(axis=1)\n",
" sf_pred = pm.Deterministic(\"sf_pred\", pt.exp(-Λ_pred), dims=(\"drug\", \"t\"))"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "01a3d9f8-8be3-4a53-b5b5-6e3f74a06105",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"<style>\n",
" :root {\n",
" --column-width-1: 40%; /* Progress column width */\n",
" --column-width-2: 15%; /* Chain column width */\n",
" --column-width-3: 15%; /* Divergences column width */\n",
" --column-width-4: 15%; /* Step Size column width */\n",
" --column-width-5: 15%; /* Gradients/Draw column width */\n",
" }\n",
"\n",
" .nutpie {\n",
" max-width: 800px;\n",
" margin: 10px auto;\n",
" font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;\n",
" //color: #333;\n",
" //background-color: #fff;\n",
" padding: 10px;\n",
" box-shadow: 0 4px 6px rgba(0,0,0,0.1);\n",
" border-radius: 8px;\n",
" font-size: 14px; /* Smaller font size for a more compact look */\n",
" }\n",
" .nutpie table {\n",
" width: 100%;\n",
" border-collapse: collapse; /* Remove any extra space between borders */\n",
" }\n",
" .nutpie th, .nutpie td {\n",
" padding: 8px 10px; /* Reduce padding to make table more compact */\n",
" text-align: left;\n",
" border-bottom: 1px solid #888;\n",
" }\n",
" .nutpie th {\n",
" //background-color: #f0f0f0;\n",
" }\n",
"\n",
" .nutpie th:nth-child(1) { width: var(--column-width-1); }\n",
" .nutpie th:nth-child(2) { width: var(--column-width-2); }\n",
" .nutpie th:nth-child(3) { width: var(--column-width-3); }\n",
" .nutpie th:nth-child(4) { width: var(--column-width-4); }\n",
" .nutpie th:nth-child(5) { width: var(--column-width-5); }\n",
"\n",
" .nutpie progress {\n",
" width: 100%;\n",
" height: 15px; /* Smaller progress bars */\n",
" border-radius: 5px;\n",
" }\n",
" progress::-webkit-progress-bar {\n",
" background-color: #eee;\n",
" border-radius: 5px;\n",
" }\n",
" progress::-webkit-progress-value {\n",
" background-color: #5cb85c;\n",
" border-radius: 5px;\n",
" }\n",
" progress::-moz-progress-bar {\n",
" background-color: #5cb85c;\n",
" border-radius: 5px;\n",
" }\n",
" .nutpie .progress-cell {\n",
" width: 100%;\n",
" }\n",
"\n",
" .nutpie p strong { font-size: 16px; font-weight: bold; }\n",
"\n",
" @media (prefers-color-scheme: dark) {\n",
" .nutpie {\n",
" //color: #ddd;\n",
" //background-color: #1e1e1e;\n",
" box-shadow: 0 4px 6px rgba(0,0,0,0.2);\n",
" }\n",
" .nutpie table, .nutpie th, .nutpie td {\n",
" border-color: #555;\n",
" color: #ccc;\n",
" }\n",
" .nutpie th {\n",
" background-color: #2a2a2a;\n",
" }\n",
" .nutpie progress::-webkit-progress-bar {\n",
" background-color: #444;\n",
" }\n",
" .nutpie progress::-webkit-progress-value {\n",
" background-color: #3178c6;\n",
" }\n",
" .nutpie progress::-moz-progress-bar {\n",
" background-color: #3178c6;\n",
" }\n",
" }\n",
"</style>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"<div class=\"nutpie\">\n",
" <p><strong>Sampler Progress</strong></p>\n",
" <p>Total Chains: <span id=\"total-chains\">6</span></p>\n",
" <p>Active Chains: <span id=\"active-chains\">0</span></p>\n",
" <p>\n",
" Finished Chains:\n",
" <span id=\"active-chains\">6</span>\n",
" </p>\n",
" <p>Sampling for a minute</p>\n",
" <p>\n",
" Estimated Time to Completion:\n",
" <span id=\"eta\">now</span>\n",
" </p>\n",
"\n",
" <progress\n",
" id=\"total-progress-bar\"\n",
" max=\"7800\"\n",
" value=\"7800\">\n",
" </progress>\n",
" <table>\n",
" <thead>\n",
" <tr>\n",
" <th>Progress</th>\n",
" <th>Draws</th>\n",
" <th>Divergences</th>\n",
" <th>Step Size</th>\n",
" <th>Gradients/Draw</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody id=\"chain-details\">\n",
" \n",
" <tr>\n",
" <td class=\"progress-cell\">\n",
" <progress\n",
" max=\"1300\"\n",
" value=\"1300\">\n",
" </progress>\n",
" </td>\n",
" <td>1300</td>\n",
" <td>1</td>\n",
" <td>0.14</td>\n",
" <td>31</td>\n",
" </tr>\n",
" \n",
" <tr>\n",
" <td class=\"progress-cell\">\n",
" <progress\n",
" max=\"1300\"\n",
" value=\"1300\">\n",
" </progress>\n",
" </td>\n",
" <td>1300</td>\n",
" <td>2</td>\n",
" <td>0.13</td>\n",
" <td>31</td>\n",
" </tr>\n",
" \n",
" <tr>\n",
" <td class=\"progress-cell\">\n",
" <progress\n",
" max=\"1300\"\n",
" value=\"1300\">\n",
" </progress>\n",
" </td>\n",
" <td>1300</td>\n",
" <td>1</td>\n",
" <td>0.15</td>\n",
" <td>31</td>\n",
" </tr>\n",
" \n",
" <tr>\n",
" <td class=\"progress-cell\">\n",
" <progress\n",
" max=\"1300\"\n",
" value=\"1300\">\n",
" </progress>\n",
" </td>\n",
" <td>1300</td>\n",
" <td>1</td>\n",
" <td>0.15</td>\n",
" <td>31</td>\n",
" </tr>\n",
" \n",
" <tr>\n",
" <td class=\"progress-cell\">\n",
" <progress\n",
" max=\"1300\"\n",
" value=\"1300\">\n",
" </progress>\n",
" </td>\n",
" <td>1300</td>\n",
" <td>1</td>\n",
" <td>0.14</td>\n",
" <td>31</td>\n",
" </tr>\n",
" \n",
" <tr>\n",
" <td class=\"progress-cell\">\n",
" <progress\n",
" max=\"1300\"\n",
" value=\"1300\">\n",
" </progress>\n",
" </td>\n",
" <td>1300</td>\n",
" <td>0</td>\n",
" <td>0.13</td>\n",
" <td>31</td>\n",
" </tr>\n",
" \n",
" </tr>\n",
" </tbody>\n",
" </table>\n",
"</div>\n"
],
"text/plain": [
"<nutpie.sample._BackgroundSampler at 0x34b514a70>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"joint_trace = nutpie.sample(\n",
" nutpie.compile_pymc_model(joint_model), target_accept=0.95, **SAMPLER_KWARGS\n",
")"
]
},
{
"cell_type": "markdown",
"id": "01271828-1af0-4d2a-be05-707ce7e7d628",
"metadata": {},
"source": [
"Again, the standard sampling diagnostics show no cause for concern."
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "870b8cd0-b9d4-48a6-b528-fda3f6628572",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div><svg style=\"position: absolute; width: 0; height: 0; overflow: hidden\">\n",
"<defs>\n",
"<symbol id=\"icon-database\" viewBox=\"0 0 32 32\">\n",
"<path d=\"M16 0c-8.837 0-16 2.239-16 5v4c0 2.761 7.163 5 16 5s16-2.239 16-5v-4c0-2.761-7.163-5-16-5z\"></path>\n",
"<path d=\"M16 17c-8.837 0-16-2.239-16-5v6c0 2.761 7.163 5 16 5s16-2.239 16-5v-6c0 2.761-7.163 5-16 5z\"></path>\n",
"<path d=\"M16 26c-8.837 0-16-2.239-16-5v6c0 2.761 7.163 5 16 5s16-2.239 16-5v-6c0 2.761-7.163 5-16 5z\"></path>\n",
"</symbol>\n",
"<symbol id=\"icon-file-text2\" viewBox=\"0 0 32 32\">\n",
"<path d=\"M28.681 7.159c-0.694-0.947-1.662-2.053-2.724-3.116s-2.169-2.030-3.116-2.724c-1.612-1.182-2.393-1.319-2.841-1.319h-15.5c-1.378 0-2.5 1.121-2.5 2.5v27c0 1.378 1.122 2.5 2.5 2.5h23c1.378 0 2.5-1.122 2.5-2.5v-19.5c0-0.448-0.137-1.23-1.319-2.841zM24.543 5.457c0.959 0.959 1.712 1.825 2.268 2.543h-4.811v-4.811c0.718 0.556 1.584 1.309 2.543 2.268zM28 29.5c0 0.271-0.229 0.5-0.5 0.5h-23c-0.271 0-0.5-0.229-0.5-0.5v-27c0-0.271 0.229-0.5 0.5-0.5 0 0 15.499-0 15.5 0v7c0 0.552 0.448 1 1 1h7v19.5z\"></path>\n",
"<path d=\"M23 26h-14c-0.552 0-1-0.448-1-1s0.448-1 1-1h14c0.552 0 1 0.448 1 1s-0.448 1-1 1z\"></path>\n",
"<path d=\"M23 22h-14c-0.552 0-1-0.448-1-1s0.448-1 1-1h14c0.552 0 1 0.448 1 1s-0.448 1-1 1z\"></path>\n",
"<path d=\"M23 18h-14c-0.552 0-1-0.448-1-1s0.448-1 1-1h14c0.552 0 1 0.448 1 1s-0.448 1-1 1z\"></path>\n",
"</symbol>\n",
"</defs>\n",
"</svg>\n",
"<style>/* CSS stylesheet for displaying xarray objects in jupyterlab.\n",
" *\n",
" */\n",
"\n",
":root {\n",
" --xr-font-color0: var(--jp-content-font-color0, rgba(0, 0, 0, 1));\n",
" --xr-font-color2: var(--jp-content-font-color2, rgba(0, 0, 0, 0.54));\n",
" --xr-font-color3: var(--jp-content-font-color3, rgba(0, 0, 0, 0.38));\n",
" --xr-border-color: var(--jp-border-color2, #e0e0e0);\n",
" --xr-disabled-color: var(--jp-layout-color3, #bdbdbd);\n",
" --xr-background-color: var(--jp-layout-color0, white);\n",
" --xr-background-color-row-even: var(--jp-layout-color1, white);\n",
" --xr-background-color-row-odd: var(--jp-layout-color2, #eeeeee);\n",
"}\n",
"\n",
"html[theme=dark],\n",
"html[data-theme=dark],\n",
"body[data-theme=dark],\n",
"body.vscode-dark {\n",
" --xr-font-color0: rgba(255, 255, 255, 1);\n",
" --xr-font-color2: rgba(255, 255, 255, 0.54);\n",
" --xr-font-color3: rgba(255, 255, 255, 0.38);\n",
" --xr-border-color: #1F1F1F;\n",
" --xr-disabled-color: #515151;\n",
" --xr-background-color: #111111;\n",
" --xr-background-color-row-even: #111111;\n",
" --xr-background-color-row-odd: #313131;\n",
"}\n",
"\n",
".xr-wrap {\n",
" display: block !important;\n",
" min-width: 300px;\n",
" max-width: 700px;\n",
"}\n",
"\n",
".xr-text-repr-fallback {\n",
" /* fallback to plain text repr when CSS is not injected (untrusted notebook) */\n",
" display: none;\n",
"}\n",
"\n",
".xr-header {\n",
" padding-top: 6px;\n",
" padding-bottom: 6px;\n",
" margin-bottom: 4px;\n",
" border-bottom: solid 1px var(--xr-border-color);\n",
"}\n",
"\n",
".xr-header > div,\n",
".xr-header > ul {\n",
" display: inline;\n",
" margin-top: 0;\n",
" margin-bottom: 0;\n",
"}\n",
"\n",
".xr-obj-type,\n",
".xr-array-name {\n",
" margin-left: 2px;\n",
" margin-right: 10px;\n",
"}\n",
"\n",
".xr-obj-type {\n",
" color: var(--xr-font-color2);\n",
"}\n",
"\n",
".xr-sections {\n",
" padding-left: 0 !important;\n",
" display: grid;\n",
" grid-template-columns: 150px auto auto 1fr 0 20px 0 20px;\n",
"}\n",
"\n",
".xr-section-item {\n",
" display: contents;\n",
"}\n",
"\n",
".xr-section-item input {\n",
" display: inline-block;\n",
" opacity: 0;\n",
"}\n",
"\n",
".xr-section-item input + label {\n",
" color: var(--xr-disabled-color);\n",
"}\n",
"\n",
".xr-section-item input:enabled + label {\n",
" cursor: pointer;\n",
" color: var(--xr-font-color2);\n",
"}\n",
"\n",
".xr-section-item input:focus + label {\n",
" border: 2px solid var(--xr-font-color0);\n",
"}\n",
"\n",
".xr-section-item input:enabled + label:hover {\n",
" color: var(--xr-font-color0);\n",
"}\n",
"\n",
".xr-section-summary {\n",
" grid-column: 1;\n",
" color: var(--xr-font-color2);\n",
" font-weight: 500;\n",
"}\n",
"\n",
".xr-section-summary > span {\n",
" display: inline-block;\n",
" padding-left: 0.5em;\n",
"}\n",
"\n",
".xr-section-summary-in:disabled + label {\n",
" color: var(--xr-font-color2);\n",
"}\n",
"\n",
".xr-section-summary-in + label:before {\n",
" display: inline-block;\n",
" content: '►';\n",
" font-size: 11px;\n",
" width: 15px;\n",
" text-align: center;\n",
"}\n",
"\n",
".xr-section-summary-in:disabled + label:before {\n",
" color: var(--xr-disabled-color);\n",
"}\n",
"\n",
".xr-section-summary-in:checked + label:before {\n",
" content: '▼';\n",
"}\n",
"\n",
".xr-section-summary-in:checked + label > span {\n",
" display: none;\n",
"}\n",
"\n",
".xr-section-summary,\n",
".xr-section-inline-details {\n",
" padding-top: 4px;\n",
" padding-bottom: 4px;\n",
"}\n",
"\n",
".xr-section-inline-details {\n",
" grid-column: 2 / -1;\n",
"}\n",
"\n",
".xr-section-details {\n",
" display: none;\n",
" grid-column: 1 / -1;\n",
" margin-bottom: 5px;\n",
"}\n",
"\n",
".xr-section-summary-in:checked ~ .xr-section-details {\n",
" display: contents;\n",
"}\n",
"\n",
".xr-array-wrap {\n",
" grid-column: 1 / -1;\n",
" display: grid;\n",
" grid-template-columns: 20px auto;\n",
"}\n",
"\n",
".xr-array-wrap > label {\n",
" grid-column: 1;\n",
" vertical-align: top;\n",
"}\n",
"\n",
".xr-preview {\n",
" color: var(--xr-font-color3);\n",
"}\n",
"\n",
".xr-array-preview,\n",
".xr-array-data {\n",
" padding: 0 5px !important;\n",
" grid-column: 2;\n",
"}\n",
"\n",
".xr-array-data,\n",
".xr-array-in:checked ~ .xr-array-preview {\n",
" display: none;\n",
"}\n",
"\n",
".xr-array-in:checked ~ .xr-array-data,\n",
".xr-array-preview {\n",
" display: inline-block;\n",
"}\n",
"\n",
".xr-dim-list {\n",
" display: inline-block !important;\n",
" list-style: none;\n",
" padding: 0 !important;\n",
" margin: 0;\n",
"}\n",
"\n",
".xr-dim-list li {\n",
" display: inline-block;\n",
" padding: 0;\n",
" margin: 0;\n",
"}\n",
"\n",
".xr-dim-list:before {\n",
" content: '(';\n",
"}\n",
"\n",
".xr-dim-list:after {\n",
" content: ')';\n",
"}\n",
"\n",
".xr-dim-list li:not(:last-child):after {\n",
" content: ',';\n",
" padding-right: 5px;\n",
"}\n",
"\n",
".xr-has-index {\n",
" font-weight: bold;\n",
"}\n",
"\n",
".xr-var-list,\n",
".xr-var-item {\n",
" display: contents;\n",
"}\n",
"\n",
".xr-var-item > div,\n",
".xr-var-item label,\n",
".xr-var-item > .xr-var-name span {\n",
" background-color: var(--xr-background-color-row-even);\n",
" margin-bottom: 0;\n",
"}\n",
"\n",
".xr-var-item > .xr-var-name:hover span {\n",
" padding-right: 5px;\n",
"}\n",
"\n",
".xr-var-list > li:nth-child(odd) > div,\n",
".xr-var-list > li:nth-child(odd) > label,\n",
".xr-var-list > li:nth-child(odd) > .xr-var-name span {\n",
" background-color: var(--xr-background-color-row-odd);\n",
"}\n",
"\n",
".xr-var-name {\n",
" grid-column: 1;\n",
"}\n",
"\n",
".xr-var-dims {\n",
" grid-column: 2;\n",
"}\n",
"\n",
".xr-var-dtype {\n",
" grid-column: 3;\n",
" text-align: right;\n",
" color: var(--xr-font-color2);\n",
"}\n",
"\n",
".xr-var-preview {\n",
" grid-column: 4;\n",
"}\n",
"\n",
".xr-index-preview {\n",
" grid-column: 2 / 5;\n",
" color: var(--xr-font-color2);\n",
"}\n",
"\n",
".xr-var-name,\n",
".xr-var-dims,\n",
".xr-var-dtype,\n",
".xr-preview,\n",
".xr-attrs dt {\n",
" white-space: nowrap;\n",
" overflow: hidden;\n",
" text-overflow: ellipsis;\n",
" padding-right: 10px;\n",
"}\n",
"\n",
".xr-var-name:hover,\n",
".xr-var-dims:hover,\n",
".xr-var-dtype:hover,\n",
".xr-attrs dt:hover {\n",
" overflow: visible;\n",
" width: auto;\n",
" z-index: 1;\n",
"}\n",
"\n",
".xr-var-attrs,\n",
".xr-var-data,\n",
".xr-index-data {\n",
" display: none;\n",
" background-color: var(--xr-background-color) !important;\n",
" padding-bottom: 5px !important;\n",
"}\n",
"\n",
".xr-var-attrs-in:checked ~ .xr-var-attrs,\n",
".xr-var-data-in:checked ~ .xr-var-data,\n",
".xr-index-data-in:checked ~ .xr-index-data {\n",
" display: block;\n",
"}\n",
"\n",
".xr-var-data > table {\n",
" float: right;\n",
"}\n",
"\n",
".xr-var-name span,\n",
".xr-var-data,\n",
".xr-index-name div,\n",
".xr-index-data,\n",
".xr-attrs {\n",
" padding-left: 25px !important;\n",
"}\n",
"\n",
".xr-attrs,\n",
".xr-var-attrs,\n",
".xr-var-data,\n",
".xr-index-data {\n",
" grid-column: 1 / -1;\n",
"}\n",
"\n",
"dl.xr-attrs {\n",
" padding: 0;\n",
" margin: 0;\n",
" display: grid;\n",
" grid-template-columns: 125px auto;\n",
"}\n",
"\n",
".xr-attrs dt,\n",
".xr-attrs dd {\n",
" padding: 0;\n",
" margin: 0;\n",
" float: left;\n",
" padding-right: 10px;\n",
" width: auto;\n",
"}\n",
"\n",
".xr-attrs dt {\n",
" font-weight: normal;\n",
" grid-column: 1;\n",
"}\n",
"\n",
".xr-attrs dt:hover span {\n",
" display: inline-block;\n",
" background: var(--xr-background-color);\n",
" padding-right: 10px;\n",
"}\n",
"\n",
".xr-attrs dd {\n",
" grid-column: 2;\n",
" white-space: pre-wrap;\n",
" word-break: break-all;\n",
"}\n",
"\n",
".xr-icon-database,\n",
".xr-icon-file-text2,\n",
".xr-no-icon {\n",
" display: inline-block;\n",
" vertical-align: middle;\n",
" width: 1em;\n",
" height: 1.5em !important;\n",
" stroke-width: 0;\n",
" stroke: currentColor;\n",
" fill: currentColor;\n",
"}\n",
"</style><pre class='xr-text-repr-fallback'>&lt;xarray.DataArray ()&gt; Size: 8B\n",
"array(1.02313563)</pre><div class='xr-wrap' style='display:none'><div class='xr-header'><div class='xr-obj-type'>xarray.DataArray</div><div class='xr-array-name'></div></div><ul class='xr-sections'><li class='xr-section-item'><div class='xr-array-wrap'><input id='section-3c4b5629-2a41-4ca5-ae58-50610ee7792c' class='xr-array-in' type='checkbox' checked><label for='section-3c4b5629-2a41-4ca5-ae58-50610ee7792c' title='Show/hide data repr'><svg class='icon xr-icon-database'><use xlink:href='#icon-database'></use></svg></label><div class='xr-array-preview xr-preview'><span>1.023</span></div><div class='xr-array-data'><pre>array(1.02313563)</pre></div></div></li><li class='xr-section-item'><input id='section-b2f78b0f-7d5c-40da-98bf-f3505e64a25f' class='xr-section-summary-in' type='checkbox' disabled ><label for='section-b2f78b0f-7d5c-40da-98bf-f3505e64a25f' class='xr-section-summary' title='Expand/collapse section'>Coordinates: <span>(0)</span></label><div class='xr-section-inline-details'></div><div class='xr-section-details'><ul class='xr-var-list'></ul></div></li><li class='xr-section-item'><input id='section-3ed04038-aebc-4377-a8dc-bf2488183fc1' class='xr-section-summary-in' type='checkbox' disabled ><label for='section-3ed04038-aebc-4377-a8dc-bf2488183fc1' class='xr-section-summary' title='Expand/collapse section'>Indexes: <span>(0)</span></label><div class='xr-section-inline-details'></div><div class='xr-section-details'><ul class='xr-var-list'></ul></div></li><li class='xr-section-item'><input id='section-02741fae-2529-4e2a-aef0-20f92bd758d1' class='xr-section-summary-in' type='checkbox' disabled ><label for='section-02741fae-2529-4e2a-aef0-20f92bd758d1' class='xr-section-summary' title='Expand/collapse section'>Attributes: <span>(0)</span></label><div class='xr-section-inline-details'></div><div class='xr-section-details'><dl class='xr-attrs'></dl></div></li></ul></div></div>"
],
"text/plain": [
"<xarray.DataArray ()> Size: 8B\n",
"array(1.02313563)"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"az.rhat(joint_trace).max().to_array().max()"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "53293951-7508-4b25-a901-289aac4b7d4d",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"az.plot_energy(joint_trace);"
]
},
{
"cell_type": "markdown",
"id": "5588e21d-da30-47a0-93b8-5e62cf1f1f65",
"metadata": {},
"source": [
"This model shows a stronger influence of treatment on survival, as illustrated in the following charts."
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "aedfad58-2a3b-4187-877c-e4f6fe50e7c2",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1400x600 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, (α_ax, sf_ax) = plt.subplots(figsize=(14, 6), ncols=2)\n",
"\n",
"az.plot_posterior(joint_trace, var_names=\"α_trt\", ref_val=0, ax=α_ax)\n",
"\n",
"(\n",
" so.Plot(\n",
" joint_trace.posterior[\"sf_pred\"].to_dataframe(),\n",
" x=\"t\",\n",
" y=\"sf_pred\",\n",
" color=\"drug\",\n",
" )\n",
" .add(so.Line(), so.Agg())\n",
" .add(so.Band(), ci)\n",
" .scale(color=so.Nominal(), y=so.Continuous().tick(every=0.25).label(like=\"{x:.0%}\"))\n",
" .limit(x=(0, t.max()), y=(0, 1))\n",
" .label(x=\"Month\", y=\"Posterior predictive\\nsurvival function\")\n",
" .on(sf_ax)\n",
" .show()\n",
")\n",
"\n",
"fig.tight_layout();"
]
},
{
"cell_type": "markdown",
"id": "ae697ea0-068f-4144-b2c1-705f474ecaee",
"metadata": {},
"source": [
"The actual data from this study contains more covariates and longitudinal outcomes than we have included in this model. This example illustrates a framework for including more of the information in order to improve our estimate of the impact of treatment on survival."
]
},
{
"cell_type": "markdown",
"id": "bb3d7051-7e58-4e4a-aae4-5d3ee710773c",
"metadata": {},
"source": [
"This post is available as a Jupyter notebook [here]()."
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "1c507e76-5772-4d81-874e-0c437816ef39",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Last updated: Sat Mar 15 2025\n",
"\n",
"Python implementation: CPython\n",
"Python version : 3.12.5\n",
"IPython version : 8.29.0\n",
"\n",
"pytensor : 2.26.3\n",
"pymc : 5.18.2\n",
"arviz : 0.20.0\n",
"matplotlib: 3.9.2\n",
"numpy : 1.26.4\n",
"polars : 1.14.0\n",
"seaborn : 0.13.2\n",
"nutpie : 0.13.2\n",
"\n"
]
}
],
"source": [
"%load_ext watermark\n",
"%watermark -n -u -v -iv"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
},
"nikola": {
"date": "2025-03-15",
"slug": "joint-long-surv",
"title": "Joint Modeling of Longitudinal and Survival Outcomes in PyMC"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment