Created
November 13, 2025 23:40
-
-
Save avivajpeyi/b5329ca13d51c0d6b77465f96400d748 to your computer and use it in GitHub Desktop.
penalty_matrix_for_psplines
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/avivajpeyi/b5329ca13d51c0d6b77465f96400d748/penalty_matrix_for_psplines.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# Penalty matrix for P-splines" | |
| ], | |
| "metadata": { | |
| "id": "ymAGJYlClECF" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 651 | |
| }, | |
| "id": "lIYdn1woOS1n", | |
| "outputId": "a546efc8-721e-4bb5-eeb6-49e3af0c659a" | |
| }, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Shape: (8, 8)\n", | |
| "Frobenius norm ||P_quad - P_skfda|| = 0.2300470611495532\n", | |
| "Max abs diff = 0.11499252631071144\n" | |
| ] | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "<Figure size 1400x600 with 4 Axes>" | |
| ], | |
| "image/png": "\n" | |
| }, | |
| "metadata": {} | |
| } | |
| ], | |
| "source": [ | |
| "! pip install scikit-fda -q\n", | |
| "\n", | |
| "import numpy as np\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "from skfda.misc.operators import LinearDifferentialOperator\n", | |
| "from skfda.misc.regularization import L2Regularization\n", | |
| "from skfda.representation import basis as skfda_basis\n", | |
| "\n", | |
| "\n", | |
| "def penalty_skfda(knots, degree):\n", | |
| " basis = skfda_basis.BSplineBasis(knots=knots, order=degree + 1)\n", | |
| " operator = LinearDifferentialOperator(1)\n", | |
| " reg = L2Regularization(operator)\n", | |
| " P = np.asarray(reg.penalty_matrix(basis))\n", | |
| " return P, basis.n_basis, basis\n", | |
| "\n", | |
| "\n", | |
| "def penalty_quadrature(basis, knots, n_grid=8000):\n", | |
| " K = basis.n_basis\n", | |
| " xg = np.linspace(knots[0], knots[-1], n_grid)\n", | |
| " dx = xg[1] - xg[0]\n", | |
| "\n", | |
| " # Evaluate derivative of basis\n", | |
| " Bp = basis(xg, derivative=1)\n", | |
| " Bp = np.asarray(Bp)\n", | |
| "\n", | |
| " # Reshape to (n_grid, K)\n", | |
| " if Bp.ndim == 3:\n", | |
| " if Bp.shape[1] == 1: # (M,1,K)\n", | |
| " Bp = Bp[:, 0, :]\n", | |
| " elif Bp.shape[2] == 1: # (K,M,1)\n", | |
| " Bp = Bp[:, :, 0].T\n", | |
| " else:\n", | |
| " raise RuntimeError(f\"Cannot reshape Bp of shape {Bp.shape}\")\n", | |
| " elif Bp.ndim != 2:\n", | |
| " raise RuntimeError(f\"Cannot reshape Bp of shape {Bp.shape}\")\n", | |
| "\n", | |
| " assert Bp.shape == (n_grid, K)\n", | |
| "\n", | |
| " return (Bp.T @ Bp) * dx # (K,K)\n", | |
| "\n", | |
| "\n", | |
| "def plot_matrix(ax, matrix, title, fig):\n", | |
| " vmax = np.max(np.abs(matrix))\n", | |
| " cax = ax.matshow(matrix, cmap='coolwarm', vmin=-vmax, vmax=vmax)\n", | |
| " fig.colorbar(cax, ax=ax, shrink=0.8)\n", | |
| " ax.set_title(title, fontsize=12)\n", | |
| "\n", | |
| " n = matrix.shape[0]\n", | |
| " ax.set_xticks(np.arange(n))\n", | |
| " ax.set_yticks(np.arange(n))\n", | |
| " ax.set_xticklabels(np.arange(n))\n", | |
| " ax.set_yticklabels(np.arange(n))\n", | |
| " ax.xaxis.tick_bottom()\n", | |
| "\n", | |
| "\n", | |
| "# 4. MAIN: compare and plot\n", | |
| "xi = np.array([0.00, 0.07, 0.21, 0.25, 0.60, 1.00])\n", | |
| "degree = 3\n", | |
| "\n", | |
| "# Compute both penalties\n", | |
| "P_skfda, K, basis = penalty_skfda(xi, degree)\n", | |
| "P_quad = penalty_quadrature(basis, xi)\n", | |
| "\n", | |
| "# Print difference metrics\n", | |
| "fro_err = np.linalg.norm(P_quad - P_skfda)\n", | |
| "max_err = np.max(np.abs(P_quad - P_skfda))\n", | |
| "\n", | |
| "print(\"Shape:\", P_quad.shape)\n", | |
| "print(\"Frobenius norm ||P_quad - P_skfda|| =\", fro_err)\n", | |
| "print(\"Max abs diff =\", max_err)\n", | |
| "\n", | |
| "\n", | |
| "fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n", | |
| "fig.suptitle('Comparison of Penalty Matrices', fontsize=16)\n", | |
| "\n", | |
| "plot_matrix(axes[0], P_quad, r'Quadrature Penalty Matrix ($P_{\\rm quad}$)', fig)\n", | |
| "plot_matrix(axes[1], P_skfda, r'skfda Penalty Matrix ($P_{\\rm skfda}$)', fig)\n", | |
| "\n", | |
| "plt.tight_layout(rect=[0, 0, 1, 0.95])\n", | |
| "plt.show()\n" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "provenance": [], | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "name": "python3" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 0 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment