Skip to content

Instantly share code, notes, and snippets.

@soodoku
Last active May 14, 2025 05:52
Show Gist options
  • Save soodoku/12342c1517c081579b9ca01cd4d4bd32 to your computer and use it in GitHub Desktop.
Save soodoku/12342c1517c081579b9ca01cd4d4bd32 to your computer and use it in GitHub Desktop.
Wasserstein Prior

Chaining Bayesian Inferences with Wasserstein-Optimal Priors: A Simpler Alternative to Kernel Density Approximation

Builds from: https://statmodeling.stat.columbia.edu/2025/05/13/chaining-bayes-priors-from-posteriors/

The Problem

Bayesian inference becomes challenging when you need to chain analyses across sequential datasets but lack conjugate priors with closed-form posteriors. While textbook solutions work beautifully for simple cases like binomial-beta models, real-world scenarios often involve complex posteriors that resist analytical treatment.

Chenyang Zhong has an elegant solution using kernel density estimation to approximate posteriors as priors for subsequent analyses. Zhong approximates the posterior from the first dataset using a mixture of normals: p(θ | y₁, x₁) ≈ (1/M) Σᵢ Normal(θ | θᵢ⁽¹⁾, h·I)

where θᵢ⁽¹⁾ are posterior draws from the first analysis, and h controls the kernel bandwidth. This creates an empirical prior for analyzing the second dataset.

The approach is clever—it preserves the shape of the posterior while allowing efficient MCMC sampling. Zhong even develops a sophisticated nearest-neighbor method for fast Metropolis sampling.

A Wasserstein Alternative

The Gaussian distribution that minimizes Wasserstein-2 distance to any empirical distribution is simply the one matching the first two moments. This means we can replace Zhong's kernel density approximation with:

fit_wasserstein_prior(posterior_draws):
    """Wasserstein-optimal Gaussian approximation"""
    mu = np.mean(posterior_draws, axis=0)
    Sigma = np.cov(posterior_draws.T)
    return mu, Sigma

That's it. No optimization, no bandwidth selection, no nearest-neighbor graphs.

Implementation Comparison

Zhong's approach requires:

Choosing bandwidth h Implementing mixture model in Stan Complex log-sum-exp calculations

Wasserstein approach:

Automatic parameter selection Simple multivariate normal prior Standard Stan implementation Comparable or faster performance

Here's the key Stan difference:

stan// Zhong's mixture approach
model {
  y ~ bernoulli_logit(x * beta);
  vector[B] lp;
  for (b in 1:B) {
    lp[b] = normal_lpdf(beta | beta0[b], h);
  }
  target += log_sum_exp(lp) - log(B);
}

// Wasserstein approach  
model {
  y ~ bernoulli_logit(x * beta);
  beta ~ multi_normal(mu_prior, Sigma_prior);
}

When Does This Work?

The Wasserstein approximation is optimal when:

The true posterior is approximately Gaussian You want to preserve the covariance structure Computational simplicity matters

It may underperform when posteriors have:

Strong multimodality Heavy tails Complex dependence structures

data {
int<lower=0> N;
int<lower=0> D;
matrix[N, D] x;
array[N] int<lower=0, upper=1> y;
real<lower=0> h;
int<lower=0> B;
array[B] vector[D] beta0;
}
parameters {
vector[D] beta;
}
model {
y ~ bernoulli_logit(x * beta);
vector[B] lp;
for (b in 1:B) {
lp[b] = normal_lpdf(beta | beta0[b], h);
}
target += log_sum_exp(lp) - log(B);
}
data {
int<lower=0> N;
int<lower=0> D;
matrix[N, D] x;
array[N] int<lower=0, upper=1> y;
}
parameters {
vector[D] beta;
}
model {
y ~ bernoulli_logit(x * beta);
}
data {
int<lower=0> N;
int<lower=0> D;
matrix[N, D] x;
array[N] int<lower=0, upper=1> y;
vector[D] mu_prior;
cov_matrix[D] Sigma_prior;
}
parameters {
vector[D] beta;
}
model {
y ~ bernoulli_logit(x * beta);
beta ~ multi_normal(mu_prior, Sigma_prior);
}
<!DOCTYPE html>
<html xmlns="http://www.w3.org/1999/xhtml" lang="" xml:lang="">
<head>
<meta charset="utf-8" />
<meta name="generator" content="pandoc" />
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes" />
<title>wasserstein_prior</title>
<style>
html {
color: #1a1a1a;
background-color: #fdfdfd;
}
body {
margin: 0 auto;
max-width: 36em;
padding-left: 50px;
padding-right: 50px;
padding-top: 50px;
padding-bottom: 50px;
hyphens: auto;
overflow-wrap: break-word;
text-rendering: optimizeLegibility;
font-kerning: normal;
}
@media (max-width: 600px) {
body {
font-size: 0.9em;
padding: 12px;
}
h1 {
font-size: 1.8em;
}
}
@media print {
html {
background-color: white;
}
body {
background-color: transparent;
color: black;
font-size: 12pt;
}
p, h2, h3 {
orphans: 3;
widows: 3;
}
h2, h3, h4 {
page-break-after: avoid;
}
}
p {
margin: 1em 0;
}
a {
color: #1a1a1a;
}
a:visited {
color: #1a1a1a;
}
img {
max-width: 100%;
}
svg {
height: auto;
max-width: 100%;
}
h1, h2, h3, h4, h5, h6 {
margin-top: 1.4em;
}
h5, h6 {
font-size: 1em;
font-style: italic;
}
h6 {
font-weight: normal;
}
ol, ul {
padding-left: 1.7em;
margin-top: 1em;
}
li > ol, li > ul {
margin-top: 0;
}
blockquote {
margin: 1em 0 1em 1.7em;
padding-left: 1em;
border-left: 2px solid #e6e6e6;
color: #606060;
}
code {
font-family: Menlo, Monaco, Consolas, 'Lucida Console', monospace;
font-size: 85%;
margin: 0;
hyphens: manual;
}
pre {
margin: 1em 0;
overflow: auto;
}
pre code {
padding: 0;
overflow: visible;
overflow-wrap: normal;
}
.sourceCode {
background-color: transparent;
overflow: visible;
}
hr {
border: none;
border-top: 1px solid #1a1a1a;
height: 1px;
margin: 1em 0;
}
table {
margin: 1em 0;
border-collapse: collapse;
width: 100%;
overflow-x: auto;
display: block;
font-variant-numeric: lining-nums tabular-nums;
}
table caption {
margin-bottom: 0.75em;
}
tbody {
margin-top: 0.5em;
border-top: 1px solid #1a1a1a;
border-bottom: 1px solid #1a1a1a;
}
th {
border-top: 1px solid #1a1a1a;
padding: 0.25em 0.5em 0.25em 0.5em;
}
td {
padding: 0.125em 0.5em 0.25em 0.5em;
}
header {
margin-bottom: 4em;
text-align: center;
}
#TOC li {
list-style: none;
}
#TOC ul {
padding-left: 1.3em;
}
#TOC > ul {
padding-left: 0;
}
#TOC a:not(:hover) {
text-decoration: none;
}
code{white-space: pre-wrap;}
span.smallcaps{font-variant: small-caps;}
div.columns{display: flex; gap: min(4vw, 1.5em);}
div.column{flex: auto; overflow-x: auto;}
div.hanging-indent{margin-left: 1.5em; text-indent: -1.5em;}
/* The extra [class] is a hack that increases specificity enough to
override a similar rule in reveal.js */
ul.task-list[class]{list-style: none;}
ul.task-list li input[type="checkbox"] {
font-size: inherit;
width: 0.8em;
margin: 0 0.8em 0.2em -1.6em;
vertical-align: middle;
}
/* CSS for syntax highlighting */
html { -webkit-text-size-adjust: 100%; }
pre > code.sourceCode { white-space: pre; position: relative; }
pre > code.sourceCode > span { display: inline-block; line-height: 1.25; }
pre > code.sourceCode > span:empty { height: 1.2em; }
.sourceCode { overflow: visible; }
code.sourceCode > span { color: inherit; text-decoration: inherit; }
div.sourceCode { margin: 1em 0; }
pre.sourceCode { margin: 0; }
@media screen {
div.sourceCode { overflow: auto; }
}
@media print {
pre > code.sourceCode { white-space: pre-wrap; }
pre > code.sourceCode > span { text-indent: -5em; padding-left: 5em; }
}
pre.numberSource code
{ counter-reset: source-line 0; }
pre.numberSource code > span
{ position: relative; left: -4em; counter-increment: source-line; }
pre.numberSource code > span > a:first-child::before
{ content: counter(source-line);
position: relative; left: -1em; text-align: right; vertical-align: baseline;
border: none; display: inline-block;
-webkit-touch-callout: none; -webkit-user-select: none;
-khtml-user-select: none; -moz-user-select: none;
-ms-user-select: none; user-select: none;
padding: 0 4px; width: 4em;
color: #aaaaaa;
}
pre.numberSource { margin-left: 3em; border-left: 1px solid #aaaaaa; padding-left: 4px; }
div.sourceCode
{ }
@media screen {
pre > code.sourceCode > span > a:first-child::before { text-decoration: underline; }
}
code span.al { color: #ff0000; font-weight: bold; } /* Alert */
code span.an { color: #60a0b0; font-weight: bold; font-style: italic; } /* Annotation */
code span.at { color: #7d9029; } /* Attribute */
code span.bn { color: #40a070; } /* BaseN */
code span.bu { color: #008000; } /* BuiltIn */
code span.cf { color: #007020; font-weight: bold; } /* ControlFlow */
code span.ch { color: #4070a0; } /* Char */
code span.cn { color: #880000; } /* Constant */
code span.co { color: #60a0b0; font-style: italic; } /* Comment */
code span.cv { color: #60a0b0; font-weight: bold; font-style: italic; } /* CommentVar */
code span.do { color: #ba2121; font-style: italic; } /* Documentation */
code span.dt { color: #902000; } /* DataType */
code span.dv { color: #40a070; } /* DecVal */
code span.er { color: #ff0000; font-weight: bold; } /* Error */
code span.ex { } /* Extension */
code span.fl { color: #40a070; } /* Float */
code span.fu { color: #06287e; } /* Function */
code span.im { color: #008000; font-weight: bold; } /* Import */
code span.in { color: #60a0b0; font-weight: bold; font-style: italic; } /* Information */
code span.kw { color: #007020; font-weight: bold; } /* Keyword */
code span.op { color: #666666; } /* Operator */
code span.ot { color: #007020; } /* Other */
code span.pp { color: #bc7a00; } /* Preprocessor */
code span.sc { color: #4070a0; } /* SpecialChar */
code span.ss { color: #bb6688; } /* SpecialString */
code span.st { color: #4070a0; } /* String */
code span.va { color: #19177c; } /* Variable */
code span.vs { color: #4070a0; } /* VerbatimString */
code span.wa { color: #60a0b0; font-weight: bold; font-style: italic; } /* Warning */
</style>
</head>
<body>
<div class="sourceCode" id="cb1"><pre
class="sourceCode python"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> cmdstanpy</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a>cmdstanpy.install_cmdstan()</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> scipy <span class="im">as</span> sp</span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> cmdstanpy <span class="im">as</span> csp</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> scipy.special <span class="im">import</span> expit</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> ot</span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> scipy.optimize <span class="im">import</span> minimize</span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> scipy.stats <span class="im">import</span> multivariate_normal</span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a>seed <span class="op">=</span> <span class="dv">583883</span></span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a>D <span class="op">=</span> <span class="dv">6</span></span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a>N <span class="op">=</span> <span class="dv">1500</span></span>
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a>mu_x_OH <span class="op">=</span> <span class="op">-</span><span class="dv">1</span></span>
<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a>mu_x_NY <span class="op">=</span> <span class="dv">1</span></span>
<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a>B <span class="op">=</span> <span class="dv">10_000</span></span>
<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a>h <span class="op">=</span> <span class="fl">0.04</span></span>
<span id="cb1-20"><a href="#cb1-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-21"><a href="#cb1-21" aria-hidden="true" tabindex="-1"></a>rng <span class="op">=</span> np.random.default_rng(seed)</span>
<span id="cb1-22"><a href="#cb1-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-23"><a href="#cb1-23" aria-hidden="true" tabindex="-1"></a><span class="co"># generate single parameter vectors</span></span>
<span id="cb1-24"><a href="#cb1-24" aria-hidden="true" tabindex="-1"></a>beta <span class="op">=</span> rng.normal(loc<span class="op">=</span><span class="fl">0.0</span>, scale<span class="op">=</span><span class="fl">1.0</span>, size<span class="op">=</span>D)</span>
<span id="cb1-25"><a href="#cb1-25" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f&quot;</span><span class="sc">{</span>beta<span class="op">=</span><span class="sc">}</span><span class="ss">&quot;</span>)</span>
<span id="cb1-26"><a href="#cb1-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-27"><a href="#cb1-27" aria-hidden="true" tabindex="-1"></a><span class="co"># generate Ohio data from parameters</span></span>
<span id="cb1-28"><a href="#cb1-28" aria-hidden="true" tabindex="-1"></a>x_OH <span class="op">=</span> rng.normal(loc<span class="op">=-</span><span class="fl">1.0</span>, scale<span class="op">=</span><span class="fl">1.0</span>, size<span class="op">=</span>(N, D))</span>
<span id="cb1-29"><a href="#cb1-29" aria-hidden="true" tabindex="-1"></a>y_OH <span class="op">=</span> rng.binomial(n<span class="op">=</span><span class="dv">1</span>, p <span class="op">=</span> expit(x_OH <span class="op">@</span> beta))</span>
<span id="cb1-30"><a href="#cb1-30" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-31"><a href="#cb1-31" aria-hidden="true" tabindex="-1"></a><span class="co"># generate NY data from parameters</span></span>
<span id="cb1-32"><a href="#cb1-32" aria-hidden="true" tabindex="-1"></a>x_NY <span class="op">=</span> rng.normal(loc<span class="op">=</span><span class="fl">1.0</span>, scale<span class="op">=</span><span class="fl">1.0</span>, size<span class="op">=</span>(N, D))</span>
<span id="cb1-33"><a href="#cb1-33" aria-hidden="true" tabindex="-1"></a>y_NY <span class="op">=</span> rng.binomial(n<span class="op">=</span><span class="dv">1</span>, p <span class="op">=</span> expit(x_NY <span class="op">@</span> beta))</span>
<span id="cb1-34"><a href="#cb1-34" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f&quot;</span><span class="sc">{</span>np<span class="sc">.</span>mean(y_NY)<span class="op">=</span><span class="sc">}</span><span class="ss">&quot;</span>) <span class="co"># shouldn&#39;t be too extreme</span></span>
<span id="cb1-35"><a href="#cb1-35" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-36"><a href="#cb1-36" aria-hidden="true" tabindex="-1"></a><span class="co"># fit OH data with simple logistic regression</span></span>
<span id="cb1-37"><a href="#cb1-37" aria-hidden="true" tabindex="-1"></a>data_OH <span class="op">=</span> {<span class="st">&#39;N&#39;</span>: N, <span class="st">&#39;D&#39;</span>: D, <span class="st">&#39;x&#39;</span>: x_OH, <span class="st">&#39;y&#39;</span>: y_OH }</span>
<span id="cb1-38"><a href="#cb1-38" aria-hidden="true" tabindex="-1"></a>model_OH <span class="op">=</span> csp.CmdStanModel(stan_file<span class="op">=</span><span class="st">&#39;flat-logistic.stan&#39;</span>)</span>
<span id="cb1-39"><a href="#cb1-39" aria-hidden="true" tabindex="-1"></a>fit_OH <span class="op">=</span> model_OH.sample(data<span class="op">=</span>data_OH, chains<span class="op">=</span><span class="dv">4</span>, iter_sampling<span class="op">=</span>B <span class="op">//</span> <span class="dv">4</span>, seed<span class="op">=</span>seed)</span>
<span id="cb1-40"><a href="#cb1-40" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(fit_OH.summary())</span>
<span id="cb1-41"><a href="#cb1-41" aria-hidden="true" tabindex="-1"></a>beta_OH_draws <span class="op">=</span> fit_OH.stan_variable(<span class="st">&#39;beta&#39;</span>)</span>
<span id="cb1-42"><a href="#cb1-42" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-43"><a href="#cb1-43" aria-hidden="true" tabindex="-1"></a><span class="co"># fit NY data with empirical prior from posterior of OH data</span></span>
<span id="cb1-44"><a href="#cb1-44" aria-hidden="true" tabindex="-1"></a>data_NY <span class="op">=</span> {<span class="st">&#39;N&#39;</span>: N, <span class="st">&#39;D&#39;</span>: D, <span class="st">&#39;x&#39;</span>: x_NY, <span class="st">&#39;y&#39;</span>: y_NY,</span>
<span id="cb1-45"><a href="#cb1-45" aria-hidden="true" tabindex="-1"></a> <span class="st">&#39;h&#39;</span>: h, <span class="st">&#39;B&#39;</span>: B, <span class="st">&#39;beta0&#39;</span>: beta_OH_draws }</span>
<span id="cb1-46"><a href="#cb1-46" aria-hidden="true" tabindex="-1"></a>model_NY <span class="op">=</span> csp.CmdStanModel(stan_file<span class="op">=</span><span class="st">&#39;empirical-logistic.stan&#39;</span>)</span>
<span id="cb1-47"><a href="#cb1-47" aria-hidden="true" tabindex="-1"></a>fit_NY <span class="op">=</span> model_NY.sample(data<span class="op">=</span>data_NY, chains<span class="op">=</span><span class="dv">4</span>, iter_warmup<span class="op">=</span><span class="dv">500</span>, iter_sampling<span class="op">=</span><span class="dv">500</span>, seed<span class="op">=</span>seed, show_progress<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb1-48"><a href="#cb1-48" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(fit_NY.summary())</span>
<span id="cb1-49"><a href="#cb1-49" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-50"><a href="#cb1-50" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-51"><a href="#cb1-51" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> fit_wasserstein_prior(posterior_draws, prior_family<span class="op">=</span><span class="st">&#39;mvn&#39;</span>):</span>
<span id="cb1-52"><a href="#cb1-52" aria-hidden="true" tabindex="-1"></a> <span class="co">&quot;&quot;&quot;</span></span>
<span id="cb1-53"><a href="#cb1-53" aria-hidden="true" tabindex="-1"></a><span class="co"> Fit a parametric prior by minimizing Wasserstein distance</span></span>
<span id="cb1-54"><a href="#cb1-54" aria-hidden="true" tabindex="-1"></a><span class="co"> to empirical posterior draws</span></span>
<span id="cb1-55"><a href="#cb1-55" aria-hidden="true" tabindex="-1"></a><span class="co"> &quot;&quot;&quot;</span></span>
<span id="cb1-56"><a href="#cb1-56" aria-hidden="true" tabindex="-1"></a> D <span class="op">=</span> posterior_draws.shape[<span class="dv">1</span>]</span>
<span id="cb1-57"><a href="#cb1-57" aria-hidden="true" tabindex="-1"></a> n_samples <span class="op">=</span> <span class="bu">len</span>(posterior_draws)</span>
<span id="cb1-58"><a href="#cb1-58" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-59"><a href="#cb1-59" aria-hidden="true" tabindex="-1"></a> <span class="cf">if</span> prior_family <span class="op">==</span> <span class="st">&#39;mvn&#39;</span>:</span>
<span id="cb1-60"><a href="#cb1-60" aria-hidden="true" tabindex="-1"></a> <span class="kw">def</span> objective(params):</span>
<span id="cb1-61"><a href="#cb1-61" aria-hidden="true" tabindex="-1"></a> mu <span class="op">=</span> params[:D]</span>
<span id="cb1-62"><a href="#cb1-62" aria-hidden="true" tabindex="-1"></a> <span class="co"># Use diagonal covariance for simplicity and stability</span></span>
<span id="cb1-63"><a href="#cb1-63" aria-hidden="true" tabindex="-1"></a> log_std <span class="op">=</span> params[D:]</span>
<span id="cb1-64"><a href="#cb1-64" aria-hidden="true" tabindex="-1"></a> Sigma <span class="op">=</span> np.diag(np.exp(log_std))</span>
<span id="cb1-65"><a href="#cb1-65" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-66"><a href="#cb1-66" aria-hidden="true" tabindex="-1"></a> <span class="co"># Sample from fitted distribution</span></span>
<span id="cb1-67"><a href="#cb1-67" aria-hidden="true" tabindex="-1"></a> fitted_samples <span class="op">=</span> multivariate_normal.rvs(mu, Sigma, size<span class="op">=</span>n_samples)</span>
<span id="cb1-68"><a href="#cb1-68" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-69"><a href="#cb1-69" aria-hidden="true" tabindex="-1"></a> <span class="co"># Compute Wasserstein-2 distance</span></span>
<span id="cb1-70"><a href="#cb1-70" aria-hidden="true" tabindex="-1"></a> <span class="co"># Create uniform probability distributions</span></span>
<span id="cb1-71"><a href="#cb1-71" aria-hidden="true" tabindex="-1"></a> a <span class="op">=</span> np.ones(n_samples) <span class="op">/</span> n_samples <span class="co"># uniform weights for empirical</span></span>
<span id="cb1-72"><a href="#cb1-72" aria-hidden="true" tabindex="-1"></a> b <span class="op">=</span> np.ones(n_samples) <span class="op">/</span> n_samples <span class="co"># uniform weights for fitted</span></span>
<span id="cb1-73"><a href="#cb1-73" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-74"><a href="#cb1-74" aria-hidden="true" tabindex="-1"></a> <span class="co"># Compute cost matrix (squared Euclidean distances)</span></span>
<span id="cb1-75"><a href="#cb1-75" aria-hidden="true" tabindex="-1"></a> cost_matrix <span class="op">=</span> ot.dist(posterior_draws, fitted_samples, metric<span class="op">=</span><span class="st">&#39;sqeuclidean&#39;</span>)</span>
<span id="cb1-76"><a href="#cb1-76" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-77"><a href="#cb1-77" aria-hidden="true" tabindex="-1"></a> <span class="co"># Compute optimal transport cost</span></span>
<span id="cb1-78"><a href="#cb1-78" aria-hidden="true" tabindex="-1"></a> W2_distance <span class="op">=</span> ot.emd2(a, b, cost_matrix)</span>
<span id="cb1-79"><a href="#cb1-79" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> W2_distance</span>
<span id="cb1-80"><a href="#cb1-80" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-81"><a href="#cb1-81" aria-hidden="true" tabindex="-1"></a> <span class="co"># Initialize with sample mean and std</span></span>
<span id="cb1-82"><a href="#cb1-82" aria-hidden="true" tabindex="-1"></a> init_mu <span class="op">=</span> np.mean(posterior_draws, axis<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb1-83"><a href="#cb1-83" aria-hidden="true" tabindex="-1"></a> init_std <span class="op">=</span> np.std(posterior_draws, axis<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb1-84"><a href="#cb1-84" aria-hidden="true" tabindex="-1"></a> init_log_std <span class="op">=</span> np.log(init_std <span class="op">+</span> <span class="fl">1e-6</span>) <span class="co"># avoid log(0)</span></span>
<span id="cb1-85"><a href="#cb1-85" aria-hidden="true" tabindex="-1"></a> init_params <span class="op">=</span> np.concatenate([init_mu, init_log_std])</span>
<span id="cb1-86"><a href="#cb1-86" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-87"><a href="#cb1-87" aria-hidden="true" tabindex="-1"></a> <span class="co"># Set bounds to ensure reasonable values</span></span>
<span id="cb1-88"><a href="#cb1-88" aria-hidden="true" tabindex="-1"></a> bounds <span class="op">=</span> [(<span class="va">None</span>, <span class="va">None</span>)] <span class="op">*</span> D <span class="op">+</span> [(<span class="op">-</span><span class="dv">5</span>, <span class="dv">5</span>)] <span class="op">*</span> D <span class="co"># bounds for log std devs</span></span>
<span id="cb1-89"><a href="#cb1-89" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-90"><a href="#cb1-90" aria-hidden="true" tabindex="-1"></a> <span class="co"># Optimize</span></span>
<span id="cb1-91"><a href="#cb1-91" aria-hidden="true" tabindex="-1"></a> <span class="cf">try</span>:</span>
<span id="cb1-92"><a href="#cb1-92" aria-hidden="true" tabindex="-1"></a> result <span class="op">=</span> minimize(objective, init_params, method<span class="op">=</span><span class="st">&#39;L-BFGS-B&#39;</span>, bounds<span class="op">=</span>bounds)</span>
<span id="cb1-93"><a href="#cb1-93" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-94"><a href="#cb1-94" aria-hidden="true" tabindex="-1"></a> <span class="co"># Extract optimized parameters</span></span>
<span id="cb1-95"><a href="#cb1-95" aria-hidden="true" tabindex="-1"></a> mu_opt <span class="op">=</span> result.x[:D]</span>
<span id="cb1-96"><a href="#cb1-96" aria-hidden="true" tabindex="-1"></a> log_std_opt <span class="op">=</span> result.x[D:]</span>
<span id="cb1-97"><a href="#cb1-97" aria-hidden="true" tabindex="-1"></a> Sigma_opt <span class="op">=</span> np.diag(np.exp(log_std_opt))</span>
<span id="cb1-98"><a href="#cb1-98" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-99"><a href="#cb1-99" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> mu_opt, Sigma_opt</span>
<span id="cb1-100"><a href="#cb1-100" aria-hidden="true" tabindex="-1"></a> <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb1-101"><a href="#cb1-101" aria-hidden="true" tabindex="-1"></a> <span class="bu">print</span>(<span class="ss">f&quot;Optimization failed: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">&quot;</span>)</span>
<span id="cb1-102"><a href="#cb1-102" aria-hidden="true" tabindex="-1"></a> <span class="bu">print</span>(<span class="st">&quot;Falling back to moment matching&quot;</span>)</span>
<span id="cb1-103"><a href="#cb1-103" aria-hidden="true" tabindex="-1"></a> <span class="co"># Fallback to simple moment matching</span></span>
<span id="cb1-104"><a href="#cb1-104" aria-hidden="true" tabindex="-1"></a> mu_fallback <span class="op">=</span> np.mean(posterior_draws, axis<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb1-105"><a href="#cb1-105" aria-hidden="true" tabindex="-1"></a> Sigma_fallback <span class="op">=</span> np.cov(posterior_draws.T)</span>
<span id="cb1-106"><a href="#cb1-106" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> mu_fallback, Sigma_fallback</span>
<span id="cb1-107"><a href="#cb1-107" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-108"><a href="#cb1-108" aria-hidden="true" tabindex="-1"></a><span class="co"># Or use the much simpler closed-form solution:</span></span>
<span id="cb1-109"><a href="#cb1-109" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> fit_wasserstein_prior_simple(posterior_draws):</span>
<span id="cb1-110"><a href="#cb1-110" aria-hidden="true" tabindex="-1"></a> <span class="co">&quot;&quot;&quot;</span></span>
<span id="cb1-111"><a href="#cb1-111" aria-hidden="true" tabindex="-1"></a><span class="co"> Closed-form Wasserstein-optimal Gaussian: just match moments</span></span>
<span id="cb1-112"><a href="#cb1-112" aria-hidden="true" tabindex="-1"></a><span class="co"> &quot;&quot;&quot;</span></span>
<span id="cb1-113"><a href="#cb1-113" aria-hidden="true" tabindex="-1"></a> mu <span class="op">=</span> np.mean(posterior_draws, axis<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb1-114"><a href="#cb1-114" aria-hidden="true" tabindex="-1"></a> Sigma <span class="op">=</span> np.cov(posterior_draws.T)</span>
<span id="cb1-115"><a href="#cb1-115" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> mu, Sigma</span>
<span id="cb1-116"><a href="#cb1-116" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-117"><a href="#cb1-117" aria-hidden="true" tabindex="-1"></a><span class="co"># Use the simple version:</span></span>
<span id="cb1-118"><a href="#cb1-118" aria-hidden="true" tabindex="-1"></a>mu_prior, Sigma_prior <span class="op">=</span> fit_wasserstein_prior_simple(beta_OH_draws)</span>
<span id="cb1-119"><a href="#cb1-119" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-120"><a href="#cb1-120" aria-hidden="true" tabindex="-1"></a><span class="co"># Use in Stan model (modified version)</span></span>
<span id="cb1-121"><a href="#cb1-121" aria-hidden="true" tabindex="-1"></a>data_NY_wass <span class="op">=</span> {</span>
<span id="cb1-122"><a href="#cb1-122" aria-hidden="true" tabindex="-1"></a> <span class="st">&#39;N&#39;</span>: N, <span class="st">&#39;D&#39;</span>: D, <span class="st">&#39;x&#39;</span>: x_NY, <span class="st">&#39;y&#39;</span>: y_NY,</span>
<span id="cb1-123"><a href="#cb1-123" aria-hidden="true" tabindex="-1"></a> <span class="st">&#39;mu_prior&#39;</span>: mu_prior, <span class="st">&#39;Sigma_prior&#39;</span>: Sigma_prior</span>
<span id="cb1-124"><a href="#cb1-124" aria-hidden="true" tabindex="-1"></a>}</span></code></pre></div>
<pre><code>DEBUG:cmdstanpy:cmd: make examples/bernoulli/bernoulli
cwd: None
CmdStan install directory: /root/.cmdstan
CmdStan version 2.36.0 already installed
Test model compilation
DEBUG:cmdstanpy:found newer exe file, not recompiling
DEBUG:cmdstanpy:cmd: /content/flat-logistic info
cwd: None
DEBUG:cmdstanpy:input tempfile: /tmp/tmpy2c2cbr3/vwhyncjv.json
05:05:32 - cmdstanpy - INFO - CmdStan start processing
INFO:cmdstanpy:CmdStan start processing
beta=array([ 0.32292202, -1.67712417, 0.80797451, 0.23766868, 0.86741335,
-1.506818 ])
np.mean(y_NY)=np.float64(0.368)
chain 1 | | 00:00 Status
chain 2 | | 00:00 Status
chain 3 | | 00:00 Status
chain 4 | | 00:00 Status
DEBUG:cmdstanpy:idx 0
DEBUG:cmdstanpy:running CmdStan, num_threads: 1
DEBUG:cmdstanpy:CmdStan args: [&#39;/content/flat-logistic&#39;, &#39;id=1&#39;, &#39;random&#39;, &#39;seed=583883&#39;, &#39;data&#39;, &#39;file=/tmp/tmpy2c2cbr3/vwhyncjv.json&#39;, &#39;output&#39;, &#39;file=/tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_1.csv&#39;, &#39;method=sample&#39;, &#39;num_samples=2500&#39;, &#39;algorithm=hmc&#39;, &#39;adapt&#39;, &#39;engaged=1&#39;]
DEBUG:cmdstanpy:idx 1
DEBUG:cmdstanpy:running CmdStan, num_threads: 1
DEBUG:cmdstanpy:CmdStan args: [&#39;/content/flat-logistic&#39;, &#39;id=2&#39;, &#39;random&#39;, &#39;seed=583883&#39;, &#39;data&#39;, &#39;file=/tmp/tmpy2c2cbr3/vwhyncjv.json&#39;, &#39;output&#39;, &#39;file=/tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_2.csv&#39;, &#39;method=sample&#39;, &#39;num_samples=2500&#39;, &#39;algorithm=hmc&#39;, &#39;adapt&#39;, &#39;engaged=1&#39;]
DEBUG:cmdstanpy:idx 2
DEBUG:cmdstanpy:running CmdStan, num_threads: 1
DEBUG:cmdstanpy:CmdStan args: [&#39;/content/flat-logistic&#39;, &#39;id=3&#39;, &#39;random&#39;, &#39;seed=583883&#39;, &#39;data&#39;, &#39;file=/tmp/tmpy2c2cbr3/vwhyncjv.json&#39;, &#39;output&#39;, &#39;file=/tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_3.csv&#39;, &#39;method=sample&#39;, &#39;num_samples=2500&#39;, &#39;algorithm=hmc&#39;, &#39;adapt&#39;, &#39;engaged=1&#39;]
DEBUG:cmdstanpy:idx 3
DEBUG:cmdstanpy:running CmdStan, num_threads: 1
DEBUG:cmdstanpy:CmdStan args: [&#39;/content/flat-logistic&#39;, &#39;id=4&#39;, &#39;random&#39;, &#39;seed=583883&#39;, &#39;data&#39;, &#39;file=/tmp/tmpy2c2cbr3/vwhyncjv.json&#39;, &#39;output&#39;, &#39;file=/tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_4.csv&#39;, &#39;method=sample&#39;, &#39;num_samples=2500&#39;, &#39;algorithm=hmc&#39;, &#39;adapt&#39;, &#39;engaged=1&#39;]
05:05:45 - cmdstanpy - INFO - CmdStan done processing.
INFO:cmdstanpy:CmdStan done processing.
DEBUG:cmdstanpy:runset
RunSet: chains=4, chain_ids=[1, 2, 3, 4], num_processes=4
cmd (chain 1):
[&#39;/content/flat-logistic&#39;, &#39;id=1&#39;, &#39;random&#39;, &#39;seed=583883&#39;, &#39;data&#39;, &#39;file=/tmp/tmpy2c2cbr3/vwhyncjv.json&#39;, &#39;output&#39;, &#39;file=/tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_1.csv&#39;, &#39;method=sample&#39;, &#39;num_samples=2500&#39;, &#39;algorithm=hmc&#39;, &#39;adapt&#39;, &#39;engaged=1&#39;]
retcodes=[0, 0, 0, 0]
per-chain output files (showing chain 1 only):
csv_file:
/tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_1.csv
console_msgs (if any):
/tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_0-stdout.txt
DEBUG:cmdstanpy:Chain 1 console:
method = sample (Default)
sample
num_samples = 2500
num_warmup = 1000 (Default)
save_warmup = false (Default)
thin = 1 (Default)
adapt
engaged = true (Default)
gamma = 0.05 (Default)
delta = 0.8 (Default)
kappa = 0.75 (Default)
t0 = 10 (Default)
init_buffer = 75 (Default)
term_buffer = 50 (Default)
window = 25 (Default)
save_metric = false (Default)
algorithm = hmc (Default)
hmc
engine = nuts (Default)
nuts
max_depth = 10 (Default)
metric = diag_e (Default)
metric_file = (Default)
stepsize = 1 (Default)
stepsize_jitter = 0 (Default)
num_chains = 1 (Default)
id = 1 (Default)
data
file = /tmp/tmpy2c2cbr3/vwhyncjv.json
init = 2 (Default)
random
seed = 583883
output
file = /tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_1.csv
diagnostic_file = (Default)
refresh = 100 (Default)
sig_figs = -1 (Default)
profile_file = profile.csv (Default)
save_cmdstan_config = false (Default)
num_threads = 1 (Default)
Gradient evaluation took 0.000185 seconds
1000 transitions using 10 leapfrog steps per transition would take 1.85 seconds.
Adjust your expectations accordingly!
Iteration: 1 / 3500 [ 0%] (Warmup)
Iteration: 100 / 3500 [ 2%] (Warmup)
Iteration: 200 / 3500 [ 5%] (Warmup)
Iteration: 300 / 3500 [ 8%] (Warmup)
Iteration: 400 / 3500 [ 11%] (Warmup)
Iteration: 500 / 3500 [ 14%] (Warmup)
Iteration: 600 / 3500 [ 17%] (Warmup)
Iteration: 700 / 3500 [ 20%] (Warmup)
Iteration: 800 / 3500 [ 22%] (Warmup)
Iteration: 900 / 3500 [ 25%] (Warmup)
Iteration: 1000 / 3500 [ 28%] (Warmup)
Iteration: 1001 / 3500 [ 28%] (Sampling)
Iteration: 1100 / 3500 [ 31%] (Sampling)
Iteration: 1200 / 3500 [ 34%] (Sampling)
Iteration: 1300 / 3500 [ 37%] (Sampling)
Iteration: 1400 / 3500 [ 40%] (Sampling)
Iteration: 1500 / 3500 [ 42%] (Sampling)
Iteration: 1600 / 3500 [ 45%] (Sampling)
Iteration: 1700 / 3500 [ 48%] (Sampling)
Iteration: 1800 / 3500 [ 51%] (Sampling)
Iteration: 1900 / 3500 [ 54%] (Sampling)
Iteration: 2000 / 3500 [ 57%] (Sampling)
Iteration: 2100 / 3500 [ 60%] (Sampling)
Iteration: 2200 / 3500 [ 62%] (Sampling)
Iteration: 2300 / 3500 [ 65%] (Sampling)
Iteration: 2400 / 3500 [ 68%] (Sampling)
Iteration: 2500 / 3500 [ 71%] (Sampling)
Iteration: 2600 / 3500 [ 74%] (Sampling)
Iteration: 2700 / 3500 [ 77%] (Sampling)
Iteration: 2800 / 3500 [ 80%] (Sampling)
Iteration: 2900 / 3500 [ 82%] (Sampling)
Iteration: 3000 / 3500 [ 85%] (Sampling)
Iteration: 3100 / 3500 [ 88%] (Sampling)
Iteration: 3200 / 3500 [ 91%] (Sampling)
Iteration: 3300 / 3500 [ 94%] (Sampling)
Iteration: 3400 / 3500 [ 97%] (Sampling)
Iteration: 3500 / 3500 [100%] (Sampling)
Elapsed Time: 1.391 seconds (Warm-up)
3.46 seconds (Sampling)
4.851 seconds (Total)
DEBUG:cmdstanpy:cmd: /root/.cmdstan/cmdstan-2.36.0/bin/stansummary --percentiles= 5,50,95 --sig_figs=6 --csv_filename=/tmp/tmpy2c2cbr3/stansummary-flat-logistic-_dkv4lm4.csv /tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_1.csv /tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_2.csv /tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_3.csv /tmp/tmpy2c2cbr3/flat-logistical8q2pgf/flat-logistic-20250514050532_4.csv
cwd: None
DEBUG:cmdstanpy:found newer exe file, not recompiling
DEBUG:cmdstanpy:cmd: /content/empirical-logistic info
cwd: None
DEBUG:cmdstanpy:input tempfile: /tmp/tmpy2c2cbr3/yg69ajzg.json
05:05:45 - cmdstanpy - INFO - CmdStan start processing
INFO:cmdstanpy:CmdStan start processing
DEBUG:cmdstanpy:idx 0
DEBUG:cmdstanpy:running CmdStan, num_threads: 1
DEBUG:cmdstanpy:CmdStan args: [&#39;/content/empirical-logistic&#39;, &#39;id=1&#39;, &#39;random&#39;, &#39;seed=583883&#39;, &#39;data&#39;, &#39;file=/tmp/tmpy2c2cbr3/yg69ajzg.json&#39;, &#39;output&#39;, &#39;file=/tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_1.csv&#39;, &#39;method=sample&#39;, &#39;num_samples=500&#39;, &#39;num_warmup=500&#39;, &#39;algorithm=hmc&#39;, &#39;adapt&#39;, &#39;engaged=1&#39;]
05:05:45 - cmdstanpy - INFO - Chain [1] start processing
DEBUG:cmdstanpy:idx 1
DEBUG:cmdstanpy:running CmdStan, num_threads: 1
DEBUG:cmdstanpy:CmdStan args: [&#39;/content/empirical-logistic&#39;, &#39;id=2&#39;, &#39;random&#39;, &#39;seed=583883&#39;, &#39;data&#39;, &#39;file=/tmp/tmpy2c2cbr3/yg69ajzg.json&#39;, &#39;output&#39;, &#39;file=/tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_2.csv&#39;, &#39;method=sample&#39;, &#39;num_samples=500&#39;, &#39;num_warmup=500&#39;, &#39;algorithm=hmc&#39;, &#39;adapt&#39;, &#39;engaged=1&#39;]
Mean MCSE StdDev MAD 5% 50% \
lp__ -557.406000 0.026795 1.770460 1.568590 -560.791000 -557.063000
beta[1] 0.366866 0.000819 0.072424 0.071326 0.250861 0.366657
beta[2] -1.789300 0.001410 0.107540 0.107689 -1.969140 -1.785910
beta[3] 0.917581 0.000989 0.079849 0.078577 0.787154 0.916908
beta[4] 0.245058 0.000795 0.072760 0.073003 0.124279 0.245038
beta[5] 0.826520 0.000930 0.078852 0.078082 0.699908 0.825228
beta[6] -1.578740 0.001312 0.098768 0.098519 -1.742420 -1.576840
95% ESS_bulk ESS_tail R_hat
lp__ -555.199000 4633.58 5881.42 0.999884
beta[1] 0.487298 7869.64 6899.19 1.000120
beta[2] -1.618440 5936.46 6436.66 1.000000
beta[3] 1.051410 6569.76 6402.72 1.000210
beta[4] 0.364416 8444.90 6252.19 1.000350
beta[5] 0.956443 7271.80 6691.55 1.000470
beta[6] -1.418180 5713.45 5901.44 1.000730
05:05:45 - cmdstanpy - INFO - Chain [2] start processing
INFO:cmdstanpy:Chain [1] start processing
INFO:cmdstanpy:Chain [2] start processing
05:06:48 - cmdstanpy - INFO - Chain [2] done processing
INFO:cmdstanpy:Chain [2] done processing
DEBUG:cmdstanpy:idx 2
DEBUG:cmdstanpy:running CmdStan, num_threads: 1
DEBUG:cmdstanpy:CmdStan args: [&#39;/content/empirical-logistic&#39;, &#39;id=3&#39;, &#39;random&#39;, &#39;seed=583883&#39;, &#39;data&#39;, &#39;file=/tmp/tmpy2c2cbr3/yg69ajzg.json&#39;, &#39;output&#39;, &#39;file=/tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_3.csv&#39;, &#39;method=sample&#39;, &#39;num_samples=500&#39;, &#39;num_warmup=500&#39;, &#39;algorithm=hmc&#39;, &#39;adapt&#39;, &#39;engaged=1&#39;]
05:06:48 - cmdstanpy - INFO - Chain [3] start processing
INFO:cmdstanpy:Chain [3] start processing
05:06:50 - cmdstanpy - INFO - Chain [1] done processing
INFO:cmdstanpy:Chain [1] done processing
DEBUG:cmdstanpy:idx 3
DEBUG:cmdstanpy:running CmdStan, num_threads: 1
DEBUG:cmdstanpy:CmdStan args: [&#39;/content/empirical-logistic&#39;, &#39;id=4&#39;, &#39;random&#39;, &#39;seed=583883&#39;, &#39;data&#39;, &#39;file=/tmp/tmpy2c2cbr3/yg69ajzg.json&#39;, &#39;output&#39;, &#39;file=/tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_4.csv&#39;, &#39;method=sample&#39;, &#39;num_samples=500&#39;, &#39;num_warmup=500&#39;, &#39;algorithm=hmc&#39;, &#39;adapt&#39;, &#39;engaged=1&#39;]
05:06:50 - cmdstanpy - INFO - Chain [4] start processing
INFO:cmdstanpy:Chain [4] start processing
05:07:50 - cmdstanpy - INFO - Chain [3] done processing
INFO:cmdstanpy:Chain [3] done processing
05:07:53 - cmdstanpy - INFO - Chain [4] done processing
INFO:cmdstanpy:Chain [4] done processing
DEBUG:cmdstanpy:runset
RunSet: chains=4, chain_ids=[1, 2, 3, 4], num_processes=4
cmd (chain 1):
[&#39;/content/empirical-logistic&#39;, &#39;id=1&#39;, &#39;random&#39;, &#39;seed=583883&#39;, &#39;data&#39;, &#39;file=/tmp/tmpy2c2cbr3/yg69ajzg.json&#39;, &#39;output&#39;, &#39;file=/tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_1.csv&#39;, &#39;method=sample&#39;, &#39;num_samples=500&#39;, &#39;num_warmup=500&#39;, &#39;algorithm=hmc&#39;, &#39;adapt&#39;, &#39;engaged=1&#39;]
retcodes=[0, 0, 0, 0]
per-chain output files (showing chain 1 only):
csv_file:
/tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_1.csv
console_msgs (if any):
/tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_0-stdout.txt
DEBUG:cmdstanpy:Chain 1 console:
method = sample (Default)
sample
num_samples = 500
num_warmup = 500
save_warmup = false (Default)
thin = 1 (Default)
adapt
engaged = true (Default)
gamma = 0.05 (Default)
delta = 0.8 (Default)
kappa = 0.75 (Default)
t0 = 10 (Default)
init_buffer = 75 (Default)
term_buffer = 50 (Default)
window = 25 (Default)
save_metric = false (Default)
algorithm = hmc (Default)
hmc
engine = nuts (Default)
nuts
max_depth = 10 (Default)
metric = diag_e (Default)
metric_file = (Default)
stepsize = 1 (Default)
stepsize_jitter = 0 (Default)
num_chains = 1 (Default)
id = 1 (Default)
data
file = /tmp/tmpy2c2cbr3/yg69ajzg.json
init = 2 (Default)
random
seed = 583883
output
file = /tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_1.csv
diagnostic_file = (Default)
refresh = 100 (Default)
sig_figs = -1 (Default)
profile_file = profile.csv (Default)
save_cmdstan_config = false (Default)
num_threads = 1 (Default)
Gradient evaluation took 0.008579 seconds
1000 transitions using 10 leapfrog steps per transition would take 85.79 seconds.
Adjust your expectations accordingly!
Iteration: 1 / 1000 [ 0%] (Warmup)
Iteration: 100 / 1000 [ 10%] (Warmup)
Iteration: 200 / 1000 [ 20%] (Warmup)
Iteration: 300 / 1000 [ 30%] (Warmup)
Iteration: 400 / 1000 [ 40%] (Warmup)
Iteration: 500 / 1000 [ 50%] (Warmup)
Iteration: 501 / 1000 [ 50%] (Sampling)
Iteration: 600 / 1000 [ 60%] (Sampling)
Iteration: 700 / 1000 [ 70%] (Sampling)
Iteration: 800 / 1000 [ 80%] (Sampling)
Iteration: 900 / 1000 [ 90%] (Sampling)
Iteration: 1000 / 1000 [100%] (Sampling)
Elapsed Time: 32.347 seconds (Warm-up)
32.174 seconds (Sampling)
64.521 seconds (Total)
DEBUG:cmdstanpy:cmd: /root/.cmdstan/cmdstan-2.36.0/bin/stansummary --percentiles= 5,50,95 --sig_figs=6 --csv_filename=/tmp/tmpy2c2cbr3/stansummary-empirical-logistic-v4gv2feg.csv /tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_1.csv /tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_2.csv /tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_3.csv /tmp/tmpy2c2cbr3/empirical-logistic3nhl762p/empirical-logistic-20250514050545_4.csv
cwd: None
Mean MCSE StdDev MAD 5% 50% \
lp__ -559.231000 0.052517 1.670540 1.512990 -562.479000 -558.881000
beta[1] 0.350303 0.001256 0.053163 0.055772 0.264625 0.349386
beta[2] -1.684520 0.002128 0.072740 0.073189 -1.806480 -1.683590
beta[3] 0.816993 0.001509 0.058009 0.055932 0.720248 0.818284
beta[4] 0.308557 0.001222 0.052301 0.054825 0.222648 0.310081
beta[5] 0.807412 0.001536 0.056870 0.057270 0.714319 0.807694
beta[6] -1.547500 0.002062 0.069704 0.068882 -1.664110 -1.546650
95% ESS_bulk ESS_tail R_hat
lp__ -557.139000 1042.35 1121.73 1.007360
beta[1] 0.440071 1791.38 1473.59 1.000620
beta[2] -1.568570 1172.88 1269.70 1.001860
beta[3] 0.910145 1522.17 1148.50 0.999944
beta[4] 0.389971 1861.81 1350.30 1.001190
beta[5] 0.902555 1385.68 1368.33 1.002960
beta[6] -1.435650 1155.15 1153.33 0.999938 </code></pre>
<div class="sourceCode" id="cb3"><pre
class="sourceCode python"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> fit_wasserstein_prior(draws):</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a> <span class="co">&quot;&quot;&quot;</span></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="co"> Returns the Gaussian (mu, Sigma) that minimizes W2 distance</span></span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a><span class="co"> to the empirical measure of &#39;draws&#39; via matching first two moments.</span></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="co"> &quot;&quot;&quot;</span></span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a> mu <span class="op">=</span> np.mean(draws, axis<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a> Sigma <span class="op">=</span> np.cov(draws, rowvar<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> mu, Sigma</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a><span class="co"># After fitting the Ohio model:</span></span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>beta_OH_draws <span class="op">=</span> fit_OH.stan_variable(<span class="st">&#39;beta&#39;</span>) <span class="co"># shape (B, D)</span></span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Fit the W2‐optimal Gaussian prior</span></span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>mu_prior, Sigma_prior <span class="op">=</span> fit_wasserstein_prior(beta_OH_draws)</span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Prepare data for the next Stan model (NY with Gaussian prior)</span></span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>data_NY_wass <span class="op">=</span> {</span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a> <span class="st">&#39;N&#39;</span>: N,</span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a> <span class="st">&#39;D&#39;</span>: D,</span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a> <span class="st">&#39;x&#39;</span>: x_NY,</span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a> <span class="st">&#39;y&#39;</span>: y_NY,</span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a> <span class="st">&#39;mu_prior&#39;</span>: mu_prior,</span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a> <span class="st">&#39;Sigma_prior&#39;</span>: Sigma_prior</span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a>model_NY_wass <span class="op">=</span> csp.CmdStanModel(stan_file<span class="op">=</span><span class="st">&#39;gaussian-prior-logistic.stan&#39;</span>)</span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a>fit_NY_wass <span class="op">=</span> model_NY_wass.sample(data<span class="op">=</span>data_NY_wass, chains<span class="op">=</span><span class="dv">4</span>, iter_warmup<span class="op">=</span><span class="dv">500</span>, iter_sampling<span class="op">=</span><span class="dv">500</span>, seed<span class="op">=</span>seed)</span>
<span id="cb3-29"><a href="#cb3-29" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(fit_NY_wass.summary())</span></code></pre></div>
<pre><code>DEBUG:cmdstanpy:found newer exe file, not recompiling
DEBUG:cmdstanpy:cmd: /content/gaussian-prior-logistic info
cwd: None
DEBUG:cmdstanpy:input tempfile: /tmp/tmpy2c2cbr3/mgqi9mu9.json
05:07:53 - cmdstanpy - INFO - CmdStan start processing
INFO:cmdstanpy:CmdStan start processing
chain 1 | | 00:00 Status
chain 2 | | 00:00 Status
chain 3 | | 00:00 Status
chain 4 | | 00:00 Status
DEBUG:cmdstanpy:idx 0
DEBUG:cmdstanpy:running CmdStan, num_threads: 1
DEBUG:cmdstanpy:idx 1
DEBUG:cmdstanpy:running CmdStan, num_threads: 1
DEBUG:cmdstanpy:CmdStan args: [&#39;/content/gaussian-prior-logistic&#39;, &#39;id=2&#39;, &#39;random&#39;, &#39;seed=583883&#39;, &#39;data&#39;, &#39;file=/tmp/tmpy2c2cbr3/mgqi9mu9.json&#39;, &#39;output&#39;, &#39;file=/tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_2.csv&#39;, &#39;method=sample&#39;, &#39;num_samples=500&#39;, &#39;num_warmup=500&#39;, &#39;algorithm=hmc&#39;, &#39;adapt&#39;, &#39;engaged=1&#39;]
DEBUG:cmdstanpy:CmdStan args: [&#39;/content/gaussian-prior-logistic&#39;, &#39;id=1&#39;, &#39;random&#39;, &#39;seed=583883&#39;, &#39;data&#39;, &#39;file=/tmp/tmpy2c2cbr3/mgqi9mu9.json&#39;, &#39;output&#39;, &#39;file=/tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_1.csv&#39;, &#39;method=sample&#39;, &#39;num_samples=500&#39;, &#39;num_warmup=500&#39;, &#39;algorithm=hmc&#39;, &#39;adapt&#39;, &#39;engaged=1&#39;]
DEBUG:cmdstanpy:idx 2
DEBUG:cmdstanpy:running CmdStan, num_threads: 1
DEBUG:cmdstanpy:CmdStan args: [&#39;/content/gaussian-prior-logistic&#39;, &#39;id=3&#39;, &#39;random&#39;, &#39;seed=583883&#39;, &#39;data&#39;, &#39;file=/tmp/tmpy2c2cbr3/mgqi9mu9.json&#39;, &#39;output&#39;, &#39;file=/tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_3.csv&#39;, &#39;method=sample&#39;, &#39;num_samples=500&#39;, &#39;num_warmup=500&#39;, &#39;algorithm=hmc&#39;, &#39;adapt&#39;, &#39;engaged=1&#39;]
DEBUG:cmdstanpy:idx 3
DEBUG:cmdstanpy:running CmdStan, num_threads: 1
DEBUG:cmdstanpy:CmdStan args: [&#39;/content/gaussian-prior-logistic&#39;, &#39;id=4&#39;, &#39;random&#39;, &#39;seed=583883&#39;, &#39;data&#39;, &#39;file=/tmp/tmpy2c2cbr3/mgqi9mu9.json&#39;, &#39;output&#39;, &#39;file=/tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_4.csv&#39;, &#39;method=sample&#39;, &#39;num_samples=500&#39;, &#39;num_warmup=500&#39;, &#39;algorithm=hmc&#39;, &#39;adapt&#39;, &#39;engaged=1&#39;]
05:07:58 - cmdstanpy - INFO - CmdStan done processing.
INFO:cmdstanpy:CmdStan done processing.
DEBUG:cmdstanpy:runset
RunSet: chains=4, chain_ids=[1, 2, 3, 4], num_processes=4
cmd (chain 1):
[&#39;/content/gaussian-prior-logistic&#39;, &#39;id=1&#39;, &#39;random&#39;, &#39;seed=583883&#39;, &#39;data&#39;, &#39;file=/tmp/tmpy2c2cbr3/mgqi9mu9.json&#39;, &#39;output&#39;, &#39;file=/tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_1.csv&#39;, &#39;method=sample&#39;, &#39;num_samples=500&#39;, &#39;num_warmup=500&#39;, &#39;algorithm=hmc&#39;, &#39;adapt&#39;, &#39;engaged=1&#39;]
retcodes=[0, 0, 0, 0]
per-chain output files (showing chain 1 only):
csv_file:
/tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_1.csv
console_msgs (if any):
/tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_0-stdout.txt
DEBUG:cmdstanpy:Chain 1 console:
method = sample (Default)
sample
num_samples = 500
num_warmup = 500
save_warmup = false (Default)
thin = 1 (Default)
adapt
engaged = true (Default)
gamma = 0.05 (Default)
delta = 0.8 (Default)
kappa = 0.75 (Default)
t0 = 10 (Default)
init_buffer = 75 (Default)
term_buffer = 50 (Default)
window = 25 (Default)
save_metric = false (Default)
algorithm = hmc (Default)
hmc
engine = nuts (Default)
nuts
max_depth = 10 (Default)
metric = diag_e (Default)
metric_file = (Default)
stepsize = 1 (Default)
stepsize_jitter = 0 (Default)
num_chains = 1 (Default)
id = 1 (Default)
data
file = /tmp/tmpy2c2cbr3/mgqi9mu9.json
init = 2 (Default)
random
seed = 583883
output
file = /tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_1.csv
diagnostic_file = (Default)
refresh = 100 (Default)
sig_figs = -1 (Default)
profile_file = profile.csv (Default)
save_cmdstan_config = false (Default)
num_threads = 1 (Default)
Gradient evaluation took 0.000179 seconds
1000 transitions using 10 leapfrog steps per transition would take 1.79 seconds.
Adjust your expectations accordingly!
Iteration: 1 / 1000 [ 0%] (Warmup)
Iteration: 100 / 1000 [ 10%] (Warmup)
Iteration: 200 / 1000 [ 20%] (Warmup)
Iteration: 300 / 1000 [ 30%] (Warmup)
Iteration: 400 / 1000 [ 40%] (Warmup)
Iteration: 500 / 1000 [ 50%] (Warmup)
Iteration: 501 / 1000 [ 50%] (Sampling)
Iteration: 600 / 1000 [ 60%] (Sampling)
Iteration: 700 / 1000 [ 70%] (Sampling)
Iteration: 800 / 1000 [ 80%] (Sampling)
Iteration: 900 / 1000 [ 90%] (Sampling)
Iteration: 1000 / 1000 [100%] (Sampling)
Elapsed Time: 0.998 seconds (Warm-up)
1.215 seconds (Sampling)
2.213 seconds (Total)
DEBUG:cmdstanpy:cmd: /root/.cmdstan/cmdstan-2.36.0/bin/stansummary --percentiles= 5,50,95 --sig_figs=6 --csv_filename=/tmp/tmpy2c2cbr3/stansummary-gaussian-prior-logistic-w4iqkew2.csv /tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_1.csv /tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_2.csv /tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_3.csv /tmp/tmpy2c2cbr3/gaussian-prior-logistich_ssz9bq/gaussian-prior-logistic-20250514050753_4.csv
cwd: None
Mean MCSE StdDev MAD 5% 50% \
lp__ -568.833000 0.061184 1.775250 1.637530 -572.075000 -568.518000
beta[1] 0.352418 0.001256 0.049502 0.048350 0.271975 0.351888
beta[2] -1.696680 0.002096 0.075393 0.072381 -1.818390 -1.698950
beta[3] 0.824647 0.001493 0.057310 0.057457 0.732214 0.822308
beta[4] 0.301689 0.001300 0.050560 0.050508 0.216780 0.301548
beta[5] 0.808698 0.001430 0.056148 0.054268 0.714924 0.808041
beta[6] -1.550370 0.001959 0.071111 0.067636 -1.671050 -1.548620
95% ESS_bulk ESS_tail R_hat
lp__ -566.585000 881.262 1143.34 1.002360
beta[1] 0.434444 1584.370 1544.86 1.001480
beta[2] -1.570150 1307.350 1407.69 0.999047
beta[3] 0.922475 1502.040 1410.28 1.002310
beta[4] 0.386601 1520.230 1296.52 1.000290
beta[5] 0.901743 1563.190 1176.63 1.000660
beta[6] -1.436520 1349.680 1310.96 1.000150 </code></pre>
<div class="sourceCode" id="cb5"><pre
class="sourceCode python"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> matplotlib.pyplot <span class="im">as</span> plt</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> seaborn <span class="im">as</span> sns</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> plot_prior_comparison(empirical_draws, mu_wass, Sigma_wass, h_kde):</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a> <span class="co">&quot;&quot;&quot;Compare KDE vs Wasserstein approximations visually&quot;&quot;&quot;</span></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a> fig, axes <span class="op">=</span> plt.subplots(<span class="dv">2</span>, <span class="dv">3</span>, figsize<span class="op">=</span>(<span class="dv">15</span>, <span class="dv">10</span>))</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> dim <span class="kw">in</span> <span class="bu">range</span>(<span class="bu">min</span>(<span class="dv">6</span>, D)):</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a> ax <span class="op">=</span> axes[dim <span class="op">//</span> <span class="dv">3</span>, dim <span class="op">%</span> <span class="dv">3</span>]</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a> <span class="co"># Plot empirical posterior</span></span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a> ax.hist(empirical_draws[:, dim], bins<span class="op">=</span><span class="dv">50</span>, alpha<span class="op">=</span><span class="fl">0.3</span>,</span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a> label<span class="op">=</span><span class="st">&#39;Empirical posterior&#39;</span>, density<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a> <span class="co"># Plot KDE approximation</span></span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a> x_range <span class="op">=</span> np.linspace(empirical_draws[:, dim].<span class="bu">min</span>(),</span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a> empirical_draws[:, dim].<span class="bu">max</span>(), <span class="dv">100</span>)</span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a> kde_samples <span class="op">=</span> rng.normal(mu_wass[dim], np.sqrt(Sigma_wass[dim, dim]), <span class="dv">10000</span>)</span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a> ax.hist(kde_samples, bins<span class="op">=</span><span class="dv">50</span>, alpha<span class="op">=</span><span class="fl">0.3</span>,</span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a> label<span class="op">=</span><span class="st">&#39;Wasserstein prior&#39;</span>, density<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a> ax.set_title(<span class="ss">f&#39;Dimension </span><span class="sc">{</span>dim<span class="sc">}</span><span class="ss">&#39;</span>)</span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a> ax.legend()</span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a> plt.tight_layout()</span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a> plt.show()</span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a><span class="co"># After fitting both methods</span></span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a>plot_prior_comparison(beta_OH_draws, mu_prior, Sigma_prior, h)</span></code></pre></div>
<figure>
<img src="wasserstein_prior_files/wasserstein_prior_2_0.png"
alt="png" />
<figcaption aria-hidden="true">png</figcaption>
</figure>
<div class="sourceCode" id="cb6"><pre
class="sourceCode python"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> posterior_predictive_check(x_test, beta_draws, beta_true):</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a> <span class="co">&quot;&quot;&quot;Check if posterior draws generate reasonable predictions&quot;&quot;&quot;</span></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a> <span class="co"># Generate predictions from posterior draws</span></span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a> pred_probs <span class="op">=</span> []</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> beta_draw <span class="kw">in</span> beta_draws[<span class="op">-</span><span class="dv">1000</span>:]: <span class="co"># Use last 1000 draws</span></span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a> pred_probs.append(expit(x_test <span class="op">@</span> beta_draw))</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a> pred_probs <span class="op">=</span> np.array(pred_probs)</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a> <span class="co"># True probabilities</span></span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a> true_probs <span class="op">=</span> expit(x_test <span class="op">@</span> beta_true)</span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a> <span class="co"># Check coverage</span></span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a> pred_mean <span class="op">=</span> np.mean(pred_probs, axis<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a> pred_lower <span class="op">=</span> np.percentile(pred_probs, <span class="fl">2.5</span>, axis<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a> pred_upper <span class="op">=</span> np.percentile(pred_probs, <span class="fl">97.5</span>, axis<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a> coverage <span class="op">=</span> np.mean((true_probs <span class="op">&gt;=</span> pred_lower) <span class="op">&amp;</span> (true_probs <span class="op">&lt;=</span> pred_upper))</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a> <span class="bu">print</span>(<span class="ss">f&quot;95% credible interval coverage: </span><span class="sc">{</span>coverage<span class="sc">:.3f}</span><span class="ss">&quot;</span>)</span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a> <span class="co"># Plot calibration</span></span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a> plt.figure(figsize<span class="op">=</span>(<span class="dv">10</span>, <span class="dv">6</span>))</span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a> plt.subplot(<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">1</span>)</span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a> plt.scatter(true_probs, pred_mean, alpha<span class="op">=</span><span class="fl">0.5</span>)</span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a> plt.plot([<span class="dv">0</span>, <span class="dv">1</span>], [<span class="dv">0</span>, <span class="dv">1</span>], <span class="st">&#39;r--&#39;</span>)</span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a> plt.xlabel(<span class="st">&#39;True probability&#39;</span>)</span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a> plt.ylabel(<span class="st">&#39;Predicted probability&#39;</span>)</span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a> plt.title(<span class="st">&#39;Calibration&#39;</span>)</span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a> <span class="co"># Plot prediction intervals</span></span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a> plt.subplot(<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">2</span>)</span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a> order <span class="op">=</span> np.argsort(true_probs)</span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a> plt.fill_between(<span class="bu">range</span>(<span class="bu">len</span>(order)), pred_lower[order], pred_upper[order], alpha<span class="op">=</span><span class="fl">0.3</span>)</span>
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a> plt.plot(true_probs[order], <span class="st">&#39;r-&#39;</span>, label<span class="op">=</span><span class="st">&#39;True&#39;</span>)</span>
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a> plt.plot(pred_mean[order], <span class="st">&#39;b-&#39;</span>, label<span class="op">=</span><span class="st">&#39;Predicted&#39;</span>)</span>
<span id="cb6-35"><a href="#cb6-35" aria-hidden="true" tabindex="-1"></a> plt.legend()</span>
<span id="cb6-36"><a href="#cb6-36" aria-hidden="true" tabindex="-1"></a> plt.title(<span class="st">&#39;Prediction intervals&#39;</span>)</span>
<span id="cb6-37"><a href="#cb6-37" aria-hidden="true" tabindex="-1"></a> plt.tight_layout()</span>
<span id="cb6-38"><a href="#cb6-38" aria-hidden="true" tabindex="-1"></a> plt.show()</span>
<span id="cb6-39"><a href="#cb6-39" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-40"><a href="#cb6-40" aria-hidden="true" tabindex="-1"></a><span class="co"># Test all three methods</span></span>
<span id="cb6-41"><a href="#cb6-41" aria-hidden="true" tabindex="-1"></a>x_test <span class="op">=</span> rng.normal(loc<span class="op">=</span><span class="fl">0.0</span>, scale<span class="op">=</span><span class="fl">1.0</span>, size<span class="op">=</span>(<span class="dv">500</span>, D))</span>
<span id="cb6-42"><a href="#cb6-42" aria-hidden="true" tabindex="-1"></a>posterior_predictive_check(x_test, fit_NY.stan_variable(<span class="st">&#39;beta&#39;</span>), beta)</span></code></pre></div>
<pre><code>95% credible interval coverage: 1.000</code></pre>
<figure>
<img src="wasserstein_prior_files/wasserstein_prior_3_1.png"
alt="png" />
<figcaption aria-hidden="true">png</figcaption>
</figure>
<div class="sourceCode" id="cb8"><pre
class="sourceCode python"><code class="sourceCode python"></code></pre></div>
</body>
</html>
# -*- coding: utf-8 -*-
"""wasserstein_prior.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1veSI3587AFs0jRmwOQSiA_7K42lWa4HX
"""
import cmdstanpy
cmdstanpy.install_cmdstan()
import numpy as np
import scipy as sp
import cmdstanpy as csp
from scipy.special import expit
import ot
from scipy.optimize import minimize
from scipy.stats import multivariate_normal
seed = 583883
D = 6
N = 1500
mu_x_OH = -1
mu_x_NY = 1
B = 10_000
h = 0.04
rng = np.random.default_rng(seed)
# generate single parameter vectors
beta = rng.normal(loc=0.0, scale=1.0, size=D)
print(f"{beta=}")
# generate Ohio data from parameters
x_OH = rng.normal(loc=-1.0, scale=1.0, size=(N, D))
y_OH = rng.binomial(n=1, p = expit(x_OH @ beta))
# generate NY data from parameters
x_NY = rng.normal(loc=1.0, scale=1.0, size=(N, D))
y_NY = rng.binomial(n=1, p = expit(x_NY @ beta))
print(f"{np.mean(y_NY)=}") # shouldn't be too extreme
# fit OH data with simple logistic regression
data_OH = {'N': N, 'D': D, 'x': x_OH, 'y': y_OH }
model_OH = csp.CmdStanModel(stan_file='flat-logistic.stan')
fit_OH = model_OH.sample(data=data_OH, chains=4, iter_sampling=B // 4, seed=seed)
print(fit_OH.summary())
beta_OH_draws = fit_OH.stan_variable('beta')
# fit NY data with empirical prior from posterior of OH data
data_NY = {'N': N, 'D': D, 'x': x_NY, 'y': y_NY,
'h': h, 'B': B, 'beta0': beta_OH_draws }
model_NY = csp.CmdStanModel(stan_file='empirical-logistic.stan')
fit_NY = model_NY.sample(data=data_NY, chains=4, iter_warmup=500, iter_sampling=500, seed=seed, show_progress=False)
print(fit_NY.summary())
def fit_wasserstein_prior(posterior_draws, prior_family='mvn'):
"""
Fit a parametric prior by minimizing Wasserstein distance
to empirical posterior draws
"""
D = posterior_draws.shape[1]
n_samples = len(posterior_draws)
if prior_family == 'mvn':
def objective(params):
mu = params[:D]
# Use diagonal covariance for simplicity and stability
log_std = params[D:]
Sigma = np.diag(np.exp(log_std))
# Sample from fitted distribution
fitted_samples = multivariate_normal.rvs(mu, Sigma, size=n_samples)
# Compute Wasserstein-2 distance
# Create uniform probability distributions
a = np.ones(n_samples) / n_samples # uniform weights for empirical
b = np.ones(n_samples) / n_samples # uniform weights for fitted
# Compute cost matrix (squared Euclidean distances)
cost_matrix = ot.dist(posterior_draws, fitted_samples, metric='sqeuclidean')
# Compute optimal transport cost
W2_distance = ot.emd2(a, b, cost_matrix)
return W2_distance
# Initialize with sample mean and std
init_mu = np.mean(posterior_draws, axis=0)
init_std = np.std(posterior_draws, axis=0)
init_log_std = np.log(init_std + 1e-6) # avoid log(0)
init_params = np.concatenate([init_mu, init_log_std])
# Set bounds to ensure reasonable values
bounds = [(None, None)] * D + [(-5, 5)] * D # bounds for log std devs
# Optimize
try:
result = minimize(objective, init_params, method='L-BFGS-B', bounds=bounds)
# Extract optimized parameters
mu_opt = result.x[:D]
log_std_opt = result.x[D:]
Sigma_opt = np.diag(np.exp(log_std_opt))
return mu_opt, Sigma_opt
except Exception as e:
print(f"Optimization failed: {e}")
print("Falling back to moment matching")
# Fallback to simple moment matching
mu_fallback = np.mean(posterior_draws, axis=0)
Sigma_fallback = np.cov(posterior_draws.T)
return mu_fallback, Sigma_fallback
# Or use the much simpler closed-form solution:
def fit_wasserstein_prior_simple(posterior_draws):
"""
Closed-form Wasserstein-optimal Gaussian: just match moments
"""
mu = np.mean(posterior_draws, axis=0)
Sigma = np.cov(posterior_draws.T)
return mu, Sigma
# Use the simple version:
mu_prior, Sigma_prior = fit_wasserstein_prior_simple(beta_OH_draws)
# Use in Stan model (modified version)
data_NY_wass = {
'N': N, 'D': D, 'x': x_NY, 'y': y_NY,
'mu_prior': mu_prior, 'Sigma_prior': Sigma_prior
}
def fit_wasserstein_prior(draws):
"""
Returns the Gaussian (mu, Sigma) that minimizes W2 distance
to the empirical measure of 'draws' via matching first two moments.
"""
mu = np.mean(draws, axis=0)
Sigma = np.cov(draws, rowvar=False)
return mu, Sigma
# After fitting the Ohio model:
beta_OH_draws = fit_OH.stan_variable('beta') # shape (B, D)
# Fit the W2‐optimal Gaussian prior
mu_prior, Sigma_prior = fit_wasserstein_prior(beta_OH_draws)
# Prepare data for the next Stan model (NY with Gaussian prior)
data_NY_wass = {
'N': N,
'D': D,
'x': x_NY,
'y': y_NY,
'mu_prior': mu_prior,
'Sigma_prior': Sigma_prior
}
model_NY_wass = csp.CmdStanModel(stan_file='gaussian-prior-logistic.stan')
fit_NY_wass = model_NY_wass.sample(data=data_NY_wass, chains=4, iter_warmup=500, iter_sampling=500, seed=seed)
print(fit_NY_wass.summary())
import matplotlib.pyplot as plt
import seaborn as sns
def plot_prior_comparison(empirical_draws, mu_wass, Sigma_wass, h_kde):
"""Compare KDE vs Wasserstein approximations visually"""
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
for dim in range(min(6, D)):
ax = axes[dim // 3, dim % 3]
# Plot empirical posterior
ax.hist(empirical_draws[:, dim], bins=50, alpha=0.3,
label='Empirical posterior', density=True)
# Plot KDE approximation
x_range = np.linspace(empirical_draws[:, dim].min(),
empirical_draws[:, dim].max(), 100)
kde_samples = rng.normal(mu_wass[dim], np.sqrt(Sigma_wass[dim, dim]), 10000)
ax.hist(kde_samples, bins=50, alpha=0.3,
label='Wasserstein prior', density=True)
ax.set_title(f'Dimension {dim}')
ax.legend()
plt.tight_layout()
plt.show()
# After fitting both methods
plot_prior_comparison(beta_OH_draws, mu_prior, Sigma_prior, h)
def posterior_predictive_check(x_test, beta_draws, beta_true):
"""Check if posterior draws generate reasonable predictions"""
# Generate predictions from posterior draws
pred_probs = []
for beta_draw in beta_draws[-1000:]: # Use last 1000 draws
pred_probs.append(expit(x_test @ beta_draw))
pred_probs = np.array(pred_probs)
# True probabilities
true_probs = expit(x_test @ beta_true)
# Check coverage
pred_mean = np.mean(pred_probs, axis=0)
pred_lower = np.percentile(pred_probs, 2.5, axis=0)
pred_upper = np.percentile(pred_probs, 97.5, axis=0)
coverage = np.mean((true_probs >= pred_lower) & (true_probs <= pred_upper))
print(f"95% credible interval coverage: {coverage:.3f}")
# Plot calibration
plt.figure(figsize=(10, 6))
plt.subplot(1, 2, 1)
plt.scatter(true_probs, pred_mean, alpha=0.5)
plt.plot([0, 1], [0, 1], 'r--')
plt.xlabel('True probability')
plt.ylabel('Predicted probability')
plt.title('Calibration')
# Plot prediction intervals
plt.subplot(1, 2, 2)
order = np.argsort(true_probs)
plt.fill_between(range(len(order)), pred_lower[order], pred_upper[order], alpha=0.3)
plt.plot(true_probs[order], 'r-', label='True')
plt.plot(pred_mean[order], 'b-', label='Predicted')
plt.legend()
plt.title('Prediction intervals')
plt.tight_layout()
plt.show()
# Test all three methods
x_test = rng.normal(loc=0.0, scale=1.0, size=(500, D))
posterior_predictive_check(x_test, fit_NY.stan_variable('beta'), beta)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment