Last active
April 22, 2025 21:51
-
-
Save avivajpeyi/4b43e0cebf47cb8ca879ea5eb973782c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Gibbs sampler + multivar signal with Adaptive Metropolis within Gibbs | |
""" | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from scipy.fft import rfft, rfftfreq | |
from scipy.stats import uniform | |
from functools import partial | |
from tqdm.auto import trange | |
from scipy.signal.windows import tukey | |
FS = 32.0 # Sampling frequency | |
DT = 1.0 / FS # Sampling interval | |
N_SAMPLES = 512 | |
N_DIM = 2 | |
SIGNAL_PARAMS = dict( | |
A=0.1, # Amplitude | |
f0=8.0, | |
) | |
A_RANGE = (0.01, 2.0) | |
F0_RANGE = (5.0, 10.0) | |
N_ITER = 10000 | |
BURN_IN = 500 | |
THINNING = 1 | |
ADAPTIVE_WINDOW = 100 # Window size for adapting the proposal variance | |
EPS = 1e-8 # Small constant for numerical stability | |
np.random.seed(0) # Set seed for reproducibility | |
def unif_logpdf(x, lower, upper): | |
return uniform.logpdf(x, loc=lower, scale=upper - lower) | |
def psd_model(freq, a=0.5, b=10.0, c=1.0): | |
return a / (1 + b * freq ** 2) ** c | |
def signal_model(A, f0, t, dim=N_DIM): | |
y = A * np.sin(2 * np.pi * f0 * t) | |
y *= tukey(len(t), alpha=0.1) # Apply Tukey window | |
return np.array([y] * int(dim)).T | |
def generate_data(A, f0, noise_psd_func, n_samples=N_SAMPLES, n_dim=N_DIM, dt=DT): | |
t = np.arange(n_samples) * dt | |
signal = signal_model(A, f0, t, dim=n_dim) | |
noise = generate_noise_from_psd(noise_psd_func, n_samples, dt, n_dim) | |
data = signal + noise | |
# SNR | |
h = np.fft.rfft(signal, axis=0) | |
psd = np.array([noise_psd_func(rfftfreq(n_samples, dt))] * 2).T | |
SNR = np.sqrt(inner_prod( | |
h, h, psd, dt, n_samples | |
))[0] # Compute optimal matched filtering SNR | |
print(f"SNR: {SNR:.2f}") | |
return data | |
def generate_noise_from_psd(noise_psd_func, n_samples=N_SAMPLES, dt=DT, n_dim=N_DIM): | |
noise_time = np.zeros((n_samples, n_dim)) | |
freqs = rfftfreq(n_samples, dt) | |
true_psd_full = noise_psd_func(freqs) | |
for i in range(n_dim): | |
white_noise_fft = (np.random.normal(size=len(freqs)) + 1j * np.random.normal(size=len(freqs))) / np.sqrt(2) | |
white_noise_fft[0] /= np.sqrt(2) # Adjust DC component | |
colored_noise_fft = white_noise_fft * np.sqrt(true_psd_full) | |
noise_time[:, i] = np.fft.irfft(colored_noise_fft, n=n_samples) | |
return noise_time | |
def inner_prod(sig1_f, sig2_f, PSD, delta_t, N_t): | |
# Compute inner product. Useful for likelihood calculations and SNRs. | |
return (4 * delta_t / N_t) * np.real( | |
sum(np.conjugate(sig1_f) * sig2_f / PSD) | |
) | |
def snr(d, h, psd, dt, n_samples): | |
"""d, h, psd are 1D arrays""" | |
# Compute SNR for a single channel | |
assert len(d.shape) == 1, "Input data must be 1D" | |
inner_prod = np.real(np.sum((np.conj(h) * h) / psd)) | |
scale = (4 * dt / n_samples) | |
return np.sqrt(scale * inner_prod) | |
def network_snr(d, h, psd, dt, n_samples): | |
# Compute SNR for a network of channels | |
dim = d.shape[1] | |
snr_values = [snr(d[:, i], h[:, i], psd[:, i], dt, n_samples) for i in range(dim)] | |
return np.sqrt(np.sum(np.array(snr_values) ** 2)) / dim | |
def whittle_lnl(signal_params, data_freq, psd): | |
"""Calculates the Whittle log-likelihood.""" | |
signal = signal_model(*signal_params) | |
signal_freq = rfft(signal, axis=0) | |
dim = signal_freq.shape[1] | |
lnl = 0.0 | |
for i in range(dim): | |
p = psd[:, i] | |
d = data_freq[:, i] | |
h = signal_freq[:, i] | |
# 1D Whittle log-likelihood | |
lnl += - np.sum((np.abs(d - h) ** 2) / p) | |
return lnl | |
def ln_prior(A, f0): | |
ln_a = unif_logpdf(A, *A_RANGE) | |
ln_f0 = unif_logpdf(f0, *F0_RANGE) | |
return ln_a + ln_f0 | |
def ln_prob(signal_params, model_args, data_freq, psd): | |
A = signal_params[0] | |
f0 = signal_params[1] | |
ln_likelihood = whittle_lnl((A, f0, *model_args), data_freq, psd) | |
ln_prior_value = ln_prior(A, f0) | |
return ln_likelihood + ln_prior_value | |
def periodogram(data, dt=DT): | |
n_samples = data.shape[0] | |
freqs = rfftfreq(n_samples, dt) | |
psd = np.abs(rfft(data, axis=0)) ** 2 / n_samples | |
return freqs, psd | |
def compute_ci(posterior_samples, freqs, alpha=0.05): | |
signal_periodograms = [] | |
t = np.arange(N_SAMPLES) * DT | |
for A, f0 in posterior_samples: | |
signal = signal_model(A, f0, t) | |
_, sp = periodogram(signal) | |
signal_periodograms.append(sp) | |
signal_periodograms = np.array(signal_periodograms) | |
ci_lower = np.percentile(signal_periodograms, 100 * alpha / 2, axis=0) | |
ci_upper = np.percentile(signal_periodograms, 100 * (1 - alpha / 2), axis=0) | |
median = np.median(signal_periodograms, axis=0) | |
return ci_lower, median, ci_upper | |
def plot_model(data, psd, posterior_samples=None): | |
n, dim = data.shape | |
t = np.arange(n) * DT | |
true_signal = signal_model(*SIGNAL_PARAMS.values(), t, dim) | |
freq = rfftfreq(n, DT) | |
_, data_periodogram = periodogram(data) | |
_, signal_periodogram = periodogram(true_signal) | |
ci = None | |
if posterior_samples is not None: | |
ci = compute_ci(posterior_samples, freq) | |
fig, axes = plt.subplots(dim, 1, figsize=(10, 4 * dim)) | |
if dim == 1: | |
axes = [axes] # Ensure axes is iterable | |
for i in range(dim): | |
ax = axes[i] | |
ax.loglog(freq, data_periodogram[:, i], label="Data Periodogram", color='gray', alpha=0.5) | |
ax.loglog(freq, signal_periodogram[:, i], label="True Signal Periodogram", color='tab:red', alpha=0.5) | |
ax.loglog(freq, psd[:, i], label="True PSD", color='k', ls='--') | |
ax.set_xlabel("Frequency [Hz]") | |
ax.set_ylabel("Power/Frequency [1/Hz]") | |
ax.set_title(f"Channel {i + 1}") | |
if ci is not None: | |
ax.fill_between(freq, ci[0][:, i], ci[-1][:, i], alpha=0.3, label="Posterior", color="tab:blue") | |
ax.plot(freq, ci[1][:, i], color="tab:blue", lw=2) | |
axes[0].legend() | |
plt.tight_layout() | |
return fig, axes | |
def adaptive_metropolis_step(current_param, log_prob_fn, proposal_variance, step_size=1.0): | |
"""Adaptive Metropolis step for a single parameter.""" | |
proposed_param = np.random.normal(current_param, np.sqrt(step_size * proposal_variance)) | |
current_logprob = log_prob_fn(current_param) | |
proposed_logprob = log_prob_fn(proposed_param) | |
acceptance_ratio = np.exp(proposed_logprob - current_logprob) | |
if np.random.rand() < min(1, acceptance_ratio): | |
return proposed_param, True | |
else: | |
return current_param, False | |
def adaptive_gibbs_sampler(data, psd, model_args, n_iter=N_ITER, burn_in=BURN_IN, thinning=THINNING, | |
adaptive_window=ADAPTIVE_WINDOW, initial_step_size=0.1): | |
"""Gibbs sampler with Adaptive Metropolis for A and f0.""" | |
n_samples, n_dim = data.shape | |
data_freq = rfft(data, axis=0) | |
_lnp = partial(ln_prob, model_args=model_args, data_freq=data_freq, psd=psd) | |
# Initialize chains | |
amplitude_chain = [np.random.uniform(*A_RANGE)] | |
frequency_chain = [np.random.uniform(*F0_RANGE)] | |
param_history = [[amplitude_chain[0], frequency_chain[0]]] | |
acceptance_rates = [] | |
proposal_variance_A = initial_step_size ** 2 | |
proposal_variance_f0 = initial_step_size ** 2 | |
for i in trange(n_iter): | |
current_A = amplitude_chain[-1] | |
current_f0 = frequency_chain[-1] | |
# --- Sample Amplitude (Adaptive Metropolis) --- | |
def log_prob_A(a): | |
return _lnp((a, current_f0)) | |
proposed_A, accepted_A = adaptive_metropolis_step( | |
current_A, log_prob_A, proposal_variance_A, step_size=1.0 | |
) | |
amplitude_chain.append(proposed_A) | |
# --- Sample Frequency (Adaptive Metropolis) --- | |
def log_prob_f0(f): | |
return _lnp((amplitude_chain[-1], f)) | |
proposed_f0, accepted_f0 = adaptive_metropolis_step( | |
current_f0, log_prob_f0, proposal_variance_f0, step_size=1.0 | |
) | |
frequency_chain.append(proposed_f0) | |
param_history.append([amplitude_chain[-1], frequency_chain[-1]]) | |
acceptance_rates.append([accepted_A, accepted_f0]) | |
# --- Adapt proposal variance --- | |
if (i + 1) % adaptive_window == 0 and i >= burn_in: | |
history_window = np.array(param_history[-adaptive_window:]) | |
if history_window.shape[0] > 1: | |
proposal_variance_A = np.var(history_window[:, 0]) + EPS | |
proposal_variance_f0 = np.var(history_window[:, 1]) + EPS | |
# only keep samples after burn-in | |
amplitude_samples = np.array(amplitude_chain[burn_in:]) | |
frequency_samples = np.array(frequency_chain[burn_in:]) | |
posterior_samples = np.vstack((amplitude_samples, frequency_samples)).T | |
return posterior_samples, acceptance_rates | |
def plot_posterior_trace(posterior_samples, acceptance_rates=None): | |
# Plot posterior distributions | |
fig_hist, axes_hist = plt.subplots(2, 1, figsize=(8, 6)) | |
axes_hist[0].hist(posterior_samples[:, 0], bins=50, density=True, alpha=0.7, label="Posterior A") | |
axes_hist[0].axvline(SIGNAL_PARAMS["A"], color='r', linestyle='--', label="True A") | |
axes_hist[0].set_xlabel("Amplitude (A)") | |
axes_hist[0].legend() | |
axes_hist[1].hist(posterior_samples[:, 1], bins=50, density=True, alpha=0.7, label="Posterior f0") | |
axes_hist[1].axvline(SIGNAL_PARAMS["f0"], color='r', linestyle='--', label="True f0") | |
axes_hist[1].set_xlabel("Frequency (f0)") | |
axes_hist[1].legend() | |
plt.tight_layout() | |
plt.savefig('posterior_histogram_adaptive.png') | |
# plot the trace | |
fig_trace, axes_trace = plt.subplots(2, 1, figsize=(8, 6)) | |
axes_trace[0].plot(posterior_samples[:, 0], label="Amplitude (A)") | |
axes_trace[0].set_ylabel("Amplitude (A)") | |
axes_trace[0].axhline(SIGNAL_PARAMS["A"], color='r', linestyle='--', label="True A") | |
axes_trace[0].legend() | |
axes_trace[1].plot(posterior_samples[:, 1], label="Frequency (f0)") | |
axes_trace[1].set_ylabel("Frequency (f0)") | |
axes_trace[1].axhline(SIGNAL_PARAMS["f0"], color='r', linestyle='--', label="True f0") | |
axes_trace[1].legend() | |
plt.tight_layout() | |
plt.savefig('posterior_trace_adaptive.png') | |
if acceptance_rates: | |
acceptance_rates_arr = np.array(acceptance_rates) | |
fig_acc, axes_acc = plt.subplots(2, 1, figsize=(8, 6)) | |
axes_acc[0].plot( | |
np.convolve(acceptance_rates_arr[:, 0], np.ones(ADAPTIVE_WINDOW) / ADAPTIVE_WINDOW, mode='valid'), | |
label="Acceptance Rate A") | |
axes_acc[0].set_ylabel("Acceptance Rate (A)") | |
axes_acc[0].axhline(0.23, color='k', linestyle='--', label="Optimal ~0.23") | |
axes_acc[0].legend() | |
axes_acc[1].plot( | |
np.convolve(acceptance_rates_arr[:, 1], np.ones(ADAPTIVE_WINDOW) / ADAPTIVE_WINDOW, mode='valid'), | |
label="Acceptance Rate f0") | |
axes_acc[1].set_ylabel("Acceptance Rate (f0)") | |
axes_acc[1].axhline(0.23, color='k', linestyle='--', label="Optimal ~0.23") | |
axes_acc[1].set_xlabel("Iteration") | |
axes_acc[1].legend() | |
plt.tight_layout() | |
plt.savefig('acceptance_rates_adaptive.png') | |
if __name__ == "__main__": | |
# Generate data | |
t = np.arange(N_SAMPLES) * DT | |
freq = rfftfreq(N_SAMPLES, DT) | |
true_psd = np.array([psd_model(freq)] * N_DIM).T | |
data = generate_data(*SIGNAL_PARAMS.values(), noise_psd_func=psd_model, n_samples=N_SAMPLES, n_dim=N_DIM) | |
# Plot model with true parameters | |
fig_true, axes_true = plot_model(data, true_psd) | |
fig_true.suptitle("Data and True Model") | |
fig_true.savefig("data_adaptive.png") | |
# Run Adaptive Gibbs sampler | |
posterior_samples, acceptance_rates = adaptive_gibbs_sampler(data, true_psd, model_args=(t, N_DIM)) | |
# chuck some more samples | |
posterior_samples = posterior_samples[1000:] | |
plot_posterior_trace(posterior_samples, acceptance_rates) | |
# plot model + posterior samples | |
fig_post, _ = plot_model(data, true_psd, posterior_samples=posterior_samples) | |
fig_post.suptitle("Posterior (Adaptive)") | |
fig_post.savefig("posterior_adaptive.png") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Gibbs sampler + multivar signal with Adaptive Metropolis within Gibbs | |
""" | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from scipy.fft import rfft, rfftfreq | |
from scipy.stats import uniform | |
from functools import partial | |
from tqdm.auto import trange | |
from sgvb_psd.psd_estimator import PSDEstimator | |
import h5py | |
FS = 32.0 # Sampling frequency | |
DT = 1.0 / FS # Sampling interval | |
N_SAMPLES = 512 | |
N_DIM = 2 | |
SIGNAL_PARAMS = dict( | |
A=0.1, # Amplitude | |
f0=8.0, | |
) | |
A_RANGE = (0.01, 2.0) | |
F0_RANGE = (5.0, 10.0) | |
N_ITER = 10000 | |
SAMPLES_TO_KEEP = 5000 | |
BURN_IN = 500 | |
THINNING = 1 | |
ADAPTIVE_WINDOW = 100 # Window size for adapting the proposal variance | |
SGVB_ITERATIONS = 500 # run SGVB every 500 iterations | |
EPS = 1e-8 # Small constant for numerical stability | |
np.random.seed(0) # Set seed for reproducibility | |
def unif_logpdf(x, lower, upper): | |
return uniform.logpdf(x, loc=lower, scale=upper - lower) | |
def psd_model(freq, a=0.5, b=10.0, c=1.0): | |
return a / (1 + b * freq ** 2) ** c | |
def signal_model(A, f0, t, dim=N_DIM): | |
y = A * np.sin(2 * np.pi * f0 * t) | |
# y *= tukey(len(t), alpha=0.01) # Apply Tukey window | |
return np.array([y] * int(dim)).T | |
def generate_data(A, f0, noise_psd_func, n_samples=N_SAMPLES, n_dim=N_DIM, dt=DT): | |
t = np.arange(n_samples) * dt | |
signal = signal_model(A, f0, t, dim=n_dim) | |
noise = generate_noise_from_psd(noise_psd_func, n_samples, dt, n_dim) | |
data = signal + noise | |
# SNR | |
h = np.fft.rfft(signal, axis=0) | |
psd = np.array([noise_psd_func(rfftfreq(n_samples, dt))] * 2).T | |
SNR = np.sqrt(inner_prod( | |
h, h, psd, dt, n_samples | |
))[0] # Compute optimal matched filtering SNR | |
print(f"SNR: {SNR:.2f}") | |
return data | |
def generate_noise_from_psd(noise_psd_func, n_samples=N_SAMPLES, dt=DT, n_dim=N_DIM): | |
noise_time = np.zeros((n_samples, n_dim)) | |
freqs = rfftfreq(n_samples, dt) | |
true_psd_full = noise_psd_func(freqs) | |
for i in range(n_dim): | |
white_noise_fft = (np.random.normal(size=len(freqs)) + 1j * np.random.normal(size=len(freqs))) / np.sqrt(2) | |
white_noise_fft[0] /= np.sqrt(2) # Adjust DC component | |
colored_noise_fft = white_noise_fft * np.sqrt(true_psd_full) | |
noise_time[:, i] = np.fft.irfft(colored_noise_fft, n=n_samples) | |
return noise_time | |
def inner_prod(sig1_f, sig2_f, PSD, delta_t, N_t): | |
# Compute inner product. Useful for likelihood calculations and SNRs. | |
return (4 * delta_t / N_t) * np.real( | |
sum(np.conjugate(sig1_f) * sig2_f / PSD) | |
) | |
def snr(d, h, psd, dt, n_samples): | |
"""d, h, psd are 1D arrays""" | |
# Compute SNR for a single channel | |
assert len(d.shape) == 1, "Input data must be 1D" | |
inner_prod = np.real(np.sum((np.conj(h) * h) / psd)) | |
scale = (4 * dt / n_samples) | |
return np.sqrt(scale * inner_prod) | |
def network_snr(d, h, psd, dt, n_samples): | |
# Compute SNR for a network of channels | |
dim = d.shape[1] | |
snr_values = [snr(d[:, i], h[:, i], psd[:, i], dt, n_samples) for i in range(dim)] | |
return np.sqrt(np.sum(np.array(snr_values) ** 2)) / dim | |
def whittle_lnl(signal_params, data_freq, psd): | |
"""Calculates the Whittle log-likelihood.""" | |
signal = signal_model(*signal_params) | |
signal_freq = rfft(signal, axis=0)[1:-1] | |
dim = signal_freq.shape[1] | |
lnl = 0.0 | |
for i in range(dim): | |
p = psd[:, i] | |
d = data_freq[:, i] | |
h = signal_freq[:, i] | |
# 1D Whittle log-likelihood | |
lnl += - np.sum((np.abs(d - h) ** 2) / p) | |
return lnl | |
def ln_prior(A, f0): | |
ln_a = unif_logpdf(A, *A_RANGE) | |
ln_f0 = unif_logpdf(f0, *F0_RANGE) | |
return ln_a + ln_f0 | |
def ln_prob(signal_params, model_args, data_freq, psd): | |
A = signal_params[0] | |
f0 = signal_params[1] | |
ln_likelihood = whittle_lnl((A, f0, *model_args), data_freq, psd) | |
ln_prior_value = ln_prior(A, f0) | |
return ln_likelihood + ln_prior_value | |
def periodogram(data, dt=DT): | |
n_samples = data.shape[0] | |
freqs = rfftfreq(n_samples, dt)[1:-1] | |
psd = np.abs(rfft(data, axis=0)) ** 2 | |
return freqs, psd[1:-1] | |
def compute_ci(posterior_samples, freqs, alpha=0.05): | |
signal_periodograms = [] | |
t = np.arange(N_SAMPLES) * DT | |
for A, f0 in posterior_samples: | |
signal = signal_model(A, f0, t) | |
_, sp = periodogram(signal) | |
signal_periodograms.append(sp) | |
signal_periodograms = np.array(signal_periodograms) | |
ci_lower = np.percentile(signal_periodograms, 100 * alpha / 2, axis=0) | |
ci_upper = np.percentile(signal_periodograms, 100 * (1 - alpha / 2), axis=0) | |
median = np.median(signal_periodograms, axis=0) | |
return ci_lower, median, ci_upper | |
def plot_model(data, psd, posterior_samples=None, psd_ci=None, welch_psd=None, sgvb_psd=None): | |
n, dim = data.shape | |
t = np.arange(n) * DT | |
true_signal = signal_model(*SIGNAL_PARAMS.values(), t, dim) | |
freq = rfftfreq(n, DT)[1:-1] | |
_, data_periodogram = periodogram(data) | |
_, signal_periodogram = periodogram(true_signal) | |
ci = None | |
if posterior_samples is not None: | |
ci = compute_ci(posterior_samples, freq) | |
fig, axes = plt.subplots(dim, 1, figsize=(10, 4 * dim)) | |
if dim == 1: | |
axes = [axes] # Ensure axes is iterable | |
for i in range(dim): | |
ax = axes[i] | |
ax.plot(freq, data_periodogram[:, i], label="Data Periodogram", color='gray', alpha=0.5) | |
ax.plot(freq, signal_periodogram[:, i], label="True Signal Periodogram", color='tab:red', alpha=0.5) | |
ax.plot(freq, psd[:, i], label="True PSD", color='k', ls='--') | |
ax.set_xlabel("Frequency [Hz]") | |
ax.set_ylabel("Power/Frequency [1/Hz]") | |
ax.set_title(f"Channel {i + 1}") | |
if ci is not None: | |
ax.fill_between(freq, ci[0][:, i], ci[-1][:, i], alpha=0.3, label="Posterior", color="tab:blue") | |
ax.plot(freq, ci[1][:, i], color="tab:blue", lw=2) | |
if psd_ci is not None: | |
_psdci = np.real(psd_ci[..., i, i]) | |
ax.fill_between(freq, _psdci[0], _psdci[-1], alpha=0.3, label="SGVB CI", color="tab:orange") | |
# ax.plot(freq, _psdci[1], color="tab:orange", lw=2) | |
if sgvb_psd is not None: | |
ax.plot(freq, sgvb_psd[:, i], label="SGVB PSD", color='tab:orange', alpha=1, lw=2) | |
if welch_psd is not None: | |
ax.plot(freq, welch_psd[:, i], label="Welch PSD", color='tab:green', alpha=0.5) | |
ax.set_yscale('log') | |
ax.set_ylim(bottom=1e-6) | |
ax.set_xlim(freq[0], freq[-1]) | |
axes[0].legend() | |
plt.tight_layout() | |
return fig, axes | |
def adaptive_metropolis_step(current_param, log_prob_fn, proposal_variance, step_size=1.0): | |
"""Adaptive Metropolis step for a single parameter.""" | |
proposed_param = np.random.normal(current_param, np.sqrt(step_size * proposal_variance)) | |
current_logprob = log_prob_fn(current_param) | |
proposed_logprob = log_prob_fn(proposed_param) | |
acceptance_ratio = np.exp(proposed_logprob - current_logprob) | |
if np.random.rand() < min(1, acceptance_ratio): | |
return proposed_param, True | |
else: | |
return current_param, False | |
def adaptive_gibbs_sampler(data, psd, model_args, n_iter=N_ITER, burn_in=BURN_IN, thinning=THINNING, | |
adaptive_window=ADAPTIVE_WINDOW, initial_step_size=0.1): | |
"""Gibbs sampler with Adaptive Metropolis for A and f0.""" | |
n_samples, n_dim = data.shape | |
data_freq = rfft(data, axis=0)[1:-1] | |
_lnp = partial(ln_prob, model_args=model_args, data_freq=data_freq, psd=psd) | |
# Initialize chains | |
amplitude_chain = [np.random.uniform(*A_RANGE)] | |
frequency_chain = [np.random.uniform(*F0_RANGE)] | |
param_history = [[amplitude_chain[0], frequency_chain[0]]] | |
acceptance_rates = [] | |
proposal_variance_A = initial_step_size ** 2 | |
proposal_variance_f0 = initial_step_size ** 2 | |
for i in trange(n_iter): | |
current_A = amplitude_chain[-1] | |
current_f0 = frequency_chain[-1] | |
# --- Sample Amplitude (Adaptive Metropolis) --- | |
def log_prob_A(a): | |
return _lnp((a, current_f0)) | |
proposed_A, accepted_A = adaptive_metropolis_step( | |
current_A, log_prob_A, proposal_variance_A, step_size=1.0 | |
) | |
amplitude_chain.append(proposed_A) | |
# --- Sample Frequency (Adaptive Metropolis) --- | |
def log_prob_f0(f): | |
return _lnp((amplitude_chain[-1], f)) | |
proposed_f0, accepted_f0 = adaptive_metropolis_step( | |
current_f0, log_prob_f0, proposal_variance_f0, step_size=1.0 | |
) | |
frequency_chain.append(proposed_f0) | |
param_history.append([amplitude_chain[-1], frequency_chain[-1]]) | |
acceptance_rates.append([accepted_A, accepted_f0]) | |
# --- Adapt proposal variance --- | |
if (i + 1) % adaptive_window == 0 and i >= burn_in: | |
history_window = np.array(param_history[-adaptive_window:]) | |
if history_window.shape[0] > 1: | |
proposal_variance_A = np.var(history_window[:, 0]) + EPS | |
proposal_variance_f0 = np.var(history_window[:, 1]) + EPS | |
# update PSD | |
if (i + 1) % SGVB_ITERATIONS == 0 and i >= burn_in: | |
# current signal model | |
current_signal = signal_model(amplitude_chain[-1], frequency_chain[-1], *model_args) | |
residuals = data - current_signal | |
psd = get_sgvb_psd_estimate(residuals, i) | |
_lnp = partial(ln_prob, model_args=model_args, data_freq=data_freq, psd=psd) | |
# only keep samples after burn-in | |
amplitude_samples = np.array(amplitude_chain[burn_in:]) | |
frequency_samples = np.array(frequency_chain[burn_in:]) | |
posterior_samples = np.vstack((amplitude_samples, frequency_samples)).T | |
return posterior_samples, acceptance_rates, psd | |
def get_sgvb_psd_estimate(residuals, i: int): | |
psd_estim = PSDEstimator(x=residuals, fs=1 / DT, nchunks=1, max_hyperparm_eval=1) | |
psd_all, pointwise_ci, uniform_ci = psd_estim.run(lr=1.7117e-02) | |
axes = psd_estim.plot() | |
freq = psd_estim.freq | |
true_psd = psd_model(freq) | |
# plot the true PSD | |
axes[0, 0].plot(freq, true_psd, label="True PSD", color='k', ls='--') | |
axes[1, 1].plot(freq, true_psd, label="True PSD", color='k', ls='--') | |
# remove axes [0, 1], [1, 0] from the plot | |
axes[0, 1].remove() | |
axes[1, 0].remove() | |
plt.savefig(f"psd_estim_{i}.png") | |
# save uniform confidence intervals | |
with h5py.File(f"psd_estim.h5", "w") as f: | |
f.create_dataset("uniform_ci", data=uniform_ci) | |
psd = np.real(psd_all[0]) | |
return np.array([psd[:, 0, 0], psd[:, 1, 1]]).T | |
def plot_posterior_trace(posterior_samples, acceptance_rates=None): | |
# Plot posterior distributions | |
fig_hist, axes_hist = plt.subplots(2, 1, figsize=(8, 6)) | |
axes_hist[0].hist(posterior_samples[:, 0], bins=50, density=True, alpha=0.7, label="Posterior A") | |
axes_hist[0].axvline(SIGNAL_PARAMS["A"], color='r', linestyle='--', label="True A") | |
axes_hist[0].set_xlabel("Amplitude (A)") | |
axes_hist[0].legend() | |
axes_hist[1].hist(posterior_samples[:, 1], bins=50, density=True, alpha=0.7, label="Posterior f0") | |
axes_hist[1].axvline(SIGNAL_PARAMS["f0"], color='r', linestyle='--', label="True f0") | |
axes_hist[1].set_xlabel("Frequency (f0)") | |
axes_hist[1].legend() | |
plt.tight_layout() | |
plt.savefig('posterior_histogram_adaptive.png') | |
# plot the trace | |
fig_trace, axes_trace = plt.subplots(2, 1, figsize=(8, 6)) | |
axes_trace[0].plot(posterior_samples[:, 0], label="Amplitude (A)") | |
axes_trace[0].set_ylabel("Amplitude (A)") | |
axes_trace[0].axhline(SIGNAL_PARAMS["A"], color='r', linestyle='--', label="True A") | |
axes_trace[0].legend() | |
axes_trace[1].plot(posterior_samples[:, 1], label="Frequency (f0)") | |
axes_trace[1].set_ylabel("Frequency (f0)") | |
axes_trace[1].axhline(SIGNAL_PARAMS["f0"], color='r', linestyle='--', label="True f0") | |
axes_trace[1].legend() | |
plt.tight_layout() | |
plt.savefig('posterior_trace_adaptive.png') | |
if acceptance_rates: | |
acceptance_rates_arr = np.array(acceptance_rates) | |
fig_acc, axes_acc = plt.subplots(2, 1, figsize=(8, 6)) | |
axes_acc[0].plot( | |
np.convolve(acceptance_rates_arr[:, 0], np.ones(ADAPTIVE_WINDOW) / ADAPTIVE_WINDOW, mode='valid'), | |
label="Acceptance Rate A") | |
axes_acc[0].set_ylabel("Acceptance Rate (A)") | |
axes_acc[0].axhline(0.23, color='k', linestyle='--', label="Optimal ~0.23") | |
axes_acc[0].legend() | |
axes_acc[1].plot( | |
np.convolve(acceptance_rates_arr[:, 1], np.ones(ADAPTIVE_WINDOW) / ADAPTIVE_WINDOW, mode='valid'), | |
label="Acceptance Rate f0") | |
axes_acc[1].set_ylabel("Acceptance Rate (f0)") | |
axes_acc[1].axhline(0.23, color='k', linestyle='--', label="Optimal ~0.23") | |
axes_acc[1].set_xlabel("Iteration") | |
axes_acc[1].legend() | |
plt.tight_layout() | |
plt.savefig('acceptance_rates_adaptive.png') | |
def generate_welch_psd(n_chunks, noise_psd_func, n_samples, dt, n_dim): | |
"""Generate median Welch PSD""" | |
noises = [generate_noise_from_psd(noise_psd_func, n_samples, dt, n_dim) for _ in range(n_chunks)] | |
## compute welch median psd | |
psd = np.zeros((n_chunks, (n_samples // 2) - 1, n_dim)) | |
for i in range(n_chunks): | |
_, psd[i] = periodogram(noises[i]) | |
return np.median(np.real(psd), axis=0) | |
def load_psd_ci(scaling=1.0): | |
with h5py.File("psd_estim.h5", "r") as f: | |
uniform_ci = f["uniform_ci"][:] | |
return uniform_ci / scaling ** 2 | |
if __name__ == "__main__": | |
# Generate data | |
t = np.arange(N_SAMPLES) * DT | |
freq = rfftfreq(N_SAMPLES, DT)[1:-1] | |
true_psd = np.array([psd_model(freq)] * N_DIM).T | |
welch_psd = generate_welch_psd(128, psd_model, N_SAMPLES, DT, N_DIM) | |
true_signal = signal_model(*SIGNAL_PARAMS.values(), t, dim=N_DIM) | |
data = generate_data(*SIGNAL_PARAMS.values(), noise_psd_func=psd_model, n_samples=N_SAMPLES, n_dim=N_DIM) | |
residual = data - true_signal | |
scaling = np.std(residual) | |
# # Plot model with true parameters | |
fig_true, axes_true = plot_model(data, true_psd, welch_psd=welch_psd, psd_ci=load_psd_ci(scaling)) | |
fig_true.savefig("data_adaptive.png") | |
# test out SGVB once | |
sgvb_psd = get_sgvb_psd_estimate(data - true_signal, 0) | |
assert sgvb_psd.shape == true_psd.shape, f"SGVB PSD shape {sgvb_psd.shape} mismatch with true PSD shape {true_psd.shape}" | |
# Run Adaptive Gibbs sampler | |
posterior_samples, acceptance_rates, sgvb_psd = adaptive_gibbs_sampler(data, welch_psd, model_args=(t, N_DIM)) | |
# Grab final "SAMPLES_TO_KEEP" posterior samples for plotting | |
posterior_samples = posterior_samples[-SAMPLES_TO_KEEP:] | |
plot_posterior_trace(posterior_samples, acceptance_rates) | |
# plot model + posterior samples | |
fig_post, _ = plot_model( | |
data, true_psd, | |
posterior_samples=posterior_samples, | |
welch_psd=welch_psd, | |
psd_ci=load_psd_ci(scaling), | |
sgvb_psd=sgvb_psd | |
) | |
fig_post.savefig("posterior_adaptive.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment