Skip to content

Instantly share code, notes, and snippets.

@avivajpeyi
Last active March 13, 2025 04:49
Show Gist options
  • Save avivajpeyi/4e2702a7e20565053b91726f025c0c44 to your computer and use it in GitHub Desktop.
Save avivajpeyi/4e2702a7e20565053b91726f025c0c44 to your computer and use it in GitHub Desktop.
psd_and_signal_estimation_example.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyNDilo9GPSgNybsPua2RdR8",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"2744c950ab8148c5a26b27889b303076": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_a32091baa8e041ff9c2f2898f3b16ccc",
"IPY_MODEL_4b3950b8c8254469bdfa07fae6868736",
"IPY_MODEL_56784a2ef7564b4d8dc388b70cc89218"
],
"layout": "IPY_MODEL_645fb53be59f4be6b96bb93b2236c670"
}
},
"a32091baa8e041ff9c2f2898f3b16ccc": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_07b05b257ab74478b43447a00c38ddca",
"placeholder": "​",
"style": "IPY_MODEL_ec7e837b56ad40feb620fe4cea65c224",
"value": "Gibbs Sampling: 100%"
}
},
"4b3950b8c8254469bdfa07fae6868736": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_c2ec3a6304584605ba246be1abab33f3",
"max": 30,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_69fcc6acc49b44408a4d1e4152ab62dd",
"value": 30
}
},
"56784a2ef7564b4d8dc388b70cc89218": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_69b74cce27f94f8e95fd5f3ab010d676",
"placeholder": "​",
"style": "IPY_MODEL_e50577642db849c088722e88a73ec36f",
"value": " 30/30 [01:57<00:00,  4.01s/it]"
}
},
"645fb53be59f4be6b96bb93b2236c670": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"07b05b257ab74478b43447a00c38ddca": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"ec7e837b56ad40feb620fe4cea65c224": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"c2ec3a6304584605ba246be1abab33f3": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"69fcc6acc49b44408a4d1e4152ab62dd": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"69b74cce27f94f8e95fd5f3ab010d676": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"e50577642db849c088722e88a73ec36f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/avivajpeyi/4e2702a7e20565053b91726f025c0c44/psd_and_signal_estimation_example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# PSD and Signal estimation\n",
"\n",
"\n"
],
"metadata": {
"id": "gEEwm4ZEN_r0"
}
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "L7IasQS_N59_",
"outputId": "8dae71af-6b46-44ef-9f3c-342643edab36"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting numpyro\n",
" Downloading numpyro-0.17.0-py3-none-any.whl.metadata (37 kB)\n",
"Requirement already satisfied: spectrum in /usr/local/lib/python3.11/dist-packages (0.9.0)\n",
"Requirement already satisfied: jax>=0.4.25 in /usr/local/lib/python3.11/dist-packages (from numpyro) (0.5.2)\n",
"Requirement already satisfied: jaxlib>=0.4.25 in /usr/local/lib/python3.11/dist-packages (from numpyro) (0.5.1)\n",
"Requirement already satisfied: multipledispatch in /usr/local/lib/python3.11/dist-packages (from numpyro) (1.0.0)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from numpyro) (1.26.4)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from numpyro) (4.67.1)\n",
"Requirement already satisfied: easydev in /usr/local/lib/python3.11/dist-packages (from spectrum) (0.13.3)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.11/dist-packages (from spectrum) (1.14.1)\n",
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.11/dist-packages (from spectrum) (3.10.0)\n",
"Requirement already satisfied: ml_dtypes>=0.4.0 in /usr/local/lib/python3.11/dist-packages (from jax>=0.4.25->numpyro) (0.4.1)\n",
"Requirement already satisfied: opt_einsum in /usr/local/lib/python3.11/dist-packages (from jax>=0.4.25->numpyro) (3.4.0)\n",
"Requirement already satisfied: colorama<0.5.0,>=0.4.6 in /usr/local/lib/python3.11/dist-packages (from easydev->spectrum) (0.4.6)\n",
"Requirement already satisfied: colorlog<7.0.0,>=6.8.2 in /usr/local/lib/python3.11/dist-packages (from easydev->spectrum) (6.9.0)\n",
"Requirement already satisfied: line-profiler<5.0.0,>=4.1.2 in /usr/local/lib/python3.11/dist-packages (from easydev->spectrum) (4.2.0)\n",
"Requirement already satisfied: pexpect<5.0.0,>=4.9.0 in /usr/local/lib/python3.11/dist-packages (from easydev->spectrum) (4.9.0)\n",
"Requirement already satisfied: platformdirs<5.0.0,>=4.2.0 in /usr/local/lib/python3.11/dist-packages (from easydev->spectrum) (4.3.6)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib->spectrum) (1.3.1)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib->spectrum) (0.12.1)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib->spectrum) (4.56.0)\n",
"Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib->spectrum) (1.4.8)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib->spectrum) (24.2)\n",
"Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib->spectrum) (11.1.0)\n",
"Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib->spectrum) (3.2.1)\n",
"Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.11/dist-packages (from matplotlib->spectrum) (2.8.2)\n",
"Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.11/dist-packages (from pexpect<5.0.0,>=4.9.0->easydev->spectrum) (0.7.0)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.7->matplotlib->spectrum) (1.17.0)\n",
"Downloading numpyro-0.17.0-py3-none-any.whl (360 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m360.8/360.8 kB\u001b[0m \u001b[31m15.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hInstalling collected packages: numpyro\n",
"Successfully installed numpyro-0.17.0\n"
]
}
],
"source": [
"! pip install numpyro spectrum"
]
},
{
"cell_type": "code",
"source": [
"import os\n",
"import glob\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import scipy.signal\n",
"import scipy.stats\n",
"import emcee\n",
"from spectrum import aryule, arma2psd\n",
"from tqdm.auto import trange\n",
"\n",
"# ------------------------------\n",
"# Helper functions for periodogram and PSD\n",
"# ------------------------------\n",
"def compute_periodogram(signal, dt):\n",
" \"\"\"Compute periodogram using FFT.\"\"\"\n",
" n = len(signal)\n",
" fft_vals = np.fft.fft(signal)\n",
" freqs = np.fft.fftfreq(n, d=dt)[: n // 2]\n",
" periodogram = (np.abs(fft_vals[: n // 2]) ** 2) * dt / n\n",
" return freqs, periodogram\n",
"\n",
"def compute_estimated_psd(AR, P, T, NFFT):\n",
" \"\"\"\n",
" Compute one-sided PSD from an AR model using arma2psd.\n",
" \"\"\"\n",
" psd_full = arma2psd(AR, [1], P, T, NFFT)\n",
" one_sided = psd_full[: NFFT // 2 + 1]\n",
" return one_sided\n",
"\n",
"# ------------------------------\n",
"# Plotting functions\n",
"# ------------------------------\n",
"def plot_progress(iteration, t, data, current_signal, dt, current_AR, current_P, fs,\n",
" true_signal_pdgrm, true_psd):\n",
" \"\"\"\n",
" Plot the posterior-predictive PSD and signal vs. true PSD and true signal.\n",
" Saves a figure for the current Gibbs iteration.\n",
" \"\"\"\n",
" fig, ax = plt.subplots(1, 1, figsize=(5, 3.5))\n",
"\n",
" # Compute periodograms\n",
" freqs, data_periodogram = compute_periodogram(data, dt)\n",
" _, signal_periodogram = compute_periodogram(current_signal, dt)\n",
"\n",
" # Set NFFT and frequency vector for PSD computation\n",
" NFFT = len(freqs) * 2\n",
" psd_freqs = np.linspace(0, fs / 2, NFFT // 2 + 1)\n",
"\n",
" # Estimated noise PSD from current AR model\n",
" est_noise_psd = compute_estimated_psd(current_AR, current_P, T=fs, NFFT=NFFT)\n",
"\n",
" # Plot on a log-log scale\n",
" ax.loglog(freqs, data_periodogram, label='Data', color='gray', alpha=0.5)\n",
" ax.loglog(freqs, signal_periodogram, label='Estimated Signal', color='green', linestyle='-.')\n",
" ax.loglog(freqs, true_signal_pdgrm, label='True Signal', color='green')\n",
" ax.loglog(psd_freqs, true_psd, label='True Noise', color='red')\n",
" ax.loglog(psd_freqs, est_noise_psd, label='Estimated Noise', color='red', linestyle='--')\n",
"\n",
" ax.set_xlim(left=1)\n",
" ax.set_xlabel(\"Frequency [Hz]\")\n",
" ax.set_ylabel(\"PSD\")\n",
" ax.set_title(f\"Iteration: {iteration}\")\n",
" ax.legend()\n",
" ax.set_ylim(bottom=1e-6, top=1e4)\n",
"\n",
" plt.tight_layout()\n",
" plt.savefig(f\"psd_signal_itr{iteration:02d}.png\")\n",
" plt.close()\n",
"\n",
"def plot_trace_with_hist(samples_dict, true_values, bounds, labels, iteration):\n",
" \"\"\"\n",
" Plot trace plots (one per parameter) with adjacent rotated histograms.\n",
" Each parameter's trace is over all emcee walkers.\n",
"\n",
" Parameters:\n",
" samples_dict (dict): Keys are parameter names; each value is an array of shape (n_steps, n_walkers).\n",
" true_values (list): True parameter values.\n",
" bounds (list of tuples): y-limits for each parameter (min, max).\n",
" labels (list of str): Parameter names.\n",
" iteration (int): Current Gibbs iteration (for file naming).\n",
" \"\"\"\n",
" ndim = len(labels)\n",
" fig, axes = plt.subplots(ndim, 2, figsize=(8, 2.5 * ndim),\n",
" gridspec_kw={\"width_ratios\": [3, 1], \"wspace\": 0.05})\n",
"\n",
" for i in range(ndim):\n",
" # Left: Trace plot\n",
" ax_trace = axes[i, 0]\n",
" # samples_dict[label] shape: (n_steps, n_walkers)\n",
" chain = samples_dict[labels[i]]\n",
" n_steps, n_walkers = chain.shape\n",
" for walker in range(n_walkers):\n",
" ax_trace.plot(np.arange(n_steps), chain[:, walker], color=\"k\", alpha=0.3)\n",
" ax_trace.axhline(true_values[i], color=\"red\", linestyle=\"--\", label=\"True\")\n",
" ax_trace.set_ylabel(labels[i])\n",
" ax_trace.set_ylim(bounds[i])\n",
" if i == ndim - 1:\n",
" ax_trace.set_xlabel(\"MCMC Iteration\")\n",
" else:\n",
" ax_trace.set_xticklabels([])\n",
"\n",
" # Right: Rotated histogram\n",
" ax_hist = axes[i, 1]\n",
" all_samples = chain.flatten()\n",
" ax_hist.hist(all_samples, bins=30, orientation=\"horizontal\", color=\"gray\", alpha=0.7)\n",
" ax_hist.axhline(true_values[i], color=\"red\", linestyle=\"--\")\n",
" ax_hist.set_ylim(bounds[i])\n",
" ax_hist.set_xticks([])\n",
" ax_hist.set_yticks([])\n",
"\n",
" plt.suptitle(f\"Trace and Posterior Histogram (Gibbs Iteration {iteration})\", y=1.02)\n",
" plt.tight_layout()\n",
" plt.savefig(f\"trace_hist_itr{iteration:02d}.png\")\n",
" plt.close()\n",
"\n",
"# ------------------------------\n",
"# JSD computation and plotting functions\n",
"# ------------------------------\n",
"def compute_jsd(file1, file2, bin_ranges, n_bins=30):\n",
" \"\"\"\n",
" Compute the Jensen–Shannon Divergence (JSD) between two posterior sample files.\n",
"\n",
" Parameters:\n",
" file1 (str): First sample file.\n",
" file2 (str): Baseline sample file.\n",
" bin_ranges (tuple): ((a_min, a_max), (f_min, f_max)) for histogram binning.\n",
" n_bins (int): Number of bins.\n",
"\n",
" Returns:\n",
" float: Sum of JSD for parameters a and f.\n",
" \"\"\"\n",
" samples1 = np.loadtxt(file1)\n",
" samples2 = np.loadtxt(file2)\n",
"\n",
" def jsd_param(p, q, rng):\n",
" p_hist, _ = np.histogram(p, bins=n_bins, range=rng, density=True)\n",
" q_hist, _ = np.histogram(q, bins=n_bins, range=rng, density=True)\n",
" p_hist += 1e-10\n",
" q_hist += 1e-10\n",
" m = 0.5 * (p_hist + q_hist)\n",
" return 0.5 * (scipy.stats.entropy(p_hist, m) + scipy.stats.entropy(q_hist, m))\n",
"\n",
" jsd_a = jsd_param(samples1[:, 0], samples2[:, 0], bin_ranges[0])\n",
" jsd_f = jsd_param(samples1[:, 1], samples2[:, 1], bin_ranges[1])\n",
" return jsd_a + jsd_f\n",
"\n",
"def plot_jsd(n_bins=30):\n",
" \"\"\"\n",
" Load saved posterior sample files, compute JSD (against the final iteration),\n",
" and plot JSD vs. Gibbs iteration.\n",
" \"\"\"\n",
" sample_files = sorted(glob.glob(\"samples_*.txt\"), key=lambda x: int(x.split(\"_\")[1].split(\".\")[0]))\n",
" if not sample_files:\n",
" raise FileNotFoundError(\"No sample files found.\")\n",
"\n",
" baseline_file = sample_files[-1]\n",
" # Determine bin ranges from all samples\n",
" all_samples = []\n",
" for fpath in sample_files:\n",
" all_samples.append(np.loadtxt(fpath))\n",
" all_samples = np.vstack(all_samples)\n",
" a_range = (all_samples[:, 0].min(), all_samples[:, 0].max())\n",
" f_range = (all_samples[:, 1].min(), all_samples[:, 1].max())\n",
" bin_ranges = (a_range, f_range)\n",
"\n",
" jsd_values = []\n",
" iterations = []\n",
" for fpath in sample_files[:-1]:\n",
" iteration = int(fpath.split(\"_\")[1].split(\".\")[0])\n",
" jsd_val = compute_jsd(fpath, baseline_file, bin_ranges, n_bins=n_bins)\n",
" jsd_values.append(jsd_val)\n",
" iterations.append(iteration)\n",
"\n",
" plt.figure(figsize=(5, 3))\n",
" plt.plot(iterations, jsd_values, marker=\"o\", linestyle=\"-\")\n",
" plt.xlabel(\"Gibbs Iteration\")\n",
" plt.ylabel(\"JSD\")\n",
" plt.title(\"JSD of [a, f] posterior vs. final posterior\")\n",
" plt.savefig(\"jsd_vs_iterations.png\")\n",
" plt.close()\n",
"\n",
"# ------------------------------\n",
"# Emcee log probability function for the sine-wave model with Whittle likelihood\n",
"# ------------------------------\n",
"def log_prob(params, data, t, dt, fs, noise_psd, a_bounds, f_bounds):\n",
" a, f = params\n",
" # Enforce bounds using the prior\n",
" if not (a_bounds[0] <= a <= a_bounds[1] and f_bounds[0] <= f <= f_bounds[1]):\n",
" return -np.inf\n",
"\n",
" # Generate sine-wave model and compute residuals\n",
" model_signal = a * np.sin(2 * np.pi * f * t)\n",
" res_model = data - model_signal\n",
"\n",
" # Compute periodogram of the residuals\n",
" freqs, periodogram = compute_periodogram(res_model, dt)\n",
" NFFT = len(freqs) * 2\n",
" psd_freqs = np.linspace(0, fs / 2, NFFT // 2 + 1)\n",
"\n",
" # Interpolate the current noise PSD onto periodogram frequencies\n",
" interp_psd = np.interp(freqs, noise_psd[0], noise_psd[1])\n",
"\n",
" # Compute the Whittle likelihood (up to a constant)\n",
" logl = -0.5 * np.sum(periodogram / interp_psd + np.log(interp_psd))\n",
" return logl\n",
"\n",
"# ------------------------------\n",
"# Main Gibbs Sampling Workflow with emcee\n",
"# ------------------------------\n",
"np.random.seed(42)\n",
"\n",
"# True sine-wave parameters\n",
"true_a = 20.0 # Amplitude\n",
"true_f = 25.0 # Frequency (Hz)\n",
"\n",
"# AR filter coefficients (for noise generation)\n",
"a_coeff = [1, -2.2137, 2.9403, -2.1697, 0.9606]\n",
"true_rho = 0.1\n",
"\n",
"# Data length and time grid\n",
"n_samples = 1024\n",
"fs = 100 # Hz\n",
"dt = 1.0 / fs\n",
"t = np.linspace(0, (n_samples - 1) * dt, n_samples)\n",
"\n",
"# Generate true signal and noise, then observed data\n",
"true_signal = true_a * np.sin(2 * np.pi * true_f * t)\n",
"noise = scipy.signal.lfilter([1], a_coeff, np.random.randn(n_samples))\n",
"data = true_signal + noise\n",
"\n",
"# Estimate true noise PSD from the pure noise (using aryule)\n",
"order = 4\n",
"AR_est, P_est, _ = aryule(noise, order)\n",
"NFFT = len(data)\n",
"true_psd = compute_estimated_psd(AR_est, P_est, T=fs, NFFT=NFFT)\n",
"psd_freqs = np.linspace(0, fs/2, len(true_psd))\n",
"noise_psd_default = (psd_freqs, true_psd)\n",
"\n",
"# Compute true signal periodogram (for posterior-predictive plot)\n",
"_, true_signal_pdgrm = compute_periodogram(true_signal, dt)\n",
"\n",
"# Parameter bounds for sine model\n",
"a_bounds = (true_a - 5, true_a + 5)\n",
"f_bounds = (true_f - 5, true_f + 5)\n",
"\n",
"# ------------------------------\n",
"# Gibbs Sampler Setup\n",
"# ------------------------------\n",
"n_gibbs = 30\n",
"n_mcmc = 1000 # MCMC steps per Gibbs iteration\n",
"burnin = 300\n",
"\n",
"# Initialize current signal estimate (start with zeros)\n",
"current_signal = np.zeros_like(data)\n",
"current_noise_psd = noise_psd_default\n",
"\n",
"# For saving posterior samples and tracking median estimates\n",
"posterior_samples_files = []\n",
"last_median_a = (a_bounds[0] + a_bounds[1]) / 2\n",
"last_median_f = (f_bounds[0] + f_bounds[1]) / 2\n",
"\n",
"# Define labels and bounds for trace plots\n",
"param_labels = ['a', 'f']\n",
"param_bounds = [a_bounds, f_bounds]\n",
"true_params = [true_a, true_f]\n",
"\n",
"for gibbs_iter in trange(n_gibbs, desc='Gibbs Sampling'):\n",
" # 1. Update noise PSD using residuals\n",
" residuals = data - current_signal\n",
" AR_est, P_est, _ = aryule(residuals, order)\n",
" current_psd = compute_estimated_psd(AR_est, P_est, T=fs, NFFT=NFFT)\n",
" psd_freqs = np.linspace(0, fs/2, len(current_psd))\n",
" current_noise_psd = (psd_freqs, current_psd)\n",
"\n",
" # 2. Sample [a, f] using emcee\n",
" n_walkers = 16\n",
" # Initialize walkers near the previous median with small random perturbations\n",
" initial = np.array([last_median_a, last_median_f]) + 0.1 * np.random.randn(n_walkers, 2)\n",
" sampler = emcee.EnsembleSampler(n_walkers, 2, log_prob,\n",
" args=(data, t, dt, fs, current_noise_psd, a_bounds, f_bounds))\n",
" sampler.run_mcmc(initial, n_mcmc, progress=False)\n",
"\n",
" # Extract posterior samples (discard burn-in)\n",
" samples = sampler.get_chain(discard=burnin, flat=False) # shape: (n_steps, n_walkers, 2)\n",
" # Reshape to get arrays for a and f (for trace plots)\n",
" a_samples = samples[:, :, 0].copy() # shape: (n_steps, n_walkers)\n",
" f_samples = samples[:, :, 1].copy()\n",
"\n",
" # Update medians for next iteration\n",
" last_median_a = np.median(a_samples)\n",
" last_median_f = np.median(f_samples)\n",
"\n",
" # Update current signal using median parameters\n",
" current_signal = last_median_a * np.sin(2 * np.pi * last_median_f * t)\n",
"\n",
" # 3. Save aggregated posterior samples (over walkers) to a text file\n",
" combined_samples = np.column_stack([a_samples.flatten(), f_samples.flatten()])\n",
" filename = f\"samples_{gibbs_iter+1:02d}.txt\"\n",
" np.savetxt(filename, combined_samples)\n",
" posterior_samples_files.append(filename)\n",
"\n",
" # 4. Plot trace and histogram for this iteration\n",
" samples_for_plot = {\"a\": a_samples, \"f\": f_samples}\n",
" plot_trace_with_hist(samples_for_plot, true_params, param_bounds, param_labels, gibbs_iter+1)\n",
"\n",
" # 5. Plot posterior-predictive PSD and signal vs. true ones\n",
" plot_progress(gibbs_iter+1, t, data, current_signal, dt,\n",
" current_AR=AR_est, current_P=P_est, fs=fs,\n",
" true_signal_pdgrm=true_signal_pdgrm, true_psd=true_psd)\n",
"\n",
"# ------------------------------\n",
"# Final JSD Computation and Plotting\n",
"# ------------------------------\n",
"plot_jsd(n_bins=30)\n",
"\n",
"print(\"Gibbs sampling complete!\")\n",
"print(\"JSD plot saved as 'jsd_vs_iterations.png'.\")\n"
],
"metadata": {
"id": "jWAE2ULnPAjB",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000,
"referenced_widgets": [
"2744c950ab8148c5a26b27889b303076",
"a32091baa8e041ff9c2f2898f3b16ccc",
"4b3950b8c8254469bdfa07fae6868736",
"56784a2ef7564b4d8dc388b70cc89218",
"645fb53be59f4be6b96bb93b2236c670",
"07b05b257ab74478b43447a00c38ddca",
"ec7e837b56ad40feb620fe4cea65c224",
"c2ec3a6304584605ba246be1abab33f3",
"69fcc6acc49b44408a4d1e4152ab62dd",
"69b74cce27f94f8e95fd5f3ab010d676",
"e50577642db849c088722e88a73ec36f"
]
},
"outputId": "cc1ef2cc-fce1-4695-e470-47aa95d443d4"
},
"execution_count": 12,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"Gibbs Sampling: 0%| | 0/30 [00:00<?, ?it/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "2744c950ab8148c5a26b27889b303076"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n",
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
" plt.tight_layout()\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Gibbs sampling complete!\n",
"JSD plot saved as 'jsd_vs_iterations.png'.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import imageio\n",
"import os\n",
"\n",
"# Collect image filenames\n",
"image_files = [f\"itr{i:02d}.png\" for i in range(1, n_gibbs + 1)]\n",
"\n",
"# Create GIF\n",
"images = []\n",
"for filename in image_files:\n",
" images.append(imageio.imread(filename))\n",
"\n",
"# Add the last image for a longer duration\n",
"images.append(imageio.imread(image_files[-1]))\n",
"images.append(imageio.imread(image_files[-1]))\n",
"images.append(imageio.imread(image_files[-1]))\n",
"\n",
"\n",
"# Save as GIF\n",
"imageio.mimsave('gibbs_sampling_psd.gif', images, duration=0.5) # Adjust duration as needed\n",
"\n",
"# Clean up individual image files (optional)\n",
"for filename in image_files:\n",
" os.remove(filename)\n",
"\n",
"print(\"GIF created: gibbs_sampling_psd.gif\")"
],
"metadata": {
"id": "SjCi9pFOVrth"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"![](gibbs_sampling_psd.gif)"
],
"metadata": {
"id": "KAH1GG9XaoxD"
}
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "H46kuPOAaTU6"
},
"execution_count": null,
"outputs": []
}
]
}
@avivajpeyi
Copy link
Author

gibbs_sampling_psd

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment