Last active
March 13, 2025 04:49
-
-
Save avivajpeyi/4e2702a7e20565053b91726f025c0c44 to your computer and use it in GitHub Desktop.
psd_and_signal_estimation_example.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyNDilo9GPSgNybsPua2RdR8", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"widgets": { | |
"application/vnd.jupyter.widget-state+json": { | |
"2744c950ab8148c5a26b27889b303076": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HBoxModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HBoxView", | |
"box_style": "", | |
"children": [ | |
"IPY_MODEL_a32091baa8e041ff9c2f2898f3b16ccc", | |
"IPY_MODEL_4b3950b8c8254469bdfa07fae6868736", | |
"IPY_MODEL_56784a2ef7564b4d8dc388b70cc89218" | |
], | |
"layout": "IPY_MODEL_645fb53be59f4be6b96bb93b2236c670" | |
} | |
}, | |
"a32091baa8e041ff9c2f2898f3b16ccc": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_07b05b257ab74478b43447a00c38ddca", | |
"placeholder": "", | |
"style": "IPY_MODEL_ec7e837b56ad40feb620fe4cea65c224", | |
"value": "Gibbs Sampling: 100%" | |
} | |
}, | |
"4b3950b8c8254469bdfa07fae6868736": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "FloatProgressModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "ProgressView", | |
"bar_style": "success", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_c2ec3a6304584605ba246be1abab33f3", | |
"max": 30, | |
"min": 0, | |
"orientation": "horizontal", | |
"style": "IPY_MODEL_69fcc6acc49b44408a4d1e4152ab62dd", | |
"value": 30 | |
} | |
}, | |
"56784a2ef7564b4d8dc388b70cc89218": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_69b74cce27f94f8e95fd5f3ab010d676", | |
"placeholder": "", | |
"style": "IPY_MODEL_e50577642db849c088722e88a73ec36f", | |
"value": " 30/30 [01:57<00:00, 4.01s/it]" | |
} | |
}, | |
"645fb53be59f4be6b96bb93b2236c670": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"07b05b257ab74478b43447a00c38ddca": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"ec7e837b56ad40feb620fe4cea65c224": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"c2ec3a6304584605ba246be1abab33f3": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"69fcc6acc49b44408a4d1e4152ab62dd": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "ProgressStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"bar_color": null, | |
"description_width": "" | |
} | |
}, | |
"69b74cce27f94f8e95fd5f3ab010d676": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"e50577642db849c088722e88a73ec36f": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
} | |
} | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/avivajpeyi/4e2702a7e20565053b91726f025c0c44/psd_and_signal_estimation_example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# PSD and Signal estimation\n", | |
"\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "gEEwm4ZEN_r0" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "L7IasQS_N59_", | |
"outputId": "8dae71af-6b46-44ef-9f3c-342643edab36" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Collecting numpyro\n", | |
" Downloading numpyro-0.17.0-py3-none-any.whl.metadata (37 kB)\n", | |
"Requirement already satisfied: spectrum in /usr/local/lib/python3.11/dist-packages (0.9.0)\n", | |
"Requirement already satisfied: jax>=0.4.25 in /usr/local/lib/python3.11/dist-packages (from numpyro) (0.5.2)\n", | |
"Requirement already satisfied: jaxlib>=0.4.25 in /usr/local/lib/python3.11/dist-packages (from numpyro) (0.5.1)\n", | |
"Requirement already satisfied: multipledispatch in /usr/local/lib/python3.11/dist-packages (from numpyro) (1.0.0)\n", | |
"Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from numpyro) (1.26.4)\n", | |
"Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from numpyro) (4.67.1)\n", | |
"Requirement already satisfied: easydev in /usr/local/lib/python3.11/dist-packages (from spectrum) (0.13.3)\n", | |
"Requirement already satisfied: scipy in /usr/local/lib/python3.11/dist-packages (from spectrum) (1.14.1)\n", | |
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.11/dist-packages (from spectrum) (3.10.0)\n", | |
"Requirement already satisfied: ml_dtypes>=0.4.0 in /usr/local/lib/python3.11/dist-packages (from jax>=0.4.25->numpyro) (0.4.1)\n", | |
"Requirement already satisfied: opt_einsum in /usr/local/lib/python3.11/dist-packages (from jax>=0.4.25->numpyro) (3.4.0)\n", | |
"Requirement already satisfied: colorama<0.5.0,>=0.4.6 in /usr/local/lib/python3.11/dist-packages (from easydev->spectrum) (0.4.6)\n", | |
"Requirement already satisfied: colorlog<7.0.0,>=6.8.2 in /usr/local/lib/python3.11/dist-packages (from easydev->spectrum) (6.9.0)\n", | |
"Requirement already satisfied: line-profiler<5.0.0,>=4.1.2 in /usr/local/lib/python3.11/dist-packages (from easydev->spectrum) (4.2.0)\n", | |
"Requirement already satisfied: pexpect<5.0.0,>=4.9.0 in /usr/local/lib/python3.11/dist-packages (from easydev->spectrum) (4.9.0)\n", | |
"Requirement already satisfied: platformdirs<5.0.0,>=4.2.0 in /usr/local/lib/python3.11/dist-packages (from easydev->spectrum) (4.3.6)\n", | |
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib->spectrum) (1.3.1)\n", | |
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib->spectrum) (0.12.1)\n", | |
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib->spectrum) (4.56.0)\n", | |
"Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib->spectrum) (1.4.8)\n", | |
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib->spectrum) (24.2)\n", | |
"Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib->spectrum) (11.1.0)\n", | |
"Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib->spectrum) (3.2.1)\n", | |
"Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.11/dist-packages (from matplotlib->spectrum) (2.8.2)\n", | |
"Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.11/dist-packages (from pexpect<5.0.0,>=4.9.0->easydev->spectrum) (0.7.0)\n", | |
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.7->matplotlib->spectrum) (1.17.0)\n", | |
"Downloading numpyro-0.17.0-py3-none-any.whl (360 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m360.8/360.8 kB\u001b[0m \u001b[31m15.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hInstalling collected packages: numpyro\n", | |
"Successfully installed numpyro-0.17.0\n" | |
] | |
} | |
], | |
"source": [ | |
"! pip install numpyro spectrum" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import os\n", | |
"import glob\n", | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt\n", | |
"import scipy.signal\n", | |
"import scipy.stats\n", | |
"import emcee\n", | |
"from spectrum import aryule, arma2psd\n", | |
"from tqdm.auto import trange\n", | |
"\n", | |
"# ------------------------------\n", | |
"# Helper functions for periodogram and PSD\n", | |
"# ------------------------------\n", | |
"def compute_periodogram(signal, dt):\n", | |
" \"\"\"Compute periodogram using FFT.\"\"\"\n", | |
" n = len(signal)\n", | |
" fft_vals = np.fft.fft(signal)\n", | |
" freqs = np.fft.fftfreq(n, d=dt)[: n // 2]\n", | |
" periodogram = (np.abs(fft_vals[: n // 2]) ** 2) * dt / n\n", | |
" return freqs, periodogram\n", | |
"\n", | |
"def compute_estimated_psd(AR, P, T, NFFT):\n", | |
" \"\"\"\n", | |
" Compute one-sided PSD from an AR model using arma2psd.\n", | |
" \"\"\"\n", | |
" psd_full = arma2psd(AR, [1], P, T, NFFT)\n", | |
" one_sided = psd_full[: NFFT // 2 + 1]\n", | |
" return one_sided\n", | |
"\n", | |
"# ------------------------------\n", | |
"# Plotting functions\n", | |
"# ------------------------------\n", | |
"def plot_progress(iteration, t, data, current_signal, dt, current_AR, current_P, fs,\n", | |
" true_signal_pdgrm, true_psd):\n", | |
" \"\"\"\n", | |
" Plot the posterior-predictive PSD and signal vs. true PSD and true signal.\n", | |
" Saves a figure for the current Gibbs iteration.\n", | |
" \"\"\"\n", | |
" fig, ax = plt.subplots(1, 1, figsize=(5, 3.5))\n", | |
"\n", | |
" # Compute periodograms\n", | |
" freqs, data_periodogram = compute_periodogram(data, dt)\n", | |
" _, signal_periodogram = compute_periodogram(current_signal, dt)\n", | |
"\n", | |
" # Set NFFT and frequency vector for PSD computation\n", | |
" NFFT = len(freqs) * 2\n", | |
" psd_freqs = np.linspace(0, fs / 2, NFFT // 2 + 1)\n", | |
"\n", | |
" # Estimated noise PSD from current AR model\n", | |
" est_noise_psd = compute_estimated_psd(current_AR, current_P, T=fs, NFFT=NFFT)\n", | |
"\n", | |
" # Plot on a log-log scale\n", | |
" ax.loglog(freqs, data_periodogram, label='Data', color='gray', alpha=0.5)\n", | |
" ax.loglog(freqs, signal_periodogram, label='Estimated Signal', color='green', linestyle='-.')\n", | |
" ax.loglog(freqs, true_signal_pdgrm, label='True Signal', color='green')\n", | |
" ax.loglog(psd_freqs, true_psd, label='True Noise', color='red')\n", | |
" ax.loglog(psd_freqs, est_noise_psd, label='Estimated Noise', color='red', linestyle='--')\n", | |
"\n", | |
" ax.set_xlim(left=1)\n", | |
" ax.set_xlabel(\"Frequency [Hz]\")\n", | |
" ax.set_ylabel(\"PSD\")\n", | |
" ax.set_title(f\"Iteration: {iteration}\")\n", | |
" ax.legend()\n", | |
" ax.set_ylim(bottom=1e-6, top=1e4)\n", | |
"\n", | |
" plt.tight_layout()\n", | |
" plt.savefig(f\"psd_signal_itr{iteration:02d}.png\")\n", | |
" plt.close()\n", | |
"\n", | |
"def plot_trace_with_hist(samples_dict, true_values, bounds, labels, iteration):\n", | |
" \"\"\"\n", | |
" Plot trace plots (one per parameter) with adjacent rotated histograms.\n", | |
" Each parameter's trace is over all emcee walkers.\n", | |
"\n", | |
" Parameters:\n", | |
" samples_dict (dict): Keys are parameter names; each value is an array of shape (n_steps, n_walkers).\n", | |
" true_values (list): True parameter values.\n", | |
" bounds (list of tuples): y-limits for each parameter (min, max).\n", | |
" labels (list of str): Parameter names.\n", | |
" iteration (int): Current Gibbs iteration (for file naming).\n", | |
" \"\"\"\n", | |
" ndim = len(labels)\n", | |
" fig, axes = plt.subplots(ndim, 2, figsize=(8, 2.5 * ndim),\n", | |
" gridspec_kw={\"width_ratios\": [3, 1], \"wspace\": 0.05})\n", | |
"\n", | |
" for i in range(ndim):\n", | |
" # Left: Trace plot\n", | |
" ax_trace = axes[i, 0]\n", | |
" # samples_dict[label] shape: (n_steps, n_walkers)\n", | |
" chain = samples_dict[labels[i]]\n", | |
" n_steps, n_walkers = chain.shape\n", | |
" for walker in range(n_walkers):\n", | |
" ax_trace.plot(np.arange(n_steps), chain[:, walker], color=\"k\", alpha=0.3)\n", | |
" ax_trace.axhline(true_values[i], color=\"red\", linestyle=\"--\", label=\"True\")\n", | |
" ax_trace.set_ylabel(labels[i])\n", | |
" ax_trace.set_ylim(bounds[i])\n", | |
" if i == ndim - 1:\n", | |
" ax_trace.set_xlabel(\"MCMC Iteration\")\n", | |
" else:\n", | |
" ax_trace.set_xticklabels([])\n", | |
"\n", | |
" # Right: Rotated histogram\n", | |
" ax_hist = axes[i, 1]\n", | |
" all_samples = chain.flatten()\n", | |
" ax_hist.hist(all_samples, bins=30, orientation=\"horizontal\", color=\"gray\", alpha=0.7)\n", | |
" ax_hist.axhline(true_values[i], color=\"red\", linestyle=\"--\")\n", | |
" ax_hist.set_ylim(bounds[i])\n", | |
" ax_hist.set_xticks([])\n", | |
" ax_hist.set_yticks([])\n", | |
"\n", | |
" plt.suptitle(f\"Trace and Posterior Histogram (Gibbs Iteration {iteration})\", y=1.02)\n", | |
" plt.tight_layout()\n", | |
" plt.savefig(f\"trace_hist_itr{iteration:02d}.png\")\n", | |
" plt.close()\n", | |
"\n", | |
"# ------------------------------\n", | |
"# JSD computation and plotting functions\n", | |
"# ------------------------------\n", | |
"def compute_jsd(file1, file2, bin_ranges, n_bins=30):\n", | |
" \"\"\"\n", | |
" Compute the Jensen–Shannon Divergence (JSD) between two posterior sample files.\n", | |
"\n", | |
" Parameters:\n", | |
" file1 (str): First sample file.\n", | |
" file2 (str): Baseline sample file.\n", | |
" bin_ranges (tuple): ((a_min, a_max), (f_min, f_max)) for histogram binning.\n", | |
" n_bins (int): Number of bins.\n", | |
"\n", | |
" Returns:\n", | |
" float: Sum of JSD for parameters a and f.\n", | |
" \"\"\"\n", | |
" samples1 = np.loadtxt(file1)\n", | |
" samples2 = np.loadtxt(file2)\n", | |
"\n", | |
" def jsd_param(p, q, rng):\n", | |
" p_hist, _ = np.histogram(p, bins=n_bins, range=rng, density=True)\n", | |
" q_hist, _ = np.histogram(q, bins=n_bins, range=rng, density=True)\n", | |
" p_hist += 1e-10\n", | |
" q_hist += 1e-10\n", | |
" m = 0.5 * (p_hist + q_hist)\n", | |
" return 0.5 * (scipy.stats.entropy(p_hist, m) + scipy.stats.entropy(q_hist, m))\n", | |
"\n", | |
" jsd_a = jsd_param(samples1[:, 0], samples2[:, 0], bin_ranges[0])\n", | |
" jsd_f = jsd_param(samples1[:, 1], samples2[:, 1], bin_ranges[1])\n", | |
" return jsd_a + jsd_f\n", | |
"\n", | |
"def plot_jsd(n_bins=30):\n", | |
" \"\"\"\n", | |
" Load saved posterior sample files, compute JSD (against the final iteration),\n", | |
" and plot JSD vs. Gibbs iteration.\n", | |
" \"\"\"\n", | |
" sample_files = sorted(glob.glob(\"samples_*.txt\"), key=lambda x: int(x.split(\"_\")[1].split(\".\")[0]))\n", | |
" if not sample_files:\n", | |
" raise FileNotFoundError(\"No sample files found.\")\n", | |
"\n", | |
" baseline_file = sample_files[-1]\n", | |
" # Determine bin ranges from all samples\n", | |
" all_samples = []\n", | |
" for fpath in sample_files:\n", | |
" all_samples.append(np.loadtxt(fpath))\n", | |
" all_samples = np.vstack(all_samples)\n", | |
" a_range = (all_samples[:, 0].min(), all_samples[:, 0].max())\n", | |
" f_range = (all_samples[:, 1].min(), all_samples[:, 1].max())\n", | |
" bin_ranges = (a_range, f_range)\n", | |
"\n", | |
" jsd_values = []\n", | |
" iterations = []\n", | |
" for fpath in sample_files[:-1]:\n", | |
" iteration = int(fpath.split(\"_\")[1].split(\".\")[0])\n", | |
" jsd_val = compute_jsd(fpath, baseline_file, bin_ranges, n_bins=n_bins)\n", | |
" jsd_values.append(jsd_val)\n", | |
" iterations.append(iteration)\n", | |
"\n", | |
" plt.figure(figsize=(5, 3))\n", | |
" plt.plot(iterations, jsd_values, marker=\"o\", linestyle=\"-\")\n", | |
" plt.xlabel(\"Gibbs Iteration\")\n", | |
" plt.ylabel(\"JSD\")\n", | |
" plt.title(\"JSD of [a, f] posterior vs. final posterior\")\n", | |
" plt.savefig(\"jsd_vs_iterations.png\")\n", | |
" plt.close()\n", | |
"\n", | |
"# ------------------------------\n", | |
"# Emcee log probability function for the sine-wave model with Whittle likelihood\n", | |
"# ------------------------------\n", | |
"def log_prob(params, data, t, dt, fs, noise_psd, a_bounds, f_bounds):\n", | |
" a, f = params\n", | |
" # Enforce bounds using the prior\n", | |
" if not (a_bounds[0] <= a <= a_bounds[1] and f_bounds[0] <= f <= f_bounds[1]):\n", | |
" return -np.inf\n", | |
"\n", | |
" # Generate sine-wave model and compute residuals\n", | |
" model_signal = a * np.sin(2 * np.pi * f * t)\n", | |
" res_model = data - model_signal\n", | |
"\n", | |
" # Compute periodogram of the residuals\n", | |
" freqs, periodogram = compute_periodogram(res_model, dt)\n", | |
" NFFT = len(freqs) * 2\n", | |
" psd_freqs = np.linspace(0, fs / 2, NFFT // 2 + 1)\n", | |
"\n", | |
" # Interpolate the current noise PSD onto periodogram frequencies\n", | |
" interp_psd = np.interp(freqs, noise_psd[0], noise_psd[1])\n", | |
"\n", | |
" # Compute the Whittle likelihood (up to a constant)\n", | |
" logl = -0.5 * np.sum(periodogram / interp_psd + np.log(interp_psd))\n", | |
" return logl\n", | |
"\n", | |
"# ------------------------------\n", | |
"# Main Gibbs Sampling Workflow with emcee\n", | |
"# ------------------------------\n", | |
"np.random.seed(42)\n", | |
"\n", | |
"# True sine-wave parameters\n", | |
"true_a = 20.0 # Amplitude\n", | |
"true_f = 25.0 # Frequency (Hz)\n", | |
"\n", | |
"# AR filter coefficients (for noise generation)\n", | |
"a_coeff = [1, -2.2137, 2.9403, -2.1697, 0.9606]\n", | |
"true_rho = 0.1\n", | |
"\n", | |
"# Data length and time grid\n", | |
"n_samples = 1024\n", | |
"fs = 100 # Hz\n", | |
"dt = 1.0 / fs\n", | |
"t = np.linspace(0, (n_samples - 1) * dt, n_samples)\n", | |
"\n", | |
"# Generate true signal and noise, then observed data\n", | |
"true_signal = true_a * np.sin(2 * np.pi * true_f * t)\n", | |
"noise = scipy.signal.lfilter([1], a_coeff, np.random.randn(n_samples))\n", | |
"data = true_signal + noise\n", | |
"\n", | |
"# Estimate true noise PSD from the pure noise (using aryule)\n", | |
"order = 4\n", | |
"AR_est, P_est, _ = aryule(noise, order)\n", | |
"NFFT = len(data)\n", | |
"true_psd = compute_estimated_psd(AR_est, P_est, T=fs, NFFT=NFFT)\n", | |
"psd_freqs = np.linspace(0, fs/2, len(true_psd))\n", | |
"noise_psd_default = (psd_freqs, true_psd)\n", | |
"\n", | |
"# Compute true signal periodogram (for posterior-predictive plot)\n", | |
"_, true_signal_pdgrm = compute_periodogram(true_signal, dt)\n", | |
"\n", | |
"# Parameter bounds for sine model\n", | |
"a_bounds = (true_a - 5, true_a + 5)\n", | |
"f_bounds = (true_f - 5, true_f + 5)\n", | |
"\n", | |
"# ------------------------------\n", | |
"# Gibbs Sampler Setup\n", | |
"# ------------------------------\n", | |
"n_gibbs = 30\n", | |
"n_mcmc = 1000 # MCMC steps per Gibbs iteration\n", | |
"burnin = 300\n", | |
"\n", | |
"# Initialize current signal estimate (start with zeros)\n", | |
"current_signal = np.zeros_like(data)\n", | |
"current_noise_psd = noise_psd_default\n", | |
"\n", | |
"# For saving posterior samples and tracking median estimates\n", | |
"posterior_samples_files = []\n", | |
"last_median_a = (a_bounds[0] + a_bounds[1]) / 2\n", | |
"last_median_f = (f_bounds[0] + f_bounds[1]) / 2\n", | |
"\n", | |
"# Define labels and bounds for trace plots\n", | |
"param_labels = ['a', 'f']\n", | |
"param_bounds = [a_bounds, f_bounds]\n", | |
"true_params = [true_a, true_f]\n", | |
"\n", | |
"for gibbs_iter in trange(n_gibbs, desc='Gibbs Sampling'):\n", | |
" # 1. Update noise PSD using residuals\n", | |
" residuals = data - current_signal\n", | |
" AR_est, P_est, _ = aryule(residuals, order)\n", | |
" current_psd = compute_estimated_psd(AR_est, P_est, T=fs, NFFT=NFFT)\n", | |
" psd_freqs = np.linspace(0, fs/2, len(current_psd))\n", | |
" current_noise_psd = (psd_freqs, current_psd)\n", | |
"\n", | |
" # 2. Sample [a, f] using emcee\n", | |
" n_walkers = 16\n", | |
" # Initialize walkers near the previous median with small random perturbations\n", | |
" initial = np.array([last_median_a, last_median_f]) + 0.1 * np.random.randn(n_walkers, 2)\n", | |
" sampler = emcee.EnsembleSampler(n_walkers, 2, log_prob,\n", | |
" args=(data, t, dt, fs, current_noise_psd, a_bounds, f_bounds))\n", | |
" sampler.run_mcmc(initial, n_mcmc, progress=False)\n", | |
"\n", | |
" # Extract posterior samples (discard burn-in)\n", | |
" samples = sampler.get_chain(discard=burnin, flat=False) # shape: (n_steps, n_walkers, 2)\n", | |
" # Reshape to get arrays for a and f (for trace plots)\n", | |
" a_samples = samples[:, :, 0].copy() # shape: (n_steps, n_walkers)\n", | |
" f_samples = samples[:, :, 1].copy()\n", | |
"\n", | |
" # Update medians for next iteration\n", | |
" last_median_a = np.median(a_samples)\n", | |
" last_median_f = np.median(f_samples)\n", | |
"\n", | |
" # Update current signal using median parameters\n", | |
" current_signal = last_median_a * np.sin(2 * np.pi * last_median_f * t)\n", | |
"\n", | |
" # 3. Save aggregated posterior samples (over walkers) to a text file\n", | |
" combined_samples = np.column_stack([a_samples.flatten(), f_samples.flatten()])\n", | |
" filename = f\"samples_{gibbs_iter+1:02d}.txt\"\n", | |
" np.savetxt(filename, combined_samples)\n", | |
" posterior_samples_files.append(filename)\n", | |
"\n", | |
" # 4. Plot trace and histogram for this iteration\n", | |
" samples_for_plot = {\"a\": a_samples, \"f\": f_samples}\n", | |
" plot_trace_with_hist(samples_for_plot, true_params, param_bounds, param_labels, gibbs_iter+1)\n", | |
"\n", | |
" # 5. Plot posterior-predictive PSD and signal vs. true ones\n", | |
" plot_progress(gibbs_iter+1, t, data, current_signal, dt,\n", | |
" current_AR=AR_est, current_P=P_est, fs=fs,\n", | |
" true_signal_pdgrm=true_signal_pdgrm, true_psd=true_psd)\n", | |
"\n", | |
"# ------------------------------\n", | |
"# Final JSD Computation and Plotting\n", | |
"# ------------------------------\n", | |
"plot_jsd(n_bins=30)\n", | |
"\n", | |
"print(\"Gibbs sampling complete!\")\n", | |
"print(\"JSD plot saved as 'jsd_vs_iterations.png'.\")\n" | |
], | |
"metadata": { | |
"id": "jWAE2ULnPAjB", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000, | |
"referenced_widgets": [ | |
"2744c950ab8148c5a26b27889b303076", | |
"a32091baa8e041ff9c2f2898f3b16ccc", | |
"4b3950b8c8254469bdfa07fae6868736", | |
"56784a2ef7564b4d8dc388b70cc89218", | |
"645fb53be59f4be6b96bb93b2236c670", | |
"07b05b257ab74478b43447a00c38ddca", | |
"ec7e837b56ad40feb620fe4cea65c224", | |
"c2ec3a6304584605ba246be1abab33f3", | |
"69fcc6acc49b44408a4d1e4152ab62dd", | |
"69b74cce27f94f8e95fd5f3ab010d676", | |
"e50577642db849c088722e88a73ec36f" | |
] | |
}, | |
"outputId": "cc1ef2cc-fce1-4695-e470-47aa95d443d4" | |
}, | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"Gibbs Sampling: 0%| | 0/30 [00:00<?, ?it/s]" | |
], | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "2744c950ab8148c5a26b27889b303076" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n", | |
"<ipython-input-12-1c3ee5c2da40>:112: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", | |
" plt.tight_layout()\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Gibbs sampling complete!\n", | |
"JSD plot saved as 'jsd_vs_iterations.png'.\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import imageio\n", | |
"import os\n", | |
"\n", | |
"# Collect image filenames\n", | |
"image_files = [f\"itr{i:02d}.png\" for i in range(1, n_gibbs + 1)]\n", | |
"\n", | |
"# Create GIF\n", | |
"images = []\n", | |
"for filename in image_files:\n", | |
" images.append(imageio.imread(filename))\n", | |
"\n", | |
"# Add the last image for a longer duration\n", | |
"images.append(imageio.imread(image_files[-1]))\n", | |
"images.append(imageio.imread(image_files[-1]))\n", | |
"images.append(imageio.imread(image_files[-1]))\n", | |
"\n", | |
"\n", | |
"# Save as GIF\n", | |
"imageio.mimsave('gibbs_sampling_psd.gif', images, duration=0.5) # Adjust duration as needed\n", | |
"\n", | |
"# Clean up individual image files (optional)\n", | |
"for filename in image_files:\n", | |
" os.remove(filename)\n", | |
"\n", | |
"print(\"GIF created: gibbs_sampling_psd.gif\")" | |
], | |
"metadata": { | |
"id": "SjCi9pFOVrth" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"" | |
], | |
"metadata": { | |
"id": "KAH1GG9XaoxD" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"metadata": { | |
"id": "H46kuPOAaTU6" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Author
avivajpeyi
commented
Mar 5, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment