Skip to content

Instantly share code, notes, and snippets.

@avivajpeyi
Last active April 22, 2025 21:51
Show Gist options
  • Save avivajpeyi/4b43e0cebf47cb8ca879ea5eb973782c to your computer and use it in GitHub Desktop.
Save avivajpeyi/4b43e0cebf47cb8ca879ea5eb973782c to your computer and use it in GitHub Desktop.
"""
Gibbs sampler + multivar signal with Adaptive Metropolis within Gibbs
"""
import numpy as np
import matplotlib.pyplot as plt
from scipy.fft import rfft, rfftfreq
from scipy.stats import uniform
from functools import partial
from tqdm.auto import trange
from scipy.signal.windows import tukey
FS = 32.0 # Sampling frequency
DT = 1.0 / FS # Sampling interval
N_SAMPLES = 512
N_DIM = 2
SIGNAL_PARAMS = dict(
A=0.1, # Amplitude
f0=8.0,
)
A_RANGE = (0.01, 2.0)
F0_RANGE = (5.0, 10.0)
N_ITER = 10000
BURN_IN = 500
THINNING = 1
ADAPTIVE_WINDOW = 100 # Window size for adapting the proposal variance
EPS = 1e-8 # Small constant for numerical stability
np.random.seed(0) # Set seed for reproducibility
def unif_logpdf(x, lower, upper):
return uniform.logpdf(x, loc=lower, scale=upper - lower)
def psd_model(freq, a=0.5, b=10.0, c=1.0):
return a / (1 + b * freq ** 2) ** c
def signal_model(A, f0, t, dim=N_DIM):
y = A * np.sin(2 * np.pi * f0 * t)
y *= tukey(len(t), alpha=0.1) # Apply Tukey window
return np.array([y] * int(dim)).T
def generate_data(A, f0, noise_psd_func, n_samples=N_SAMPLES, n_dim=N_DIM, dt=DT):
t = np.arange(n_samples) * dt
signal = signal_model(A, f0, t, dim=n_dim)
noise = generate_noise_from_psd(noise_psd_func, n_samples, dt, n_dim)
data = signal + noise
# SNR
h = np.fft.rfft(signal, axis=0)
psd = np.array([noise_psd_func(rfftfreq(n_samples, dt))] * 2).T
SNR = np.sqrt(inner_prod(
h, h, psd, dt, n_samples
))[0] # Compute optimal matched filtering SNR
print(f"SNR: {SNR:.2f}")
return data
def generate_noise_from_psd(noise_psd_func, n_samples=N_SAMPLES, dt=DT, n_dim=N_DIM):
noise_time = np.zeros((n_samples, n_dim))
freqs = rfftfreq(n_samples, dt)
true_psd_full = noise_psd_func(freqs)
for i in range(n_dim):
white_noise_fft = (np.random.normal(size=len(freqs)) + 1j * np.random.normal(size=len(freqs))) / np.sqrt(2)
white_noise_fft[0] /= np.sqrt(2) # Adjust DC component
colored_noise_fft = white_noise_fft * np.sqrt(true_psd_full)
noise_time[:, i] = np.fft.irfft(colored_noise_fft, n=n_samples)
return noise_time
def inner_prod(sig1_f, sig2_f, PSD, delta_t, N_t):
# Compute inner product. Useful for likelihood calculations and SNRs.
return (4 * delta_t / N_t) * np.real(
sum(np.conjugate(sig1_f) * sig2_f / PSD)
)
def snr(d, h, psd, dt, n_samples):
"""d, h, psd are 1D arrays"""
# Compute SNR for a single channel
assert len(d.shape) == 1, "Input data must be 1D"
inner_prod = np.real(np.sum((np.conj(h) * h) / psd))
scale = (4 * dt / n_samples)
return np.sqrt(scale * inner_prod)
def network_snr(d, h, psd, dt, n_samples):
# Compute SNR for a network of channels
dim = d.shape[1]
snr_values = [snr(d[:, i], h[:, i], psd[:, i], dt, n_samples) for i in range(dim)]
return np.sqrt(np.sum(np.array(snr_values) ** 2)) / dim
def whittle_lnl(signal_params, data_freq, psd):
"""Calculates the Whittle log-likelihood."""
signal = signal_model(*signal_params)
signal_freq = rfft(signal, axis=0)
dim = signal_freq.shape[1]
lnl = 0.0
for i in range(dim):
p = psd[:, i]
d = data_freq[:, i]
h = signal_freq[:, i]
# 1D Whittle log-likelihood
lnl += - np.sum((np.abs(d - h) ** 2) / p)
return lnl
def ln_prior(A, f0):
ln_a = unif_logpdf(A, *A_RANGE)
ln_f0 = unif_logpdf(f0, *F0_RANGE)
return ln_a + ln_f0
def ln_prob(signal_params, model_args, data_freq, psd):
A = signal_params[0]
f0 = signal_params[1]
ln_likelihood = whittle_lnl((A, f0, *model_args), data_freq, psd)
ln_prior_value = ln_prior(A, f0)
return ln_likelihood + ln_prior_value
def periodogram(data, dt=DT):
n_samples = data.shape[0]
freqs = rfftfreq(n_samples, dt)
psd = np.abs(rfft(data, axis=0)) ** 2 / n_samples
return freqs, psd
def compute_ci(posterior_samples, freqs, alpha=0.05):
signal_periodograms = []
t = np.arange(N_SAMPLES) * DT
for A, f0 in posterior_samples:
signal = signal_model(A, f0, t)
_, sp = periodogram(signal)
signal_periodograms.append(sp)
signal_periodograms = np.array(signal_periodograms)
ci_lower = np.percentile(signal_periodograms, 100 * alpha / 2, axis=0)
ci_upper = np.percentile(signal_periodograms, 100 * (1 - alpha / 2), axis=0)
median = np.median(signal_periodograms, axis=0)
return ci_lower, median, ci_upper
def plot_model(data, psd, posterior_samples=None):
n, dim = data.shape
t = np.arange(n) * DT
true_signal = signal_model(*SIGNAL_PARAMS.values(), t, dim)
freq = rfftfreq(n, DT)
_, data_periodogram = periodogram(data)
_, signal_periodogram = periodogram(true_signal)
ci = None
if posterior_samples is not None:
ci = compute_ci(posterior_samples, freq)
fig, axes = plt.subplots(dim, 1, figsize=(10, 4 * dim))
if dim == 1:
axes = [axes] # Ensure axes is iterable
for i in range(dim):
ax = axes[i]
ax.loglog(freq, data_periodogram[:, i], label="Data Periodogram", color='gray', alpha=0.5)
ax.loglog(freq, signal_periodogram[:, i], label="True Signal Periodogram", color='tab:red', alpha=0.5)
ax.loglog(freq, psd[:, i], label="True PSD", color='k', ls='--')
ax.set_xlabel("Frequency [Hz]")
ax.set_ylabel("Power/Frequency [1/Hz]")
ax.set_title(f"Channel {i + 1}")
if ci is not None:
ax.fill_between(freq, ci[0][:, i], ci[-1][:, i], alpha=0.3, label="Posterior", color="tab:blue")
ax.plot(freq, ci[1][:, i], color="tab:blue", lw=2)
axes[0].legend()
plt.tight_layout()
return fig, axes
def adaptive_metropolis_step(current_param, log_prob_fn, proposal_variance, step_size=1.0):
"""Adaptive Metropolis step for a single parameter."""
proposed_param = np.random.normal(current_param, np.sqrt(step_size * proposal_variance))
current_logprob = log_prob_fn(current_param)
proposed_logprob = log_prob_fn(proposed_param)
acceptance_ratio = np.exp(proposed_logprob - current_logprob)
if np.random.rand() < min(1, acceptance_ratio):
return proposed_param, True
else:
return current_param, False
def adaptive_gibbs_sampler(data, psd, model_args, n_iter=N_ITER, burn_in=BURN_IN, thinning=THINNING,
adaptive_window=ADAPTIVE_WINDOW, initial_step_size=0.1):
"""Gibbs sampler with Adaptive Metropolis for A and f0."""
n_samples, n_dim = data.shape
data_freq = rfft(data, axis=0)
_lnp = partial(ln_prob, model_args=model_args, data_freq=data_freq, psd=psd)
# Initialize chains
amplitude_chain = [np.random.uniform(*A_RANGE)]
frequency_chain = [np.random.uniform(*F0_RANGE)]
param_history = [[amplitude_chain[0], frequency_chain[0]]]
acceptance_rates = []
proposal_variance_A = initial_step_size ** 2
proposal_variance_f0 = initial_step_size ** 2
for i in trange(n_iter):
current_A = amplitude_chain[-1]
current_f0 = frequency_chain[-1]
# --- Sample Amplitude (Adaptive Metropolis) ---
def log_prob_A(a):
return _lnp((a, current_f0))
proposed_A, accepted_A = adaptive_metropolis_step(
current_A, log_prob_A, proposal_variance_A, step_size=1.0
)
amplitude_chain.append(proposed_A)
# --- Sample Frequency (Adaptive Metropolis) ---
def log_prob_f0(f):
return _lnp((amplitude_chain[-1], f))
proposed_f0, accepted_f0 = adaptive_metropolis_step(
current_f0, log_prob_f0, proposal_variance_f0, step_size=1.0
)
frequency_chain.append(proposed_f0)
param_history.append([amplitude_chain[-1], frequency_chain[-1]])
acceptance_rates.append([accepted_A, accepted_f0])
# --- Adapt proposal variance ---
if (i + 1) % adaptive_window == 0 and i >= burn_in:
history_window = np.array(param_history[-adaptive_window:])
if history_window.shape[0] > 1:
proposal_variance_A = np.var(history_window[:, 0]) + EPS
proposal_variance_f0 = np.var(history_window[:, 1]) + EPS
# only keep samples after burn-in
amplitude_samples = np.array(amplitude_chain[burn_in:])
frequency_samples = np.array(frequency_chain[burn_in:])
posterior_samples = np.vstack((amplitude_samples, frequency_samples)).T
return posterior_samples, acceptance_rates
def plot_posterior_trace(posterior_samples, acceptance_rates=None):
# Plot posterior distributions
fig_hist, axes_hist = plt.subplots(2, 1, figsize=(8, 6))
axes_hist[0].hist(posterior_samples[:, 0], bins=50, density=True, alpha=0.7, label="Posterior A")
axes_hist[0].axvline(SIGNAL_PARAMS["A"], color='r', linestyle='--', label="True A")
axes_hist[0].set_xlabel("Amplitude (A)")
axes_hist[0].legend()
axes_hist[1].hist(posterior_samples[:, 1], bins=50, density=True, alpha=0.7, label="Posterior f0")
axes_hist[1].axvline(SIGNAL_PARAMS["f0"], color='r', linestyle='--', label="True f0")
axes_hist[1].set_xlabel("Frequency (f0)")
axes_hist[1].legend()
plt.tight_layout()
plt.savefig('posterior_histogram_adaptive.png')
# plot the trace
fig_trace, axes_trace = plt.subplots(2, 1, figsize=(8, 6))
axes_trace[0].plot(posterior_samples[:, 0], label="Amplitude (A)")
axes_trace[0].set_ylabel("Amplitude (A)")
axes_trace[0].axhline(SIGNAL_PARAMS["A"], color='r', linestyle='--', label="True A")
axes_trace[0].legend()
axes_trace[1].plot(posterior_samples[:, 1], label="Frequency (f0)")
axes_trace[1].set_ylabel("Frequency (f0)")
axes_trace[1].axhline(SIGNAL_PARAMS["f0"], color='r', linestyle='--', label="True f0")
axes_trace[1].legend()
plt.tight_layout()
plt.savefig('posterior_trace_adaptive.png')
if acceptance_rates:
acceptance_rates_arr = np.array(acceptance_rates)
fig_acc, axes_acc = plt.subplots(2, 1, figsize=(8, 6))
axes_acc[0].plot(
np.convolve(acceptance_rates_arr[:, 0], np.ones(ADAPTIVE_WINDOW) / ADAPTIVE_WINDOW, mode='valid'),
label="Acceptance Rate A")
axes_acc[0].set_ylabel("Acceptance Rate (A)")
axes_acc[0].axhline(0.23, color='k', linestyle='--', label="Optimal ~0.23")
axes_acc[0].legend()
axes_acc[1].plot(
np.convolve(acceptance_rates_arr[:, 1], np.ones(ADAPTIVE_WINDOW) / ADAPTIVE_WINDOW, mode='valid'),
label="Acceptance Rate f0")
axes_acc[1].set_ylabel("Acceptance Rate (f0)")
axes_acc[1].axhline(0.23, color='k', linestyle='--', label="Optimal ~0.23")
axes_acc[1].set_xlabel("Iteration")
axes_acc[1].legend()
plt.tight_layout()
plt.savefig('acceptance_rates_adaptive.png')
if __name__ == "__main__":
# Generate data
t = np.arange(N_SAMPLES) * DT
freq = rfftfreq(N_SAMPLES, DT)
true_psd = np.array([psd_model(freq)] * N_DIM).T
data = generate_data(*SIGNAL_PARAMS.values(), noise_psd_func=psd_model, n_samples=N_SAMPLES, n_dim=N_DIM)
# Plot model with true parameters
fig_true, axes_true = plot_model(data, true_psd)
fig_true.suptitle("Data and True Model")
fig_true.savefig("data_adaptive.png")
# Run Adaptive Gibbs sampler
posterior_samples, acceptance_rates = adaptive_gibbs_sampler(data, true_psd, model_args=(t, N_DIM))
# chuck some more samples
posterior_samples = posterior_samples[1000:]
plot_posterior_trace(posterior_samples, acceptance_rates)
# plot model + posterior samples
fig_post, _ = plot_model(data, true_psd, posterior_samples=posterior_samples)
fig_post.suptitle("Posterior (Adaptive)")
fig_post.savefig("posterior_adaptive.png")
"""
Gibbs sampler + multivar signal with Adaptive Metropolis within Gibbs
"""
import numpy as np
import matplotlib.pyplot as plt
from scipy.fft import rfft, rfftfreq
from scipy.stats import uniform
from functools import partial
from tqdm.auto import trange
from sgvb_psd.psd_estimator import PSDEstimator
import h5py
FS = 32.0 # Sampling frequency
DT = 1.0 / FS # Sampling interval
N_SAMPLES = 512
N_DIM = 2
SIGNAL_PARAMS = dict(
A=0.1, # Amplitude
f0=8.0,
)
A_RANGE = (0.01, 2.0)
F0_RANGE = (5.0, 10.0)
N_ITER = 10000
SAMPLES_TO_KEEP = 5000
BURN_IN = 500
THINNING = 1
ADAPTIVE_WINDOW = 100 # Window size for adapting the proposal variance
SGVB_ITERATIONS = 500 # run SGVB every 500 iterations
EPS = 1e-8 # Small constant for numerical stability
np.random.seed(0) # Set seed for reproducibility
def unif_logpdf(x, lower, upper):
return uniform.logpdf(x, loc=lower, scale=upper - lower)
def psd_model(freq, a=0.5, b=10.0, c=1.0):
return a / (1 + b * freq ** 2) ** c
def signal_model(A, f0, t, dim=N_DIM):
y = A * np.sin(2 * np.pi * f0 * t)
# y *= tukey(len(t), alpha=0.01) # Apply Tukey window
return np.array([y] * int(dim)).T
def generate_data(A, f0, noise_psd_func, n_samples=N_SAMPLES, n_dim=N_DIM, dt=DT):
t = np.arange(n_samples) * dt
signal = signal_model(A, f0, t, dim=n_dim)
noise = generate_noise_from_psd(noise_psd_func, n_samples, dt, n_dim)
data = signal + noise
# SNR
h = np.fft.rfft(signal, axis=0)
psd = np.array([noise_psd_func(rfftfreq(n_samples, dt))] * 2).T
SNR = np.sqrt(inner_prod(
h, h, psd, dt, n_samples
))[0] # Compute optimal matched filtering SNR
print(f"SNR: {SNR:.2f}")
return data
def generate_noise_from_psd(noise_psd_func, n_samples=N_SAMPLES, dt=DT, n_dim=N_DIM):
noise_time = np.zeros((n_samples, n_dim))
freqs = rfftfreq(n_samples, dt)
true_psd_full = noise_psd_func(freqs)
for i in range(n_dim):
white_noise_fft = (np.random.normal(size=len(freqs)) + 1j * np.random.normal(size=len(freqs))) / np.sqrt(2)
white_noise_fft[0] /= np.sqrt(2) # Adjust DC component
colored_noise_fft = white_noise_fft * np.sqrt(true_psd_full)
noise_time[:, i] = np.fft.irfft(colored_noise_fft, n=n_samples)
return noise_time
def inner_prod(sig1_f, sig2_f, PSD, delta_t, N_t):
# Compute inner product. Useful for likelihood calculations and SNRs.
return (4 * delta_t / N_t) * np.real(
sum(np.conjugate(sig1_f) * sig2_f / PSD)
)
def snr(d, h, psd, dt, n_samples):
"""d, h, psd are 1D arrays"""
# Compute SNR for a single channel
assert len(d.shape) == 1, "Input data must be 1D"
inner_prod = np.real(np.sum((np.conj(h) * h) / psd))
scale = (4 * dt / n_samples)
return np.sqrt(scale * inner_prod)
def network_snr(d, h, psd, dt, n_samples):
# Compute SNR for a network of channels
dim = d.shape[1]
snr_values = [snr(d[:, i], h[:, i], psd[:, i], dt, n_samples) for i in range(dim)]
return np.sqrt(np.sum(np.array(snr_values) ** 2)) / dim
def whittle_lnl(signal_params, data_freq, psd):
"""Calculates the Whittle log-likelihood."""
signal = signal_model(*signal_params)
signal_freq = rfft(signal, axis=0)[1:-1]
dim = signal_freq.shape[1]
lnl = 0.0
for i in range(dim):
p = psd[:, i]
d = data_freq[:, i]
h = signal_freq[:, i]
# 1D Whittle log-likelihood
lnl += - np.sum((np.abs(d - h) ** 2) / p)
return lnl
def ln_prior(A, f0):
ln_a = unif_logpdf(A, *A_RANGE)
ln_f0 = unif_logpdf(f0, *F0_RANGE)
return ln_a + ln_f0
def ln_prob(signal_params, model_args, data_freq, psd):
A = signal_params[0]
f0 = signal_params[1]
ln_likelihood = whittle_lnl((A, f0, *model_args), data_freq, psd)
ln_prior_value = ln_prior(A, f0)
return ln_likelihood + ln_prior_value
def periodogram(data, dt=DT):
n_samples = data.shape[0]
freqs = rfftfreq(n_samples, dt)[1:-1]
psd = np.abs(rfft(data, axis=0)) ** 2
return freqs, psd[1:-1]
def compute_ci(posterior_samples, freqs, alpha=0.05):
signal_periodograms = []
t = np.arange(N_SAMPLES) * DT
for A, f0 in posterior_samples:
signal = signal_model(A, f0, t)
_, sp = periodogram(signal)
signal_periodograms.append(sp)
signal_periodograms = np.array(signal_periodograms)
ci_lower = np.percentile(signal_periodograms, 100 * alpha / 2, axis=0)
ci_upper = np.percentile(signal_periodograms, 100 * (1 - alpha / 2), axis=0)
median = np.median(signal_periodograms, axis=0)
return ci_lower, median, ci_upper
def plot_model(data, psd, posterior_samples=None, psd_ci=None, welch_psd=None, sgvb_psd=None):
n, dim = data.shape
t = np.arange(n) * DT
true_signal = signal_model(*SIGNAL_PARAMS.values(), t, dim)
freq = rfftfreq(n, DT)[1:-1]
_, data_periodogram = periodogram(data)
_, signal_periodogram = periodogram(true_signal)
ci = None
if posterior_samples is not None:
ci = compute_ci(posterior_samples, freq)
fig, axes = plt.subplots(dim, 1, figsize=(10, 4 * dim))
if dim == 1:
axes = [axes] # Ensure axes is iterable
for i in range(dim):
ax = axes[i]
ax.plot(freq, data_periodogram[:, i], label="Data Periodogram", color='gray', alpha=0.5)
ax.plot(freq, signal_periodogram[:, i], label="True Signal Periodogram", color='tab:red', alpha=0.5)
ax.plot(freq, psd[:, i], label="True PSD", color='k', ls='--')
ax.set_xlabel("Frequency [Hz]")
ax.set_ylabel("Power/Frequency [1/Hz]")
ax.set_title(f"Channel {i + 1}")
if ci is not None:
ax.fill_between(freq, ci[0][:, i], ci[-1][:, i], alpha=0.3, label="Posterior", color="tab:blue")
ax.plot(freq, ci[1][:, i], color="tab:blue", lw=2)
if psd_ci is not None:
_psdci = np.real(psd_ci[..., i, i])
ax.fill_between(freq, _psdci[0], _psdci[-1], alpha=0.3, label="SGVB CI", color="tab:orange")
# ax.plot(freq, _psdci[1], color="tab:orange", lw=2)
if sgvb_psd is not None:
ax.plot(freq, sgvb_psd[:, i], label="SGVB PSD", color='tab:orange', alpha=1, lw=2)
if welch_psd is not None:
ax.plot(freq, welch_psd[:, i], label="Welch PSD", color='tab:green', alpha=0.5)
ax.set_yscale('log')
ax.set_ylim(bottom=1e-6)
ax.set_xlim(freq[0], freq[-1])
axes[0].legend()
plt.tight_layout()
return fig, axes
def adaptive_metropolis_step(current_param, log_prob_fn, proposal_variance, step_size=1.0):
"""Adaptive Metropolis step for a single parameter."""
proposed_param = np.random.normal(current_param, np.sqrt(step_size * proposal_variance))
current_logprob = log_prob_fn(current_param)
proposed_logprob = log_prob_fn(proposed_param)
acceptance_ratio = np.exp(proposed_logprob - current_logprob)
if np.random.rand() < min(1, acceptance_ratio):
return proposed_param, True
else:
return current_param, False
def adaptive_gibbs_sampler(data, psd, model_args, n_iter=N_ITER, burn_in=BURN_IN, thinning=THINNING,
adaptive_window=ADAPTIVE_WINDOW, initial_step_size=0.1):
"""Gibbs sampler with Adaptive Metropolis for A and f0."""
n_samples, n_dim = data.shape
data_freq = rfft(data, axis=0)[1:-1]
_lnp = partial(ln_prob, model_args=model_args, data_freq=data_freq, psd=psd)
# Initialize chains
amplitude_chain = [np.random.uniform(*A_RANGE)]
frequency_chain = [np.random.uniform(*F0_RANGE)]
param_history = [[amplitude_chain[0], frequency_chain[0]]]
acceptance_rates = []
proposal_variance_A = initial_step_size ** 2
proposal_variance_f0 = initial_step_size ** 2
for i in trange(n_iter):
current_A = amplitude_chain[-1]
current_f0 = frequency_chain[-1]
# --- Sample Amplitude (Adaptive Metropolis) ---
def log_prob_A(a):
return _lnp((a, current_f0))
proposed_A, accepted_A = adaptive_metropolis_step(
current_A, log_prob_A, proposal_variance_A, step_size=1.0
)
amplitude_chain.append(proposed_A)
# --- Sample Frequency (Adaptive Metropolis) ---
def log_prob_f0(f):
return _lnp((amplitude_chain[-1], f))
proposed_f0, accepted_f0 = adaptive_metropolis_step(
current_f0, log_prob_f0, proposal_variance_f0, step_size=1.0
)
frequency_chain.append(proposed_f0)
param_history.append([amplitude_chain[-1], frequency_chain[-1]])
acceptance_rates.append([accepted_A, accepted_f0])
# --- Adapt proposal variance ---
if (i + 1) % adaptive_window == 0 and i >= burn_in:
history_window = np.array(param_history[-adaptive_window:])
if history_window.shape[0] > 1:
proposal_variance_A = np.var(history_window[:, 0]) + EPS
proposal_variance_f0 = np.var(history_window[:, 1]) + EPS
# update PSD
if (i + 1) % SGVB_ITERATIONS == 0 and i >= burn_in:
# current signal model
current_signal = signal_model(amplitude_chain[-1], frequency_chain[-1], *model_args)
residuals = data - current_signal
psd = get_sgvb_psd_estimate(residuals, i)
_lnp = partial(ln_prob, model_args=model_args, data_freq=data_freq, psd=psd)
# only keep samples after burn-in
amplitude_samples = np.array(amplitude_chain[burn_in:])
frequency_samples = np.array(frequency_chain[burn_in:])
posterior_samples = np.vstack((amplitude_samples, frequency_samples)).T
return posterior_samples, acceptance_rates, psd
def get_sgvb_psd_estimate(residuals, i: int):
psd_estim = PSDEstimator(x=residuals, fs=1 / DT, nchunks=1, max_hyperparm_eval=1)
psd_all, pointwise_ci, uniform_ci = psd_estim.run(lr=1.7117e-02)
axes = psd_estim.plot()
freq = psd_estim.freq
true_psd = psd_model(freq)
# plot the true PSD
axes[0, 0].plot(freq, true_psd, label="True PSD", color='k', ls='--')
axes[1, 1].plot(freq, true_psd, label="True PSD", color='k', ls='--')
# remove axes [0, 1], [1, 0] from the plot
axes[0, 1].remove()
axes[1, 0].remove()
plt.savefig(f"psd_estim_{i}.png")
# save uniform confidence intervals
with h5py.File(f"psd_estim.h5", "w") as f:
f.create_dataset("uniform_ci", data=uniform_ci)
psd = np.real(psd_all[0])
return np.array([psd[:, 0, 0], psd[:, 1, 1]]).T
def plot_posterior_trace(posterior_samples, acceptance_rates=None):
# Plot posterior distributions
fig_hist, axes_hist = plt.subplots(2, 1, figsize=(8, 6))
axes_hist[0].hist(posterior_samples[:, 0], bins=50, density=True, alpha=0.7, label="Posterior A")
axes_hist[0].axvline(SIGNAL_PARAMS["A"], color='r', linestyle='--', label="True A")
axes_hist[0].set_xlabel("Amplitude (A)")
axes_hist[0].legend()
axes_hist[1].hist(posterior_samples[:, 1], bins=50, density=True, alpha=0.7, label="Posterior f0")
axes_hist[1].axvline(SIGNAL_PARAMS["f0"], color='r', linestyle='--', label="True f0")
axes_hist[1].set_xlabel("Frequency (f0)")
axes_hist[1].legend()
plt.tight_layout()
plt.savefig('posterior_histogram_adaptive.png')
# plot the trace
fig_trace, axes_trace = plt.subplots(2, 1, figsize=(8, 6))
axes_trace[0].plot(posterior_samples[:, 0], label="Amplitude (A)")
axes_trace[0].set_ylabel("Amplitude (A)")
axes_trace[0].axhline(SIGNAL_PARAMS["A"], color='r', linestyle='--', label="True A")
axes_trace[0].legend()
axes_trace[1].plot(posterior_samples[:, 1], label="Frequency (f0)")
axes_trace[1].set_ylabel("Frequency (f0)")
axes_trace[1].axhline(SIGNAL_PARAMS["f0"], color='r', linestyle='--', label="True f0")
axes_trace[1].legend()
plt.tight_layout()
plt.savefig('posterior_trace_adaptive.png')
if acceptance_rates:
acceptance_rates_arr = np.array(acceptance_rates)
fig_acc, axes_acc = plt.subplots(2, 1, figsize=(8, 6))
axes_acc[0].plot(
np.convolve(acceptance_rates_arr[:, 0], np.ones(ADAPTIVE_WINDOW) / ADAPTIVE_WINDOW, mode='valid'),
label="Acceptance Rate A")
axes_acc[0].set_ylabel("Acceptance Rate (A)")
axes_acc[0].axhline(0.23, color='k', linestyle='--', label="Optimal ~0.23")
axes_acc[0].legend()
axes_acc[1].plot(
np.convolve(acceptance_rates_arr[:, 1], np.ones(ADAPTIVE_WINDOW) / ADAPTIVE_WINDOW, mode='valid'),
label="Acceptance Rate f0")
axes_acc[1].set_ylabel("Acceptance Rate (f0)")
axes_acc[1].axhline(0.23, color='k', linestyle='--', label="Optimal ~0.23")
axes_acc[1].set_xlabel("Iteration")
axes_acc[1].legend()
plt.tight_layout()
plt.savefig('acceptance_rates_adaptive.png')
def generate_welch_psd(n_chunks, noise_psd_func, n_samples, dt, n_dim):
"""Generate median Welch PSD"""
noises = [generate_noise_from_psd(noise_psd_func, n_samples, dt, n_dim) for _ in range(n_chunks)]
## compute welch median psd
psd = np.zeros((n_chunks, (n_samples // 2) - 1, n_dim))
for i in range(n_chunks):
_, psd[i] = periodogram(noises[i])
return np.median(np.real(psd), axis=0)
def load_psd_ci(scaling=1.0):
with h5py.File("psd_estim.h5", "r") as f:
uniform_ci = f["uniform_ci"][:]
return uniform_ci / scaling ** 2
if __name__ == "__main__":
# Generate data
t = np.arange(N_SAMPLES) * DT
freq = rfftfreq(N_SAMPLES, DT)[1:-1]
true_psd = np.array([psd_model(freq)] * N_DIM).T
welch_psd = generate_welch_psd(128, psd_model, N_SAMPLES, DT, N_DIM)
true_signal = signal_model(*SIGNAL_PARAMS.values(), t, dim=N_DIM)
data = generate_data(*SIGNAL_PARAMS.values(), noise_psd_func=psd_model, n_samples=N_SAMPLES, n_dim=N_DIM)
residual = data - true_signal
scaling = np.std(residual)
# # Plot model with true parameters
fig_true, axes_true = plot_model(data, true_psd, welch_psd=welch_psd, psd_ci=load_psd_ci(scaling))
fig_true.savefig("data_adaptive.png")
# test out SGVB once
sgvb_psd = get_sgvb_psd_estimate(data - true_signal, 0)
assert sgvb_psd.shape == true_psd.shape, f"SGVB PSD shape {sgvb_psd.shape} mismatch with true PSD shape {true_psd.shape}"
# Run Adaptive Gibbs sampler
posterior_samples, acceptance_rates, sgvb_psd = adaptive_gibbs_sampler(data, welch_psd, model_args=(t, N_DIM))
# Grab final "SAMPLES_TO_KEEP" posterior samples for plotting
posterior_samples = posterior_samples[-SAMPLES_TO_KEEP:]
plot_posterior_trace(posterior_samples, acceptance_rates)
# plot model + posterior samples
fig_post, _ = plot_model(
data, true_psd,
posterior_samples=posterior_samples,
welch_psd=welch_psd,
psd_ci=load_psd_ci(scaling),
sgvb_psd=sgvb_psd
)
fig_post.savefig("posterior_adaptive.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment