Last active
August 12, 2020 19:21
-
-
Save michaelosthege/6cd14970dd789247176c4d4a1dd28051 to your computer and use it in GitHub Desktop.
Analysis of Rt.live model estimates of R0 - a comparison between US regions (alternative link: https://nbviewer.jupyter.org/gist/michaelosthege/6cd14970dd789247176c4d4a1dd28051/_R0_regions.ipynb)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import arviz | |
import fastprogress | |
import logging | |
import matplotlib | |
from matplotlib import pyplot | |
import numpy | |
import os | |
import pandas | |
import pathlib | |
import pymc3 | |
import typing | |
_log = logging.getLogger(__file__) | |
US_REGION_CODES = { | |
"Alabama": "AL", | |
"Alaska": "AK", | |
"Arizona": "AZ", | |
"Arkansas": "AR", | |
"California": "CA", | |
"Colorado": "CO", | |
"Connecticut": "CT", | |
"District of Columbia": "DC", | |
"Delaware": "DE", | |
"Florida": "FL", | |
"Georgia": "GA", | |
"Hawaii": "HI", | |
"Idaho": "ID", | |
"Illinois": "IL", | |
"Indiana": "IN", | |
"Iowa": "IA", | |
"Kansas": "KS", | |
"Kentucky": "KY", | |
"Louisiana": "LA", | |
"Maine": "ME", | |
"Maryland": "MD", | |
"Massachusetts": "MA", | |
"Michigan": "MI", | |
"Minnesota": "MN", | |
"Mississippi": "MS", | |
"Missouri": "MO", | |
"Montana": "MT", | |
"Nebraska": "NE", | |
"Nevada": "NV", | |
"New Hampshire": "NH", | |
"New Jersey": "NJ", | |
"New Mexico": "NM", | |
"New York": "NY", | |
"North Dakota": "ND", | |
"North Carolina": "NC", | |
"Ohio": "OH", | |
"Oklahoma": "OK", | |
"Oregon" : "OR", | |
"Pennsylvania" : "PA", | |
"Rhode Island": "RI", | |
"South Carolina": "SC", | |
"South Dakota": "SD", | |
"Tennessee": "TN", | |
"Texas": "TX", | |
"Utah": "UT", | |
"Vermont": "VT", | |
"Virginia": "VA", | |
"Washington": "WA", | |
"West Virginia": "WV", | |
"Wisconsin": "WI", | |
"Wyoming": "WY", | |
} | |
US_REGIONS = list(US_REGION_CODES.values()) | |
DATA_DATE = '2020-06-25' | |
def plot_r_t(region: str, country: str="us"): | |
# read the data | |
idata = arviz.from_netcdf(pathlib.Path(country, region, DATA_DATE, "trace.nc")) | |
fig, ax = pyplot.subplots( | |
dpi=140, | |
figsize=(10, 4), | |
) | |
pymc3.gp.util.plot_gp_dist( | |
ax=ax, | |
x=idata.posterior.date.values, | |
samples=idata.posterior.r_t.stack(sample=("chain", "draw")).T.values, | |
) | |
ax.axhline(1, linestyle=":") | |
ax.set_ylabel("$R_e(t)$ [-]", fontsize=15) | |
ax.legend( | |
handles=[ | |
ax.fill_between([], [], color="red", label=f"us/{region}") | |
], | |
loc="upper left", | |
frameon=False, | |
) | |
ax.xaxis.set_major_locator( | |
matplotlib.dates.WeekdayLocator(interval=1, byweekday=matplotlib.dates.MO) | |
) | |
ax.xaxis.set_minor_locator(matplotlib.dates.DayLocator()) | |
ax.xaxis.set_tick_params(rotation=90) | |
ax.set_ylim(0, 8) | |
fig.tight_layout() | |
return pyplot.show() | |
def plot_r_0(regions: typing.Sequence[str], country: str="us"): | |
region_samples = {} | |
for region in fastprogress.progress_bar(regions, leave=False): | |
idata = arviz.from_netcdf(pathlib.Path(country, region, DATA_DATE, "trace.nc")) | |
region_samples[region] = idata.posterior.r_t.stack(sample=('chain', 'draw')).values[0, :] | |
region_medians = { | |
region : numpy.median(samples) | |
for region, samples in region_samples.items() | |
} | |
fig, ax = pyplot.subplots( | |
dpi=140, | |
figsize=(10, 6), | |
) | |
for r, (region, median) in enumerate(sorted(region_medians.items(), key=lambda kv: kv[1])): | |
arviz.plot_kde( | |
ax=ax, | |
values=region_samples[region], | |
plot_kwargs=dict(linewidth=0.5) | |
) | |
# plot arrow to indicate the median | |
ax.annotate( | |
s=region, | |
xy=(median, 0), | |
xytext=(median, 0.15 + r % 9 * 0.15), | |
horizontalalignment="center", fontweight="bold", | |
arrowprops=dict(arrowstyle="-|>", facecolor="black", shrinkA=0, shrinkB=0) | |
) | |
ax.set_yticks(ticks=[], minor=[]) | |
ax.axvline(1, linestyle=":") | |
ax.set_xlabel("$R_0$ [-]", fontsize=15) | |
ax.set_xlim(left=0) | |
ax.set_ylabel("$p(R_0 \mid data)$", fontsize=15) | |
ax.set_ylim(0) | |
fig.tight_layout() | |
pyplot.show() | |
return region_medians, region_samples | |
def _get_US_population_densities() -> pandas.DataFrame: | |
dfs = pandas.read_html('https://en.wikipedia.org/wiki/List_of_states_and_territories_of_the_United_States_by_population_density') | |
df = dfs[0].rename(columns={ | |
"State etc.": "name", | |
"perkm2": "population_density", | |
})[[("name", "name"), ("Population density", "population_density")]] | |
df.columns = df.columns.droplevel(level=0) | |
df["code"] = [ | |
US_REGION_CODES[name] | |
if name in US_REGION_CODES else | |
None | |
for name in df.name | |
] | |
df.replace(to_replace="<1", value=0.49, inplace=True) | |
df = df.dropna().set_index("code").sort_index() | |
return df | |
def plot_scatter_r_0(regions: typing.Sequence[str], on_x="population_density", country: str="us"): | |
df_densities = _get_US_population_densities() | |
region_samples = {} | |
for region in fastprogress.progress_bar(regions): | |
idata = arviz.from_netcdf(pathlib.Path(country, region, DATA_DATE, "trace.nc")) | |
region_samples[region] = idata.posterior.r_t.stack(sample=('chain', 'draw')).values[0, :] | |
fig, ax = pyplot.subplots(dpi=140, figsize=(7, 7)) | |
ax.violinplot( | |
dataset=[ | |
samples.flatten() | |
for samples in region_samples.values() | |
], | |
positions=[ | |
numpy.log10(float(df_densities.loc[region, "population_density"])) | |
for region in region_samples.keys() | |
], | |
showextrema=False, | |
widths=0.3, | |
) | |
for region in regions: | |
ax.text( | |
s=region, | |
x=numpy.log10(float(df_densities.loc[region, "population_density"])), | |
y=numpy.median(region_samples[region]), | |
horizontalalignment="center", fontweight="bold", fontsize=6, | |
) | |
ax.xaxis.set_major_formatter(matplotlib.ticker.StrMethodFormatter("$10^{{{x:.0f}}}$")) | |
ax.xaxis.set_ticks([ | |
numpy.log10(x) | |
for p in range(-1, 4) | |
for x in numpy.linspace(10**p, 10**(p+1), 10) | |
], minor=True) | |
ax.xaxis.set_ticks([ | |
numpy.log10(x) | |
for p in range(-1, 4) | |
for x in numpy.linspace(10**p, 10**(p+1), 2) | |
], minor=False) | |
ax.set_ylim(0) | |
ax.set_xlim(-1, 4) | |
ax.set_ylabel("$p(R_0 \mid data)$ [-]") | |
ax.set_xlabel("population density [1/km²]") | |
return pyplot.show() | |
def read_nowcast_backcast_comparison( | |
region: str, | |
offset_now: int, | |
offset_back: int, | |
country: str="us", | |
hdi_prob: float=0.94, | |
) -> pandas.DataFrame: | |
df_results = pandas.DataFrame(columns=[ | |
"date", | |
"nowcast_run_date", "nowcast_median", "nowcast_hdi_down", "nowcast_hdi_up", | |
"backcast_run_date", "backcast_median", "backcast_hdi_down", "backcast_hdi_up", | |
]).set_index("date") | |
dp_region = pathlib.Path(country, region) | |
for date_str in os.listdir(dp_region): | |
fp_trace = pathlib.Path(dp_region, date_str, 'trace.nc') | |
if fp_trace.exists(): | |
run_date = pandas.Timestamp(date_str) | |
nowcast_date = run_date + pandas.DateOffset(offset_now) | |
backcast_date = nowcast_date + pandas.DateOffset(offset_back) | |
idata = arviz.from_netcdf(fp_trace) | |
r_t = idata.posterior.r_t.stack(sample=('chain', 'draw')) | |
if backcast_date not in list(r_t.date): | |
continue | |
nowcast = r_t.sel(date=nowcast_date).values | |
df_results.loc[nowcast_date, "nowcast_run_date"] = run_date | |
df_results.loc[nowcast_date, "nowcast_median"] = float(numpy.median(nowcast)) | |
df_results.loc[nowcast_date, ["nowcast_hdi_down", "nowcast_hdi_up"]] = tuple(arviz.hdi(nowcast, hdi_prob=hdi_prob)) | |
backcast = r_t.sel(date=backcast_date).values | |
df_results.loc[backcast_date, "backcast_run_date"] = run_date | |
df_results.loc[backcast_date, "backcast_median"] = float(numpy.median(backcast)) | |
df_results.loc[backcast_date, ["backcast_hdi_down", "backcast_hdi_up"]] = tuple(arviz.hdi(backcast, hdi_prob=hdi_prob)) | |
df_results.dropna(inplace=True) | |
for col in ['nowcast_median', 'nowcast_hdi_down', 'nowcast_hdi_up', 'backcast_median', 'backcast_hdi_down', 'backcast_hdi_up']: | |
df_results[col] = df_results[col].astype(float) | |
return df_results |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment