|
<!DOCTYPE html> |
|
<html xmlns="http://www.w3.org/1999/xhtml" lang="" xml:lang=""> |
|
<head> |
|
<meta charset="utf-8" /> |
|
<meta name="generator" content="pandoc" /> |
|
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes" /> |
|
<title>wasserstein_prior</title> |
|
<style> |
|
html { |
|
color: #1a1a1a; |
|
background-color: #fdfdfd; |
|
} |
|
body { |
|
margin: 0 auto; |
|
max-width: 36em; |
|
padding-left: 50px; |
|
padding-right: 50px; |
|
padding-top: 50px; |
|
padding-bottom: 50px; |
|
hyphens: auto; |
|
overflow-wrap: break-word; |
|
text-rendering: optimizeLegibility; |
|
font-kerning: normal; |
|
} |
|
@media (max-width: 600px) { |
|
body { |
|
font-size: 0.9em; |
|
padding: 12px; |
|
} |
|
h1 { |
|
font-size: 1.8em; |
|
} |
|
} |
|
@media print { |
|
html { |
|
background-color: white; |
|
} |
|
body { |
|
background-color: transparent; |
|
color: black; |
|
font-size: 12pt; |
|
} |
|
p, h2, h3 { |
|
orphans: 3; |
|
widows: 3; |
|
} |
|
h2, h3, h4 { |
|
page-break-after: avoid; |
|
} |
|
} |
|
p { |
|
margin: 1em 0; |
|
} |
|
a { |
|
color: #1a1a1a; |
|
} |
|
a:visited { |
|
color: #1a1a1a; |
|
} |
|
img { |
|
max-width: 100%; |
|
} |
|
svg { |
|
height: auto; |
|
max-width: 100%; |
|
} |
|
h1, h2, h3, h4, h5, h6 { |
|
margin-top: 1.4em; |
|
} |
|
h5, h6 { |
|
font-size: 1em; |
|
font-style: italic; |
|
} |
|
h6 { |
|
font-weight: normal; |
|
} |
|
ol, ul { |
|
padding-left: 1.7em; |
|
margin-top: 1em; |
|
} |
|
li > ol, li > ul { |
|
margin-top: 0; |
|
} |
|
blockquote { |
|
margin: 1em 0 1em 1.7em; |
|
padding-left: 1em; |
|
border-left: 2px solid #e6e6e6; |
|
color: #606060; |
|
} |
|
code { |
|
font-family: Menlo, Monaco, Consolas, 'Lucida Console', monospace; |
|
font-size: 85%; |
|
margin: 0; |
|
hyphens: manual; |
|
} |
|
pre { |
|
margin: 1em 0; |
|
overflow: auto; |
|
} |
|
pre code { |
|
padding: 0; |
|
overflow: visible; |
|
overflow-wrap: normal; |
|
} |
|
.sourceCode { |
|
background-color: transparent; |
|
overflow: visible; |
|
} |
|
hr { |
|
border: none; |
|
border-top: 1px solid #1a1a1a; |
|
height: 1px; |
|
margin: 1em 0; |
|
} |
|
table { |
|
margin: 1em 0; |
|
border-collapse: collapse; |
|
width: 100%; |
|
overflow-x: auto; |
|
display: block; |
|
font-variant-numeric: lining-nums tabular-nums; |
|
} |
|
table caption { |
|
margin-bottom: 0.75em; |
|
} |
|
tbody { |
|
margin-top: 0.5em; |
|
border-top: 1px solid #1a1a1a; |
|
border-bottom: 1px solid #1a1a1a; |
|
} |
|
th { |
|
border-top: 1px solid #1a1a1a; |
|
padding: 0.25em 0.5em 0.25em 0.5em; |
|
} |
|
td { |
|
padding: 0.125em 0.5em 0.25em 0.5em; |
|
} |
|
header { |
|
margin-bottom: 4em; |
|
text-align: center; |
|
} |
|
#TOC li { |
|
list-style: none; |
|
} |
|
#TOC ul { |
|
padding-left: 1.3em; |
|
} |
|
#TOC > ul { |
|
padding-left: 0; |
|
} |
|
#TOC a:not(:hover) { |
|
text-decoration: none; |
|
} |
|
code{white-space: pre-wrap;} |
|
span.smallcaps{font-variant: small-caps;} |
|
div.columns{display: flex; gap: min(4vw, 1.5em);} |
|
div.column{flex: auto; overflow-x: auto;} |
|
div.hanging-indent{margin-left: 1.5em; text-indent: -1.5em;} |
|
/* The extra [class] is a hack that increases specificity enough to |
|
override a similar rule in reveal.js */ |
|
ul.task-list[class]{list-style: none;} |
|
ul.task-list li input[type="checkbox"] { |
|
font-size: inherit; |
|
width: 0.8em; |
|
margin: 0 0.8em 0.2em -1.6em; |
|
vertical-align: middle; |
|
} |
|
/* CSS for syntax highlighting */ |
|
html { -webkit-text-size-adjust: 100%; } |
|
pre > code.sourceCode { white-space: pre; position: relative; } |
|
pre > code.sourceCode > span { display: inline-block; line-height: 1.25; } |
|
pre > code.sourceCode > span:empty { height: 1.2em; } |
|
.sourceCode { overflow: visible; } |
|
code.sourceCode > span { color: inherit; text-decoration: inherit; } |
|
div.sourceCode { margin: 1em 0; } |
|
pre.sourceCode { margin: 0; } |
|
@media screen { |
|
div.sourceCode { overflow: auto; } |
|
} |
|
@media print { |
|
pre > code.sourceCode { white-space: pre-wrap; } |
|
pre > code.sourceCode > span { text-indent: -5em; padding-left: 5em; } |
|
} |
|
pre.numberSource code |
|
{ counter-reset: source-line 0; } |
|
pre.numberSource code > span |
|
{ position: relative; left: -4em; counter-increment: source-line; } |
|
pre.numberSource code > span > a:first-child::before |
|
{ content: counter(source-line); |
|
position: relative; left: -1em; text-align: right; vertical-align: baseline; |
|
border: none; display: inline-block; |
|
-webkit-touch-callout: none; -webkit-user-select: none; |
|
-khtml-user-select: none; -moz-user-select: none; |
|
-ms-user-select: none; user-select: none; |
|
padding: 0 4px; width: 4em; |
|
color: #aaaaaa; |
|
} |
|
pre.numberSource { margin-left: 3em; border-left: 1px solid #aaaaaa; padding-left: 4px; } |
|
div.sourceCode |
|
{ } |
|
@media screen { |
|
pre > code.sourceCode > span > a:first-child::before { text-decoration: underline; } |
|
} |
|
code span.al { color: #ff0000; font-weight: bold; } /* Alert */ |
|
code span.an { color: #60a0b0; font-weight: bold; font-style: italic; } /* Annotation */ |
|
code span.at { color: #7d9029; } /* Attribute */ |
|
code span.bn { color: #40a070; } /* BaseN */ |
|
code span.bu { color: #008000; } /* BuiltIn */ |
|
code span.cf { color: #007020; font-weight: bold; } /* ControlFlow */ |
|
code span.ch { color: #4070a0; } /* Char */ |
|
code span.cn { color: #880000; } /* Constant */ |
|
code span.co { color: #60a0b0; font-style: italic; } /* Comment */ |
|
code span.cv { color: #60a0b0; font-weight: bold; font-style: italic; } /* CommentVar */ |
|
code span.do { color: #ba2121; font-style: italic; } /* Documentation */ |
|
code span.dt { color: #902000; } /* DataType */ |
|
code span.dv { color: #40a070; } /* DecVal */ |
|
code span.er { color: #ff0000; font-weight: bold; } /* Error */ |
|
code span.ex { } /* Extension */ |
|
code span.fl { color: #40a070; } /* Float */ |
|
code span.fu { color: #06287e; } /* Function */ |
|
code span.im { color: #008000; font-weight: bold; } /* Import */ |
|
code span.in { color: #60a0b0; font-weight: bold; font-style: italic; } /* Information */ |
|
code span.kw { color: #007020; font-weight: bold; } /* Keyword */ |
|
code span.op { color: #666666; } /* Operator */ |
|
code span.ot { color: #007020; } /* Other */ |
|
code span.pp { color: #bc7a00; } /* Preprocessor */ |
|
code span.sc { color: #4070a0; } /* SpecialChar */ |
|
code span.ss { color: #bb6688; } /* SpecialString */ |
|
code span.st { color: #4070a0; } /* String */ |
|
code span.va { color: #19177c; } /* Variable */ |
|
code span.vs { color: #4070a0; } /* VerbatimString */ |
|
code span.wa { color: #60a0b0; font-weight: bold; font-style: italic; } /* Warning */ |
|
</style> |
|
</head> |
|
<body> |
|
<div class="sourceCode" id="cb1"><pre |
|
class="sourceCode python"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> cmdstanpy</span> |
|
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a>cmdstanpy.install_cmdstan()</span> |
|
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span> |
|
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> scipy <span class="im">as</span> sp</span> |
|
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> cmdstanpy <span class="im">as</span> csp</span> |
|
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> scipy.special <span class="im">import</span> expit</span> |
|
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> ot</span> |
|
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> scipy.optimize <span class="im">import</span> minimize</span> |
|
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> scipy.stats <span class="im">import</span> multivariate_normal</span> |
|
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a>seed <span class="op">=</span> <span class="dv">583883</span></span> |
|
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a>D <span class="op">=</span> <span class="dv">6</span></span> |
|
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a>N <span class="op">=</span> <span class="dv">1500</span></span> |
|
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a>mu_x_OH <span class="op">=</span> <span class="op">-</span><span class="dv">1</span></span> |
|
<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a>mu_x_NY <span class="op">=</span> <span class="dv">1</span></span> |
|
<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a>B <span class="op">=</span> <span class="dv">10_000</span></span> |
|
<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a>h <span class="op">=</span> <span class="fl">0.04</span></span> |
|
<span id="cb1-20"><a href="#cb1-20" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-21"><a href="#cb1-21" aria-hidden="true" tabindex="-1"></a>rng <span class="op">=</span> np.random.default_rng(seed)</span> |
|
<span id="cb1-22"><a href="#cb1-22" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-23"><a href="#cb1-23" aria-hidden="true" tabindex="-1"></a><span class="co"># generate single parameter vectors</span></span> |
|
<span id="cb1-24"><a href="#cb1-24" aria-hidden="true" tabindex="-1"></a>beta <span class="op">=</span> rng.normal(loc<span class="op">=</span><span class="fl">0.0</span>, scale<span class="op">=</span><span class="fl">1.0</span>, size<span class="op">=</span>D)</span> |
|
<span id="cb1-25"><a href="#cb1-25" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"</span><span class="sc">{</span>beta<span class="op">=</span><span class="sc">}</span><span class="ss">"</span>)</span> |
|
<span id="cb1-26"><a href="#cb1-26" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-27"><a href="#cb1-27" aria-hidden="true" tabindex="-1"></a><span class="co"># generate Ohio data from parameters</span></span> |
|
<span id="cb1-28"><a href="#cb1-28" aria-hidden="true" tabindex="-1"></a>x_OH <span class="op">=</span> rng.normal(loc<span class="op">=-</span><span class="fl">1.0</span>, scale<span class="op">=</span><span class="fl">1.0</span>, size<span class="op">=</span>(N, D))</span> |
|
<span id="cb1-29"><a href="#cb1-29" aria-hidden="true" tabindex="-1"></a>y_OH <span class="op">=</span> rng.binomial(n<span class="op">=</span><span class="dv">1</span>, p <span class="op">=</span> expit(x_OH <span class="op">@</span> beta))</span> |
|
<span id="cb1-30"><a href="#cb1-30" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-31"><a href="#cb1-31" aria-hidden="true" tabindex="-1"></a><span class="co"># generate NY data from parameters</span></span> |
|
<span id="cb1-32"><a href="#cb1-32" aria-hidden="true" tabindex="-1"></a>x_NY <span class="op">=</span> rng.normal(loc<span class="op">=</span><span class="fl">1.0</span>, scale<span class="op">=</span><span class="fl">1.0</span>, size<span class="op">=</span>(N, D))</span> |
|
<span id="cb1-33"><a href="#cb1-33" aria-hidden="true" tabindex="-1"></a>y_NY <span class="op">=</span> rng.binomial(n<span class="op">=</span><span class="dv">1</span>, p <span class="op">=</span> expit(x_NY <span class="op">@</span> beta))</span> |
|
<span id="cb1-34"><a href="#cb1-34" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"</span><span class="sc">{</span>np<span class="sc">.</span>mean(y_NY)<span class="op">=</span><span class="sc">}</span><span class="ss">"</span>) <span class="co"># shouldn't be too extreme</span></span> |
|
<span id="cb1-35"><a href="#cb1-35" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-36"><a href="#cb1-36" aria-hidden="true" tabindex="-1"></a><span class="co"># fit OH data with simple logistic regression</span></span> |
|
<span id="cb1-37"><a href="#cb1-37" aria-hidden="true" tabindex="-1"></a>data_OH <span class="op">=</span> {<span class="st">'N'</span>: N, <span class="st">'D'</span>: D, <span class="st">'x'</span>: x_OH, <span class="st">'y'</span>: y_OH }</span> |
|
<span id="cb1-38"><a href="#cb1-38" aria-hidden="true" tabindex="-1"></a>model_OH <span class="op">=</span> csp.CmdStanModel(stan_file<span class="op">=</span><span class="st">'flat-logistic.stan'</span>)</span> |
|
<span id="cb1-39"><a href="#cb1-39" aria-hidden="true" tabindex="-1"></a>fit_OH <span class="op">=</span> model_OH.sample(data<span class="op">=</span>data_OH, chains<span class="op">=</span><span class="dv">4</span>, iter_sampling<span class="op">=</span>B <span class="op">//</span> <span class="dv">4</span>, seed<span class="op">=</span>seed)</span> |
|
<span id="cb1-40"><a href="#cb1-40" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(fit_OH.summary())</span> |
|
<span id="cb1-41"><a href="#cb1-41" aria-hidden="true" tabindex="-1"></a>beta_OH_draws <span class="op">=</span> fit_OH.stan_variable(<span class="st">'beta'</span>)</span> |
|
<span id="cb1-42"><a href="#cb1-42" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-43"><a href="#cb1-43" aria-hidden="true" tabindex="-1"></a><span class="co"># fit NY data with empirical prior from posterior of OH data</span></span> |
|
<span id="cb1-44"><a href="#cb1-44" aria-hidden="true" tabindex="-1"></a>data_NY <span class="op">=</span> {<span class="st">'N'</span>: N, <span class="st">'D'</span>: D, <span class="st">'x'</span>: x_NY, <span class="st">'y'</span>: y_NY,</span> |
|
<span id="cb1-45"><a href="#cb1-45" aria-hidden="true" tabindex="-1"></a> <span class="st">'h'</span>: h, <span class="st">'B'</span>: B, <span class="st">'beta0'</span>: beta_OH_draws }</span> |
|
<span id="cb1-46"><a href="#cb1-46" aria-hidden="true" tabindex="-1"></a>model_NY <span class="op">=</span> csp.CmdStanModel(stan_file<span class="op">=</span><span class="st">'empirical-logistic.stan'</span>)</span> |
|
<span id="cb1-47"><a href="#cb1-47" aria-hidden="true" tabindex="-1"></a>fit_NY <span class="op">=</span> model_NY.sample(data<span class="op">=</span>data_NY, chains<span class="op">=</span><span class="dv">4</span>, iter_warmup<span class="op">=</span><span class="dv">500</span>, iter_sampling<span class="op">=</span><span class="dv">500</span>, seed<span class="op">=</span>seed, show_progress<span class="op">=</span><span class="va">False</span>)</span> |
|
<span id="cb1-48"><a href="#cb1-48" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(fit_NY.summary())</span> |
|
<span id="cb1-49"><a href="#cb1-49" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-50"><a href="#cb1-50" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-51"><a href="#cb1-51" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> fit_wasserstein_prior(posterior_draws, prior_family<span class="op">=</span><span class="st">'mvn'</span>):</span> |
|
<span id="cb1-52"><a href="#cb1-52" aria-hidden="true" tabindex="-1"></a> <span class="co">"""</span></span> |
|
<span id="cb1-53"><a href="#cb1-53" aria-hidden="true" tabindex="-1"></a><span class="co"> Fit a parametric prior by minimizing Wasserstein distance</span></span> |
|
<span id="cb1-54"><a href="#cb1-54" aria-hidden="true" tabindex="-1"></a><span class="co"> to empirical posterior draws</span></span> |
|
<span id="cb1-55"><a href="#cb1-55" aria-hidden="true" tabindex="-1"></a><span class="co"> """</span></span> |
|
<span id="cb1-56"><a href="#cb1-56" aria-hidden="true" tabindex="-1"></a> D <span class="op">=</span> posterior_draws.shape[<span class="dv">1</span>]</span> |
|
<span id="cb1-57"><a href="#cb1-57" aria-hidden="true" tabindex="-1"></a> n_samples <span class="op">=</span> <span class="bu">len</span>(posterior_draws)</span> |
|
<span id="cb1-58"><a href="#cb1-58" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-59"><a href="#cb1-59" aria-hidden="true" tabindex="-1"></a> <span class="cf">if</span> prior_family <span class="op">==</span> <span class="st">'mvn'</span>:</span> |
|
<span id="cb1-60"><a href="#cb1-60" aria-hidden="true" tabindex="-1"></a> <span class="kw">def</span> objective(params):</span> |
|
<span id="cb1-61"><a href="#cb1-61" aria-hidden="true" tabindex="-1"></a> mu <span class="op">=</span> params[:D]</span> |
|
<span id="cb1-62"><a href="#cb1-62" aria-hidden="true" tabindex="-1"></a> <span class="co"># Use diagonal covariance for simplicity and stability</span></span> |
|
<span id="cb1-63"><a href="#cb1-63" aria-hidden="true" tabindex="-1"></a> log_std <span class="op">=</span> params[D:]</span> |
|
<span id="cb1-64"><a href="#cb1-64" aria-hidden="true" tabindex="-1"></a> Sigma <span class="op">=</span> np.diag(np.exp(log_std))</span> |
|
<span id="cb1-65"><a href="#cb1-65" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-66"><a href="#cb1-66" aria-hidden="true" tabindex="-1"></a> <span class="co"># Sample from fitted distribution</span></span> |
|
<span id="cb1-67"><a href="#cb1-67" aria-hidden="true" tabindex="-1"></a> fitted_samples <span class="op">=</span> multivariate_normal.rvs(mu, Sigma, size<span class="op">=</span>n_samples)</span> |
|
<span id="cb1-68"><a href="#cb1-68" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-69"><a href="#cb1-69" aria-hidden="true" tabindex="-1"></a> <span class="co"># Compute Wasserstein-2 distance</span></span> |
|
<span id="cb1-70"><a href="#cb1-70" aria-hidden="true" tabindex="-1"></a> <span class="co"># Create uniform probability distributions</span></span> |
|
<span id="cb1-71"><a href="#cb1-71" aria-hidden="true" tabindex="-1"></a> a <span class="op">=</span> np.ones(n_samples) <span class="op">/</span> n_samples <span class="co"># uniform weights for empirical</span></span> |
|
<span id="cb1-72"><a href="#cb1-72" aria-hidden="true" tabindex="-1"></a> b <span class="op">=</span> np.ones(n_samples) <span class="op">/</span> n_samples <span class="co"># uniform weights for fitted</span></span> |
|
<span id="cb1-73"><a href="#cb1-73" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-74"><a href="#cb1-74" aria-hidden="true" tabindex="-1"></a> <span class="co"># Compute cost matrix (squared Euclidean distances)</span></span> |
|
<span id="cb1-75"><a href="#cb1-75" aria-hidden="true" tabindex="-1"></a> cost_matrix <span class="op">=</span> ot.dist(posterior_draws, fitted_samples, metric<span class="op">=</span><span class="st">'sqeuclidean'</span>)</span> |
|
<span id="cb1-76"><a href="#cb1-76" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-77"><a href="#cb1-77" aria-hidden="true" tabindex="-1"></a> <span class="co"># Compute optimal transport cost</span></span> |
|
<span id="cb1-78"><a href="#cb1-78" aria-hidden="true" tabindex="-1"></a> W2_distance <span class="op">=</span> ot.emd2(a, b, cost_matrix)</span> |
|
<span id="cb1-79"><a href="#cb1-79" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> W2_distance</span> |
|
<span id="cb1-80"><a href="#cb1-80" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-81"><a href="#cb1-81" aria-hidden="true" tabindex="-1"></a> <span class="co"># Initialize with sample mean and std</span></span> |
|
<span id="cb1-82"><a href="#cb1-82" aria-hidden="true" tabindex="-1"></a> init_mu <span class="op">=</span> np.mean(posterior_draws, axis<span class="op">=</span><span class="dv">0</span>)</span> |
|
<span id="cb1-83"><a href="#cb1-83" aria-hidden="true" tabindex="-1"></a> init_std <span class="op">=</span> np.std(posterior_draws, axis<span class="op">=</span><span class="dv">0</span>)</span> |
|
<span id="cb1-84"><a href="#cb1-84" aria-hidden="true" tabindex="-1"></a> init_log_std <span class="op">=</span> np.log(init_std <span class="op">+</span> <span class="fl">1e-6</span>) <span class="co"># avoid log(0)</span></span> |
|
<span id="cb1-85"><a href="#cb1-85" aria-hidden="true" tabindex="-1"></a> init_params <span class="op">=</span> np.concatenate([init_mu, init_log_std])</span> |
|
<span id="cb1-86"><a href="#cb1-86" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-87"><a href="#cb1-87" aria-hidden="true" tabindex="-1"></a> <span class="co"># Set bounds to ensure reasonable values</span></span> |
|
<span id="cb1-88"><a href="#cb1-88" aria-hidden="true" tabindex="-1"></a> bounds <span class="op">=</span> [(<span class="va">None</span>, <span class="va">None</span>)] <span class="op">*</span> D <span class="op">+</span> [(<span class="op">-</span><span class="dv">5</span>, <span class="dv">5</span>)] <span class="op">*</span> D <span class="co"># bounds for log std devs</span></span> |
|
<span id="cb1-89"><a href="#cb1-89" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-90"><a href="#cb1-90" aria-hidden="true" tabindex="-1"></a> <span class="co"># Optimize</span></span> |
|
<span id="cb1-91"><a href="#cb1-91" aria-hidden="true" tabindex="-1"></a> <span class="cf">try</span>:</span> |
|
<span id="cb1-92"><a href="#cb1-92" aria-hidden="true" tabindex="-1"></a> result <span class="op">=</span> minimize(objective, init_params, method<span class="op">=</span><span class="st">'L-BFGS-B'</span>, bounds<span class="op">=</span>bounds)</span> |
|
<span id="cb1-93"><a href="#cb1-93" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-94"><a href="#cb1-94" aria-hidden="true" tabindex="-1"></a> <span class="co"># Extract optimized parameters</span></span> |
|
<span id="cb1-95"><a href="#cb1-95" aria-hidden="true" tabindex="-1"></a> mu_opt <span class="op">=</span> result.x[:D]</span> |
|
<span id="cb1-96"><a href="#cb1-96" aria-hidden="true" tabindex="-1"></a> log_std_opt <span class="op">=</span> result.x[D:]</span> |
|
<span id="cb1-97"><a href="#cb1-97" aria-hidden="true" tabindex="-1"></a> Sigma_opt <span class="op">=</span> np.diag(np.exp(log_std_opt))</span> |
|
<span id="cb1-98"><a href="#cb1-98" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-99"><a href="#cb1-99" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> mu_opt, Sigma_opt</span> |
|
<span id="cb1-100"><a href="#cb1-100" aria-hidden="true" tabindex="-1"></a> <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span> |
|
<span id="cb1-101"><a href="#cb1-101" aria-hidden="true" tabindex="-1"></a> <span class="bu">print</span>(<span class="ss">f"Optimization failed: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span> |
|
<span id="cb1-102"><a href="#cb1-102" aria-hidden="true" tabindex="-1"></a> <span class="bu">print</span>(<span class="st">"Falling back to moment matching"</span>)</span> |
|
<span id="cb1-103"><a href="#cb1-103" aria-hidden="true" tabindex="-1"></a> <span class="co"># Fallback to simple moment matching</span></span> |
|
<span id="cb1-104"><a href="#cb1-104" aria-hidden="true" tabindex="-1"></a> mu_fallback <span class="op">=</span> np.mean(posterior_draws, axis<span class="op">=</span><span class="dv">0</span>)</span> |
|
<span id="cb1-105"><a href="#cb1-105" aria-hidden="true" tabindex="-1"></a> Sigma_fallback <span class="op">=</span> np.cov(posterior_draws.T)</span> |
|
<span id="cb1-106"><a href="#cb1-106" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> mu_fallback, Sigma_fallback</span> |
|
<span id="cb1-107"><a href="#cb1-107" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-108"><a href="#cb1-108" aria-hidden="true" tabindex="-1"></a><span class="co"># Or use the much simpler closed-form solution:</span></span> |
|
<span id="cb1-109"><a href="#cb1-109" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> fit_wasserstein_prior_simple(posterior_draws):</span> |
|
<span id="cb1-110"><a href="#cb1-110" aria-hidden="true" tabindex="-1"></a> <span class="co">"""</span></span> |
|
<span id="cb1-111"><a href="#cb1-111" aria-hidden="true" tabindex="-1"></a><span class="co"> Closed-form Wasserstein-optimal Gaussian: just match moments</span></span> |
|
<span id="cb1-112"><a href="#cb1-112" aria-hidden="true" tabindex="-1"></a><span class="co"> """</span></span> |
|
<span id="cb1-113"><a href="#cb1-113" aria-hidden="true" tabindex="-1"></a> mu <span class="op">=</span> np.mean(posterior_draws, axis<span class="op">=</span><span class="dv">0</span>)</span> |
|
<span id="cb1-114"><a href="#cb1-114" aria-hidden="true" tabindex="-1"></a> Sigma <span class="op">=</span> np.cov(posterior_draws.T)</span> |
|
<span id="cb1-115"><a href="#cb1-115" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> mu, Sigma</span> |
|
<span id="cb1-116"><a href="#cb1-116" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-117"><a href="#cb1-117" aria-hidden="true" tabindex="-1"></a><span class="co"># Use the simple version:</span></span> |
|
<span id="cb1-118"><a href="#cb1-118" aria-hidden="true" tabindex="-1"></a>mu_prior, Sigma_prior <span class="op">=</span> fit_wasserstein_prior_simple(beta_OH_draws)</span> |
|
<span id="cb1-119"><a href="#cb1-119" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb1-120"><a href="#cb1-120" aria-hidden="true" tabindex="-1"></a><span class="co"># Use in Stan model (modified version)</span></span> |
|
<span id="cb1-121"><a href="#cb1-121" aria-hidden="true" tabindex="-1"></a>data_NY_wass <span class="op">=</span> {</span> |
|
<span id="cb1-122"><a href="#cb1-122" aria-hidden="true" tabindex="-1"></a> <span class="st">'N'</span>: N, <span class="st">'D'</span>: D, <span class="st">'x'</span>: x_NY, <span class="st">'y'</span>: y_NY,</span> |
|
<span id="cb1-123"><a href="#cb1-123" aria-hidden="true" tabindex="-1"></a> <span class="st">'mu_prior'</span>: mu_prior, <span class="st">'Sigma_prior'</span>: Sigma_prior</span> |
|
<span id="cb1-124"><a href="#cb1-124" aria-hidden="true" tabindex="-1"></a>}</span></code></pre></div> |
|
<pre><code>DEBUG:cmdstanpy:cmd: make examples/bernoulli/bernoulli |
|
cwd: None |
|
|
|
|
|
CmdStan install directory: /root/.cmdstan |
|
CmdStan version 2.36.0 already installed |
|
Test model compilation |
|
|
|
|
|
DEBUG:cmdstanpy:found newer exe file, not recompiling |
|
DEBUG:cmdstanpy:cmd: /content/flat-logistic info |
|
cwd: None |
|
DEBUG:cmdstanpy:input tempfile: /tmp/tmpy2c2cbr3/vwhyncjv.json |
|
05:05:32 - cmdstanpy - INFO - CmdStan start processing |
|
INFO:cmdstanpy:CmdStan start processing |
|
|
|
|
|
beta=array([ 0.32292202, -1.67712417, 0.80797451, 0.23766868, 0.86741335, |
|
-1.506818 ]) |
|
np.mean(y_NY)=np.float64(0.368) |
|
|
|
|
|
|
|
chain 1 | | 00:00 Status |
|
|
|
|
|
|
|
chain 2 | | 00:00 Status |
|
|
|
|
|
|
|
chain 3 | | 00:00 Status |
|
|
|
|
|
|
|
chain 4 | | 00:00 Status |
|
|
|
|
|
DEBUG:cmdstanpy:idx 0 |
|
DEBUG:cmdstanpy:running CmdStan, num_threads: 1 |
|
DEBUG:cmdstanpy:CmdStan args: ['/content/flat-logistic', 'id=1', 'random', 'seed=583883', 'data', 'file=/tmp/tmpy2c2cbr3/vwhyncjv.json', 'output', 'file=/tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_1.csv', 'method=sample', 'num_samples=2500', 'algorithm=hmc', 'adapt', 'engaged=1'] |
|
DEBUG:cmdstanpy:idx 1 |
|
DEBUG:cmdstanpy:running CmdStan, num_threads: 1 |
|
DEBUG:cmdstanpy:CmdStan args: ['/content/flat-logistic', 'id=2', 'random', 'seed=583883', 'data', 'file=/tmp/tmpy2c2cbr3/vwhyncjv.json', 'output', 'file=/tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_2.csv', 'method=sample', 'num_samples=2500', 'algorithm=hmc', 'adapt', 'engaged=1'] |
|
DEBUG:cmdstanpy:idx 2 |
|
DEBUG:cmdstanpy:running CmdStan, num_threads: 1 |
|
DEBUG:cmdstanpy:CmdStan args: ['/content/flat-logistic', 'id=3', 'random', 'seed=583883', 'data', 'file=/tmp/tmpy2c2cbr3/vwhyncjv.json', 'output', 'file=/tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_3.csv', 'method=sample', 'num_samples=2500', 'algorithm=hmc', 'adapt', 'engaged=1'] |
|
DEBUG:cmdstanpy:idx 3 |
|
DEBUG:cmdstanpy:running CmdStan, num_threads: 1 |
|
DEBUG:cmdstanpy:CmdStan args: ['/content/flat-logistic', 'id=4', 'random', 'seed=583883', 'data', 'file=/tmp/tmpy2c2cbr3/vwhyncjv.json', 'output', 'file=/tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_4.csv', 'method=sample', 'num_samples=2500', 'algorithm=hmc', 'adapt', 'engaged=1'] |
|
|
|
|
|
|
|
|
|
05:05:45 - cmdstanpy - INFO - CmdStan done processing. |
|
INFO:cmdstanpy:CmdStan done processing. |
|
DEBUG:cmdstanpy:runset |
|
RunSet: chains=4, chain_ids=[1, 2, 3, 4], num_processes=4 |
|
cmd (chain 1): |
|
['/content/flat-logistic', 'id=1', 'random', 'seed=583883', 'data', 'file=/tmp/tmpy2c2cbr3/vwhyncjv.json', 'output', 'file=/tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_1.csv', 'method=sample', 'num_samples=2500', 'algorithm=hmc', 'adapt', 'engaged=1'] |
|
retcodes=[0, 0, 0, 0] |
|
per-chain output files (showing chain 1 only): |
|
csv_file: |
|
/tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_1.csv |
|
console_msgs (if any): |
|
/tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_0-stdout.txt |
|
DEBUG:cmdstanpy:Chain 1 console: |
|
method = sample (Default) |
|
sample |
|
num_samples = 2500 |
|
num_warmup = 1000 (Default) |
|
save_warmup = false (Default) |
|
thin = 1 (Default) |
|
adapt |
|
engaged = true (Default) |
|
gamma = 0.05 (Default) |
|
delta = 0.8 (Default) |
|
kappa = 0.75 (Default) |
|
t0 = 10 (Default) |
|
init_buffer = 75 (Default) |
|
term_buffer = 50 (Default) |
|
window = 25 (Default) |
|
save_metric = false (Default) |
|
algorithm = hmc (Default) |
|
hmc |
|
engine = nuts (Default) |
|
nuts |
|
max_depth = 10 (Default) |
|
metric = diag_e (Default) |
|
metric_file = (Default) |
|
stepsize = 1 (Default) |
|
stepsize_jitter = 0 (Default) |
|
num_chains = 1 (Default) |
|
id = 1 (Default) |
|
data |
|
file = /tmp/tmpy2c2cbr3/vwhyncjv.json |
|
init = 2 (Default) |
|
random |
|
seed = 583883 |
|
output |
|
file = /tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_1.csv |
|
diagnostic_file = (Default) |
|
refresh = 100 (Default) |
|
sig_figs = -1 (Default) |
|
profile_file = profile.csv (Default) |
|
save_cmdstan_config = false (Default) |
|
num_threads = 1 (Default) |
|
|
|
|
|
Gradient evaluation took 0.000185 seconds |
|
1000 transitions using 10 leapfrog steps per transition would take 1.85 seconds. |
|
Adjust your expectations accordingly! |
|
|
|
|
|
Iteration: 1 / 3500 [ 0%] (Warmup) |
|
Iteration: 100 / 3500 [ 2%] (Warmup) |
|
Iteration: 200 / 3500 [ 5%] (Warmup) |
|
Iteration: 300 / 3500 [ 8%] (Warmup) |
|
Iteration: 400 / 3500 [ 11%] (Warmup) |
|
Iteration: 500 / 3500 [ 14%] (Warmup) |
|
Iteration: 600 / 3500 [ 17%] (Warmup) |
|
Iteration: 700 / 3500 [ 20%] (Warmup) |
|
Iteration: 800 / 3500 [ 22%] (Warmup) |
|
Iteration: 900 / 3500 [ 25%] (Warmup) |
|
Iteration: 1000 / 3500 [ 28%] (Warmup) |
|
Iteration: 1001 / 3500 [ 28%] (Sampling) |
|
Iteration: 1100 / 3500 [ 31%] (Sampling) |
|
Iteration: 1200 / 3500 [ 34%] (Sampling) |
|
Iteration: 1300 / 3500 [ 37%] (Sampling) |
|
Iteration: 1400 / 3500 [ 40%] (Sampling) |
|
Iteration: 1500 / 3500 [ 42%] (Sampling) |
|
Iteration: 1600 / 3500 [ 45%] (Sampling) |
|
Iteration: 1700 / 3500 [ 48%] (Sampling) |
|
Iteration: 1800 / 3500 [ 51%] (Sampling) |
|
Iteration: 1900 / 3500 [ 54%] (Sampling) |
|
Iteration: 2000 / 3500 [ 57%] (Sampling) |
|
Iteration: 2100 / 3500 [ 60%] (Sampling) |
|
Iteration: 2200 / 3500 [ 62%] (Sampling) |
|
Iteration: 2300 / 3500 [ 65%] (Sampling) |
|
Iteration: 2400 / 3500 [ 68%] (Sampling) |
|
Iteration: 2500 / 3500 [ 71%] (Sampling) |
|
Iteration: 2600 / 3500 [ 74%] (Sampling) |
|
Iteration: 2700 / 3500 [ 77%] (Sampling) |
|
Iteration: 2800 / 3500 [ 80%] (Sampling) |
|
Iteration: 2900 / 3500 [ 82%] (Sampling) |
|
Iteration: 3000 / 3500 [ 85%] (Sampling) |
|
Iteration: 3100 / 3500 [ 88%] (Sampling) |
|
Iteration: 3200 / 3500 [ 91%] (Sampling) |
|
Iteration: 3300 / 3500 [ 94%] (Sampling) |
|
Iteration: 3400 / 3500 [ 97%] (Sampling) |
|
Iteration: 3500 / 3500 [100%] (Sampling) |
|
|
|
Elapsed Time: 1.391 seconds (Warm-up) |
|
3.46 seconds (Sampling) |
|
4.851 seconds (Total) |
|
|
|
|
|
DEBUG:cmdstanpy:cmd: /root/.cmdstan/cmdstan-2.36.0/bin/stansummary --percentiles= 5,50,95 --sig_figs=6 --csv_filename=/tmp/tmpy2c2cbr3/stansummary-flat-logistic-_dkv4lm4.csv /tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_1.csv /tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_2.csv /tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_3.csv /tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_4.csv |
|
cwd: None |
|
|
|
|
|
|
|
|
|
|
|
DEBUG:cmdstanpy:found newer exe file, not recompiling |
|
DEBUG:cmdstanpy:cmd: /content/empirical-logistic info |
|
cwd: None |
|
DEBUG:cmdstanpy:input tempfile: /tmp/tmpy2c2cbr3/yg69ajzg.json |
|
05:05:45 - cmdstanpy - INFO - CmdStan start processing |
|
INFO:cmdstanpy:CmdStan start processing |
|
DEBUG:cmdstanpy:idx 0 |
|
DEBUG:cmdstanpy:running CmdStan, num_threads: 1 |
|
DEBUG:cmdstanpy:CmdStan args: ['/content/empirical-logistic', 'id=1', 'random', 'seed=583883', 'data', 'file=/tmp/tmpy2c2cbr3/yg69ajzg.json', 'output', 'file=/tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_1.csv', 'method=sample', 'num_samples=500', 'num_warmup=500', 'algorithm=hmc', 'adapt', 'engaged=1'] |
|
05:05:45 - cmdstanpy - INFO - Chain [1] start processing |
|
DEBUG:cmdstanpy:idx 1 |
|
DEBUG:cmdstanpy:running CmdStan, num_threads: 1 |
|
DEBUG:cmdstanpy:CmdStan args: ['/content/empirical-logistic', 'id=2', 'random', 'seed=583883', 'data', 'file=/tmp/tmpy2c2cbr3/yg69ajzg.json', 'output', 'file=/tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_2.csv', 'method=sample', 'num_samples=500', 'num_warmup=500', 'algorithm=hmc', 'adapt', 'engaged=1'] |
|
|
|
|
|
Mean MCSE StdDev MAD 5% 50% \ |
|
lp__ -557.406000 0.026795 1.770460 1.568590 -560.791000 -557.063000 |
|
beta[1] 0.366866 0.000819 0.072424 0.071326 0.250861 0.366657 |
|
beta[2] -1.789300 0.001410 0.107540 0.107689 -1.969140 -1.785910 |
|
beta[3] 0.917581 0.000989 0.079849 0.078577 0.787154 0.916908 |
|
beta[4] 0.245058 0.000795 0.072760 0.073003 0.124279 0.245038 |
|
beta[5] 0.826520 0.000930 0.078852 0.078082 0.699908 0.825228 |
|
beta[6] -1.578740 0.001312 0.098768 0.098519 -1.742420 -1.576840 |
|
|
|
95% ESS_bulk ESS_tail R_hat |
|
lp__ -555.199000 4633.58 5881.42 0.999884 |
|
beta[1] 0.487298 7869.64 6899.19 1.000120 |
|
beta[2] -1.618440 5936.46 6436.66 1.000000 |
|
beta[3] 1.051410 6569.76 6402.72 1.000210 |
|
beta[4] 0.364416 8444.90 6252.19 1.000350 |
|
beta[5] 0.956443 7271.80 6691.55 1.000470 |
|
beta[6] -1.418180 5713.45 5901.44 1.000730 |
|
|
|
|
|
05:05:45 - cmdstanpy - INFO - Chain [2] start processing |
|
INFO:cmdstanpy:Chain [1] start processing |
|
INFO:cmdstanpy:Chain [2] start processing |
|
05:06:48 - cmdstanpy - INFO - Chain [2] done processing |
|
INFO:cmdstanpy:Chain [2] done processing |
|
DEBUG:cmdstanpy:idx 2 |
|
DEBUG:cmdstanpy:running CmdStan, num_threads: 1 |
|
DEBUG:cmdstanpy:CmdStan args: ['/content/empirical-logistic', 'id=3', 'random', 'seed=583883', 'data', 'file=/tmp/tmpy2c2cbr3/yg69ajzg.json', 'output', 'file=/tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_3.csv', 'method=sample', 'num_samples=500', 'num_warmup=500', 'algorithm=hmc', 'adapt', 'engaged=1'] |
|
05:06:48 - cmdstanpy - INFO - Chain [3] start processing |
|
INFO:cmdstanpy:Chain [3] start processing |
|
05:06:50 - cmdstanpy - INFO - Chain [1] done processing |
|
INFO:cmdstanpy:Chain [1] done processing |
|
DEBUG:cmdstanpy:idx 3 |
|
DEBUG:cmdstanpy:running CmdStan, num_threads: 1 |
|
DEBUG:cmdstanpy:CmdStan args: ['/content/empirical-logistic', 'id=4', 'random', 'seed=583883', 'data', 'file=/tmp/tmpy2c2cbr3/yg69ajzg.json', 'output', 'file=/tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_4.csv', 'method=sample', 'num_samples=500', 'num_warmup=500', 'algorithm=hmc', 'adapt', 'engaged=1'] |
|
05:06:50 - cmdstanpy - INFO - Chain [4] start processing |
|
INFO:cmdstanpy:Chain [4] start processing |
|
05:07:50 - cmdstanpy - INFO - Chain [3] done processing |
|
INFO:cmdstanpy:Chain [3] done processing |
|
05:07:53 - cmdstanpy - INFO - Chain [4] done processing |
|
INFO:cmdstanpy:Chain [4] done processing |
|
DEBUG:cmdstanpy:runset |
|
RunSet: chains=4, chain_ids=[1, 2, 3, 4], num_processes=4 |
|
cmd (chain 1): |
|
['/content/empirical-logistic', 'id=1', 'random', 'seed=583883', 'data', 'file=/tmp/tmpy2c2cbr3/yg69ajzg.json', 'output', 'file=/tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_1.csv', 'method=sample', 'num_samples=500', 'num_warmup=500', 'algorithm=hmc', 'adapt', 'engaged=1'] |
|
retcodes=[0, 0, 0, 0] |
|
per-chain output files (showing chain 1 only): |
|
csv_file: |
|
/tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_1.csv |
|
console_msgs (if any): |
|
/tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_0-stdout.txt |
|
DEBUG:cmdstanpy:Chain 1 console: |
|
method = sample (Default) |
|
sample |
|
num_samples = 500 |
|
num_warmup = 500 |
|
save_warmup = false (Default) |
|
thin = 1 (Default) |
|
adapt |
|
engaged = true (Default) |
|
gamma = 0.05 (Default) |
|
delta = 0.8 (Default) |
|
kappa = 0.75 (Default) |
|
t0 = 10 (Default) |
|
init_buffer = 75 (Default) |
|
term_buffer = 50 (Default) |
|
window = 25 (Default) |
|
save_metric = false (Default) |
|
algorithm = hmc (Default) |
|
hmc |
|
engine = nuts (Default) |
|
nuts |
|
max_depth = 10 (Default) |
|
metric = diag_e (Default) |
|
metric_file = (Default) |
|
stepsize = 1 (Default) |
|
stepsize_jitter = 0 (Default) |
|
num_chains = 1 (Default) |
|
id = 1 (Default) |
|
data |
|
file = /tmp/tmpy2c2cbr3/yg69ajzg.json |
|
init = 2 (Default) |
|
random |
|
seed = 583883 |
|
output |
|
file = /tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_1.csv |
|
diagnostic_file = (Default) |
|
refresh = 100 (Default) |
|
sig_figs = -1 (Default) |
|
profile_file = profile.csv (Default) |
|
save_cmdstan_config = false (Default) |
|
num_threads = 1 (Default) |
|
|
|
|
|
Gradient evaluation took 0.008579 seconds |
|
1000 transitions using 10 leapfrog steps per transition would take 85.79 seconds. |
|
Adjust your expectations accordingly! |
|
|
|
|
|
Iteration: 1 / 1000 [ 0%] (Warmup) |
|
Iteration: 100 / 1000 [ 10%] (Warmup) |
|
Iteration: 200 / 1000 [ 20%] (Warmup) |
|
Iteration: 300 / 1000 [ 30%] (Warmup) |
|
Iteration: 400 / 1000 [ 40%] (Warmup) |
|
Iteration: 500 / 1000 [ 50%] (Warmup) |
|
Iteration: 501 / 1000 [ 50%] (Sampling) |
|
Iteration: 600 / 1000 [ 60%] (Sampling) |
|
Iteration: 700 / 1000 [ 70%] (Sampling) |
|
Iteration: 800 / 1000 [ 80%] (Sampling) |
|
Iteration: 900 / 1000 [ 90%] (Sampling) |
|
Iteration: 1000 / 1000 [100%] (Sampling) |
|
|
|
Elapsed Time: 32.347 seconds (Warm-up) |
|
32.174 seconds (Sampling) |
|
64.521 seconds (Total) |
|
|
|
|
|
DEBUG:cmdstanpy:cmd: /root/.cmdstan/cmdstan-2.36.0/bin/stansummary --percentiles= 5,50,95 --sig_figs=6 --csv_filename=/tmp/tmpy2c2cbr3/stansummary-empirical-logistic-v4gv2feg.csv /tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_1.csv /tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_2.csv /tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_3.csv /tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_4.csv |
|
cwd: None |
|
|
|
|
|
Mean MCSE StdDev MAD 5% 50% \ |
|
lp__ -559.231000 0.052517 1.670540 1.512990 -562.479000 -558.881000 |
|
beta[1] 0.350303 0.001256 0.053163 0.055772 0.264625 0.349386 |
|
beta[2] -1.684520 0.002128 0.072740 0.073189 -1.806480 -1.683590 |
|
beta[3] 0.816993 0.001509 0.058009 0.055932 0.720248 0.818284 |
|
beta[4] 0.308557 0.001222 0.052301 0.054825 0.222648 0.310081 |
|
beta[5] 0.807412 0.001536 0.056870 0.057270 0.714319 0.807694 |
|
beta[6] -1.547500 0.002062 0.069704 0.068882 -1.664110 -1.546650 |
|
|
|
95% ESS_bulk ESS_tail R_hat |
|
lp__ -557.139000 1042.35 1121.73 1.007360 |
|
beta[1] 0.440071 1791.38 1473.59 1.000620 |
|
beta[2] -1.568570 1172.88 1269.70 1.001860 |
|
beta[3] 0.910145 1522.17 1148.50 0.999944 |
|
beta[4] 0.389971 1861.81 1350.30 1.001190 |
|
beta[5] 0.902555 1385.68 1368.33 1.002960 |
|
beta[6] -1.435650 1155.15 1153.33 0.999938 </code></pre> |
|
<div class="sourceCode" id="cb3"><pre |
|
class="sourceCode python"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> fit_wasserstein_prior(draws):</span> |
|
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a> <span class="co">"""</span></span> |
|
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="co"> Returns the Gaussian (mu, Sigma) that minimizes W2 distance</span></span> |
|
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a><span class="co"> to the empirical measure of 'draws' via matching first two moments.</span></span> |
|
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="co"> """</span></span> |
|
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a> mu <span class="op">=</span> np.mean(draws, axis<span class="op">=</span><span class="dv">0</span>)</span> |
|
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a> Sigma <span class="op">=</span> np.cov(draws, rowvar<span class="op">=</span><span class="va">False</span>)</span> |
|
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> mu, Sigma</span> |
|
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a><span class="co"># After fitting the Ohio model:</span></span> |
|
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>beta_OH_draws <span class="op">=</span> fit_OH.stan_variable(<span class="st">'beta'</span>) <span class="co"># shape (B, D)</span></span> |
|
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Fit the W2‐optimal Gaussian prior</span></span> |
|
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>mu_prior, Sigma_prior <span class="op">=</span> fit_wasserstein_prior(beta_OH_draws)</span> |
|
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Prepare data for the next Stan model (NY with Gaussian prior)</span></span> |
|
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>data_NY_wass <span class="op">=</span> {</span> |
|
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a> <span class="st">'N'</span>: N,</span> |
|
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a> <span class="st">'D'</span>: D,</span> |
|
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a> <span class="st">'x'</span>: x_NY,</span> |
|
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a> <span class="st">'y'</span>: y_NY,</span> |
|
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a> <span class="st">'mu_prior'</span>: mu_prior,</span> |
|
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a> <span class="st">'Sigma_prior'</span>: Sigma_prior</span> |
|
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>}</span> |
|
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a>model_NY_wass <span class="op">=</span> csp.CmdStanModel(stan_file<span class="op">=</span><span class="st">'gaussian-prior-logistic.stan'</span>)</span> |
|
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a>fit_NY_wass <span class="op">=</span> model_NY_wass.sample(data<span class="op">=</span>data_NY_wass, chains<span class="op">=</span><span class="dv">4</span>, iter_warmup<span class="op">=</span><span class="dv">500</span>, iter_sampling<span class="op">=</span><span class="dv">500</span>, seed<span class="op">=</span>seed)</span> |
|
<span id="cb3-29"><a href="#cb3-29" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(fit_NY_wass.summary())</span></code></pre></div> |
|
<pre><code>DEBUG:cmdstanpy:found newer exe file, not recompiling |
|
DEBUG:cmdstanpy:cmd: /content/gaussian-prior-logistic info |
|
cwd: None |
|
DEBUG:cmdstanpy:input tempfile: /tmp/tmpy2c2cbr3/mgqi9mu9.json |
|
05:07:53 - cmdstanpy - INFO - CmdStan start processing |
|
INFO:cmdstanpy:CmdStan start processing |
|
|
|
|
|
|
|
chain 1 | | 00:00 Status |
|
|
|
|
|
|
|
chain 2 | | 00:00 Status |
|
|
|
|
|
|
|
chain 3 | | 00:00 Status |
|
|
|
|
|
|
|
chain 4 | | 00:00 Status |
|
|
|
|
|
DEBUG:cmdstanpy:idx 0 |
|
DEBUG:cmdstanpy:running CmdStan, num_threads: 1 |
|
DEBUG:cmdstanpy:idx 1 |
|
DEBUG:cmdstanpy:running CmdStan, num_threads: 1 |
|
DEBUG:cmdstanpy:CmdStan args: ['/content/gaussian-prior-logistic', 'id=2', 'random', 'seed=583883', 'data', 'file=/tmp/tmpy2c2cbr3/mgqi9mu9.json', 'output', 'file=/tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_2.csv', 'method=sample', 'num_samples=500', 'num_warmup=500', 'algorithm=hmc', 'adapt', 'engaged=1'] |
|
DEBUG:cmdstanpy:CmdStan args: ['/content/gaussian-prior-logistic', 'id=1', 'random', 'seed=583883', 'data', 'file=/tmp/tmpy2c2cbr3/mgqi9mu9.json', 'output', 'file=/tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_1.csv', 'method=sample', 'num_samples=500', 'num_warmup=500', 'algorithm=hmc', 'adapt', 'engaged=1'] |
|
DEBUG:cmdstanpy:idx 2 |
|
DEBUG:cmdstanpy:running CmdStan, num_threads: 1 |
|
DEBUG:cmdstanpy:CmdStan args: ['/content/gaussian-prior-logistic', 'id=3', 'random', 'seed=583883', 'data', 'file=/tmp/tmpy2c2cbr3/mgqi9mu9.json', 'output', 'file=/tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_3.csv', 'method=sample', 'num_samples=500', 'num_warmup=500', 'algorithm=hmc', 'adapt', 'engaged=1'] |
|
DEBUG:cmdstanpy:idx 3 |
|
DEBUG:cmdstanpy:running CmdStan, num_threads: 1 |
|
DEBUG:cmdstanpy:CmdStan args: ['/content/gaussian-prior-logistic', 'id=4', 'random', 'seed=583883', 'data', 'file=/tmp/tmpy2c2cbr3/mgqi9mu9.json', 'output', 'file=/tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_4.csv', 'method=sample', 'num_samples=500', 'num_warmup=500', 'algorithm=hmc', 'adapt', 'engaged=1'] |
|
|
|
|
|
|
|
|
|
05:07:58 - cmdstanpy - INFO - CmdStan done processing. |
|
INFO:cmdstanpy:CmdStan done processing. |
|
DEBUG:cmdstanpy:runset |
|
RunSet: chains=4, chain_ids=[1, 2, 3, 4], num_processes=4 |
|
cmd (chain 1): |
|
['/content/gaussian-prior-logistic', 'id=1', 'random', 'seed=583883', 'data', 'file=/tmp/tmpy2c2cbr3/mgqi9mu9.json', 'output', 'file=/tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_1.csv', 'method=sample', 'num_samples=500', 'num_warmup=500', 'algorithm=hmc', 'adapt', 'engaged=1'] |
|
retcodes=[0, 0, 0, 0] |
|
per-chain output files (showing chain 1 only): |
|
csv_file: |
|
/tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_1.csv |
|
console_msgs (if any): |
|
/tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_0-stdout.txt |
|
DEBUG:cmdstanpy:Chain 1 console: |
|
method = sample (Default) |
|
sample |
|
num_samples = 500 |
|
num_warmup = 500 |
|
save_warmup = false (Default) |
|
thin = 1 (Default) |
|
adapt |
|
engaged = true (Default) |
|
gamma = 0.05 (Default) |
|
delta = 0.8 (Default) |
|
kappa = 0.75 (Default) |
|
t0 = 10 (Default) |
|
init_buffer = 75 (Default) |
|
term_buffer = 50 (Default) |
|
window = 25 (Default) |
|
save_metric = false (Default) |
|
algorithm = hmc (Default) |
|
hmc |
|
engine = nuts (Default) |
|
nuts |
|
max_depth = 10 (Default) |
|
metric = diag_e (Default) |
|
metric_file = (Default) |
|
stepsize = 1 (Default) |
|
stepsize_jitter = 0 (Default) |
|
num_chains = 1 (Default) |
|
id = 1 (Default) |
|
data |
|
file = /tmp/tmpy2c2cbr3/mgqi9mu9.json |
|
init = 2 (Default) |
|
random |
|
seed = 583883 |
|
output |
|
file = /tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_1.csv |
|
diagnostic_file = (Default) |
|
refresh = 100 (Default) |
|
sig_figs = -1 (Default) |
|
profile_file = profile.csv (Default) |
|
save_cmdstan_config = false (Default) |
|
num_threads = 1 (Default) |
|
|
|
|
|
Gradient evaluation took 0.000179 seconds |
|
1000 transitions using 10 leapfrog steps per transition would take 1.79 seconds. |
|
Adjust your expectations accordingly! |
|
|
|
|
|
Iteration: 1 / 1000 [ 0%] (Warmup) |
|
Iteration: 100 / 1000 [ 10%] (Warmup) |
|
Iteration: 200 / 1000 [ 20%] (Warmup) |
|
Iteration: 300 / 1000 [ 30%] (Warmup) |
|
Iteration: 400 / 1000 [ 40%] (Warmup) |
|
Iteration: 500 / 1000 [ 50%] (Warmup) |
|
Iteration: 501 / 1000 [ 50%] (Sampling) |
|
Iteration: 600 / 1000 [ 60%] (Sampling) |
|
Iteration: 700 / 1000 [ 70%] (Sampling) |
|
Iteration: 800 / 1000 [ 80%] (Sampling) |
|
Iteration: 900 / 1000 [ 90%] (Sampling) |
|
Iteration: 1000 / 1000 [100%] (Sampling) |
|
|
|
Elapsed Time: 0.998 seconds (Warm-up) |
|
1.215 seconds (Sampling) |
|
2.213 seconds (Total) |
|
|
|
|
|
DEBUG:cmdstanpy:cmd: /root/.cmdstan/cmdstan-2.36.0/bin/stansummary --percentiles= 5,50,95 --sig_figs=6 --csv_filename=/tmp/tmpy2c2cbr3/stansummary-gaussian-prior-logistic-w4iqkew2.csv /tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_1.csv /tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_2.csv /tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_3.csv /tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_4.csv |
|
cwd: None |
|
|
|
|
|
|
|
Mean MCSE StdDev MAD 5% 50% \ |
|
lp__ -568.833000 0.061184 1.775250 1.637530 -572.075000 -568.518000 |
|
beta[1] 0.352418 0.001256 0.049502 0.048350 0.271975 0.351888 |
|
beta[2] -1.696680 0.002096 0.075393 0.072381 -1.818390 -1.698950 |
|
beta[3] 0.824647 0.001493 0.057310 0.057457 0.732214 0.822308 |
|
beta[4] 0.301689 0.001300 0.050560 0.050508 0.216780 0.301548 |
|
beta[5] 0.808698 0.001430 0.056148 0.054268 0.714924 0.808041 |
|
beta[6] -1.550370 0.001959 0.071111 0.067636 -1.671050 -1.548620 |
|
|
|
95% ESS_bulk ESS_tail R_hat |
|
lp__ -566.585000 881.262 1143.34 1.002360 |
|
beta[1] 0.434444 1584.370 1544.86 1.001480 |
|
beta[2] -1.570150 1307.350 1407.69 0.999047 |
|
beta[3] 0.922475 1502.040 1410.28 1.002310 |
|
beta[4] 0.386601 1520.230 1296.52 1.000290 |
|
beta[5] 0.901743 1563.190 1176.63 1.000660 |
|
beta[6] -1.436520 1349.680 1310.96 1.000150 </code></pre> |
|
<div class="sourceCode" id="cb5"><pre |
|
class="sourceCode python"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> matplotlib.pyplot <span class="im">as</span> plt</span> |
|
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> seaborn <span class="im">as</span> sns</span> |
|
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> plot_prior_comparison(empirical_draws, mu_wass, Sigma_wass, h_kde):</span> |
|
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a> <span class="co">"""Compare KDE vs Wasserstein approximations visually"""</span></span> |
|
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a> fig, axes <span class="op">=</span> plt.subplots(<span class="dv">2</span>, <span class="dv">3</span>, figsize<span class="op">=</span>(<span class="dv">15</span>, <span class="dv">10</span>))</span> |
|
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> dim <span class="kw">in</span> <span class="bu">range</span>(<span class="bu">min</span>(<span class="dv">6</span>, D)):</span> |
|
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a> ax <span class="op">=</span> axes[dim <span class="op">//</span> <span class="dv">3</span>, dim <span class="op">%</span> <span class="dv">3</span>]</span> |
|
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a> <span class="co"># Plot empirical posterior</span></span> |
|
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a> ax.hist(empirical_draws[:, dim], bins<span class="op">=</span><span class="dv">50</span>, alpha<span class="op">=</span><span class="fl">0.3</span>,</span> |
|
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a> label<span class="op">=</span><span class="st">'Empirical posterior'</span>, density<span class="op">=</span><span class="va">True</span>)</span> |
|
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a> <span class="co"># Plot KDE approximation</span></span> |
|
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a> x_range <span class="op">=</span> np.linspace(empirical_draws[:, dim].<span class="bu">min</span>(),</span> |
|
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a> empirical_draws[:, dim].<span class="bu">max</span>(), <span class="dv">100</span>)</span> |
|
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a> kde_samples <span class="op">=</span> rng.normal(mu_wass[dim], np.sqrt(Sigma_wass[dim, dim]), <span class="dv">10000</span>)</span> |
|
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a> ax.hist(kde_samples, bins<span class="op">=</span><span class="dv">50</span>, alpha<span class="op">=</span><span class="fl">0.3</span>,</span> |
|
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a> label<span class="op">=</span><span class="st">'Wasserstein prior'</span>, density<span class="op">=</span><span class="va">True</span>)</span> |
|
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a> ax.set_title(<span class="ss">f'Dimension </span><span class="sc">{</span>dim<span class="sc">}</span><span class="ss">'</span>)</span> |
|
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a> ax.legend()</span> |
|
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a> plt.tight_layout()</span> |
|
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a> plt.show()</span> |
|
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a><span class="co"># After fitting both methods</span></span> |
|
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a>plot_prior_comparison(beta_OH_draws, mu_prior, Sigma_prior, h)</span></code></pre></div> |
|
<figure> |
|
<img src="wasserstein_prior_files/wasserstein_prior_2_0.png" |
|
alt="png" /> |
|
<figcaption aria-hidden="true">png</figcaption> |
|
</figure> |
|
<div class="sourceCode" id="cb6"><pre |
|
class="sourceCode python"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> posterior_predictive_check(x_test, beta_draws, beta_true):</span> |
|
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a> <span class="co">"""Check if posterior draws generate reasonable predictions"""</span></span> |
|
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a> <span class="co"># Generate predictions from posterior draws</span></span> |
|
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a> pred_probs <span class="op">=</span> []</span> |
|
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> beta_draw <span class="kw">in</span> beta_draws[<span class="op">-</span><span class="dv">1000</span>:]: <span class="co"># Use last 1000 draws</span></span> |
|
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a> pred_probs.append(expit(x_test <span class="op">@</span> beta_draw))</span> |
|
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a> pred_probs <span class="op">=</span> np.array(pred_probs)</span> |
|
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a> <span class="co"># True probabilities</span></span> |
|
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a> true_probs <span class="op">=</span> expit(x_test <span class="op">@</span> beta_true)</span> |
|
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a> <span class="co"># Check coverage</span></span> |
|
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a> pred_mean <span class="op">=</span> np.mean(pred_probs, axis<span class="op">=</span><span class="dv">0</span>)</span> |
|
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a> pred_lower <span class="op">=</span> np.percentile(pred_probs, <span class="fl">2.5</span>, axis<span class="op">=</span><span class="dv">0</span>)</span> |
|
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a> pred_upper <span class="op">=</span> np.percentile(pred_probs, <span class="fl">97.5</span>, axis<span class="op">=</span><span class="dv">0</span>)</span> |
|
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a> coverage <span class="op">=</span> np.mean((true_probs <span class="op">>=</span> pred_lower) <span class="op">&</span> (true_probs <span class="op"><=</span> pred_upper))</span> |
|
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a> <span class="bu">print</span>(<span class="ss">f"95% credible interval coverage: </span><span class="sc">{</span>coverage<span class="sc">:.3f}</span><span class="ss">"</span>)</span> |
|
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a> <span class="co"># Plot calibration</span></span> |
|
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a> plt.figure(figsize<span class="op">=</span>(<span class="dv">10</span>, <span class="dv">6</span>))</span> |
|
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a> plt.subplot(<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">1</span>)</span> |
|
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a> plt.scatter(true_probs, pred_mean, alpha<span class="op">=</span><span class="fl">0.5</span>)</span> |
|
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a> plt.plot([<span class="dv">0</span>, <span class="dv">1</span>], [<span class="dv">0</span>, <span class="dv">1</span>], <span class="st">'r--'</span>)</span> |
|
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a> plt.xlabel(<span class="st">'True probability'</span>)</span> |
|
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a> plt.ylabel(<span class="st">'Predicted probability'</span>)</span> |
|
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a> plt.title(<span class="st">'Calibration'</span>)</span> |
|
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a> <span class="co"># Plot prediction intervals</span></span> |
|
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a> plt.subplot(<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">2</span>)</span> |
|
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a> order <span class="op">=</span> np.argsort(true_probs)</span> |
|
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a> plt.fill_between(<span class="bu">range</span>(<span class="bu">len</span>(order)), pred_lower[order], pred_upper[order], alpha<span class="op">=</span><span class="fl">0.3</span>)</span> |
|
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a> plt.plot(true_probs[order], <span class="st">'r-'</span>, label<span class="op">=</span><span class="st">'True'</span>)</span> |
|
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a> plt.plot(pred_mean[order], <span class="st">'b-'</span>, label<span class="op">=</span><span class="st">'Predicted'</span>)</span> |
|
<span id="cb6-35"><a href="#cb6-35" aria-hidden="true" tabindex="-1"></a> plt.legend()</span> |
|
<span id="cb6-36"><a href="#cb6-36" aria-hidden="true" tabindex="-1"></a> plt.title(<span class="st">'Prediction intervals'</span>)</span> |
|
<span id="cb6-37"><a href="#cb6-37" aria-hidden="true" tabindex="-1"></a> plt.tight_layout()</span> |
|
<span id="cb6-38"><a href="#cb6-38" aria-hidden="true" tabindex="-1"></a> plt.show()</span> |
|
<span id="cb6-39"><a href="#cb6-39" aria-hidden="true" tabindex="-1"></a></span> |
|
<span id="cb6-40"><a href="#cb6-40" aria-hidden="true" tabindex="-1"></a><span class="co"># Test all three methods</span></span> |
|
<span id="cb6-41"><a href="#cb6-41" aria-hidden="true" tabindex="-1"></a>x_test <span class="op">=</span> rng.normal(loc<span class="op">=</span><span class="fl">0.0</span>, scale<span class="op">=</span><span class="fl">1.0</span>, size<span class="op">=</span>(<span class="dv">500</span>, D))</span> |
|
<span id="cb6-42"><a href="#cb6-42" aria-hidden="true" tabindex="-1"></a>posterior_predictive_check(x_test, fit_NY.stan_variable(<span class="st">'beta'</span>), beta)</span></code></pre></div> |
|
<pre><code>95% credible interval coverage: 1.000</code></pre> |
|
<figure> |
|
<img src="wasserstein_prior_files/wasserstein_prior_3_1.png" |
|
alt="png" /> |
|
<figcaption aria-hidden="true">png</figcaption> |
|
</figure> |
|
<div class="sourceCode" id="cb8"><pre |
|
class="sourceCode python"><code class="sourceCode python"></code></pre></div> |
|
</body> |
|
</html> |