Last active
October 4, 2025 17:41
-
-
Save ricardoV94/d75d42d5c3cdde8827135bcbefc9405a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "067a3cc1-1ed2-4c28-a113-0bd24c433f45", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import pytensor.tensor as pt\n", | |
| "import numpy as np" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "3fb48d6c-5d8b-46ae-a09d-e6622669b72e", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def fill_triangular_spiral(x_raveled, n=None, upper=False):\n", | |
| " # https://github.com/tensorflow/probability/blob/a26f4cbe5ce1549767e13798d9bf5032dac4257b/tensorflow_probability/python/math/linalg.py#L925\n", | |
| " x_raveled = pt.as_tensor(x_raveled)\n", | |
| " *batch_shape, m = x_raveled.shape\n", | |
| " if n is None:\n", | |
| " n = pt.cast(pt.sqrt(0.25 + 2 * m) - 0.5, \"int32\")\n", | |
| " tail = x_raveled[..., n:]\n", | |
| " \n", | |
| " def reverse(x):\n", | |
| " return x[..., ::-1]\n", | |
| " \n", | |
| " if upper:\n", | |
| " xc = pt.concatenate([x_raveled, reverse(tail)])\n", | |
| " else:\n", | |
| " xc = pt.concatenate([tail, reverse(x_raveled)])\n", | |
| " \n", | |
| " y = pt.reshape(xc, (*batch_shape, n, n))\n", | |
| " return pt.triu(y) if upper else pt.tril(y)\n", | |
| "\n", | |
| "def inverse_fill_triangular_spiral(x, m=None, upper=False):\n", | |
| " x = pt.as_tensor(x)\n", | |
| " *batch_shape, n, n = x.shape\n", | |
| " \n", | |
| " if m is None:\n", | |
| " m = pt.cast((n * (n + 1)) // 2, \"int32\")\n", | |
| " \n", | |
| " ndim = x.ndim\n", | |
| " if upper:\n", | |
| " initial_elements = x[..., 0, :]\n", | |
| " triangular_portion = x[..., 1:, :]\n", | |
| " else:\n", | |
| " initial_elements = pt.flip(x[..., -1, :], axis=-1)\n", | |
| " triangular_portion = x[..., :-1, :]\n", | |
| " # return initial_elements, triangular_portion\n", | |
| " rotated_triangular_portion = pt.flip(triangular_portion, axis=(-1, -2))\n", | |
| " consolidated_matrix = triangular_portion + rotated_triangular_portion\n", | |
| " end_sequence = pt.reshape(\n", | |
| " consolidated_matrix,\n", | |
| " (*batch_shape, pt.cast(n * (n -1), \"int64\")),\n", | |
| " )\n", | |
| " y = pt.concatenate([initial_elements, end_sequence[..., :m - n]], axis=-1)\n", | |
| " return y \n", | |
| "\n", | |
| "\n", | |
| "def test_fill_triangular_spiral():\n", | |
| " x_unconstrained = np.array([1, 2, 3, 4, 5, 6])\n", | |
| " x_constrained_lower = np.array([[4, 0, 0], [6, 5, 0], [3, 2, 1],])\n", | |
| " x_constrained_upper = np.array([[1, 2, 3], [0, 5, 6], [0, 0, 4],])\n", | |
| "\n", | |
| " np.testing.assert_allclose(\n", | |
| " fill_triangular_spiral(x_unconstrained).eval(),\n", | |
| " x_constrained_lower,\n", | |
| " )\n", | |
| " \n", | |
| " np.testing.assert_allclose(\n", | |
| " fill_triangular_spiral(x_unconstrained, upper=True).eval(),\n", | |
| " x_constrained_upper,\n", | |
| " )\n", | |
| "\n", | |
| " np.testing.assert_allclose(\n", | |
| " inverse_fill_triangular_spiral(x_constrained_lower).eval(),\n", | |
| " x_unconstrained, \n", | |
| " )\n", | |
| " \n", | |
| " np.testing.assert_allclose(\n", | |
| " inverse_fill_triangular_spiral(x_constrained_upper, upper=True).eval(),\n", | |
| " x_unconstrained, \n", | |
| " )" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "95512889-d6e6-4573-aec5-11a2e0bc0eed", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "test_fill_triangular_spiral()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "2e078480-52e4-46da-aed3-f01fb2a3d1b7", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def backward(self, x, *inputs):\n", | |
| " x = pt.as_tensor(x)\n", | |
| " *batch_shape, core_shape = x.shape\n", | |
| "\n", | |
| " y = fill_triangular_spiral(x)\n", | |
| " n = y.shape[-1]\n", | |
| "\n", | |
| " # Pad zeros on the top row and right column.\n", | |
| " paddings = [*([(0, 0)] * (y.ndim - 2)), [1, 0], [0, 1]]\n", | |
| " y = pt.pad(y, paddings)\n", | |
| " \n", | |
| " # # Set diagonal to 1s.\n", | |
| " arange = pt.arange(n + 1)\n", | |
| " y = y[..., arange, arange].set(1)\n", | |
| "\n", | |
| " # Normalize each row to have Euclidean (L2) norm 1.\n", | |
| " y /= pt.linalg.norm(y, axis=-1, ord=2)[..., None]\n", | |
| " return y\n", | |
| "\n", | |
| "# x: constrained -> y: unconstrained\n", | |
| "def forward(self, y, *inputs):\n", | |
| " y = pt.as_tensor(y)\n", | |
| " *batch_shape, n = y.shape\n", | |
| "\n", | |
| " # Extract the reciprocal of the row norms from the diagonal.\n", | |
| " diag = pt.diagonal(y, axis1=-2, axis2=-1)[..., None]\n", | |
| "\n", | |
| " # Set the diagonal to 0s.\n", | |
| " arange = pt.arange(n)\n", | |
| " y = y[..., arange, arange].set(0)\n", | |
| " \n", | |
| " # Multiply with the norm (or divide by its reciprocal) to recover the\n", | |
| " # unconstrained reals in the (strictly) lower triangular part.\n", | |
| " x = y / diag\n", | |
| "\n", | |
| " # Remove the first row and last column before inverting the FillTriangular\n", | |
| " # transformation.\n", | |
| " return inverse_fill_triangular_spiral(x[..., 1:, :-1])\n", | |
| "\n", | |
| "def log_jac_det(self, x, *inputs):\n", | |
| " y = backward(None, x, *inputs)\n", | |
| " n = y.shape[-1]\n", | |
| " return -pt.sum(\n", | |
| " pt.arange(2, 2 + n, dtype=x.dtype) * pt.log(pt.diagonal(y, axis1=-2, axis2=-1)), \n", | |
| " axis=-1,\n", | |
| " )" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "afd21694-f720-44ed-9761-b9423c67276b", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "c9a1dda2-62a9-40ed-bbc8-2ad373265019", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "x_unconstrained = pt.as_tensor([2, 2, 1])\n", | |
| "x_constrained = pt.as_tensor(\n", | |
| " [[ 1. , 0. , 0. ],\n", | |
| " [ 0.70710678, 0.70710678, 0. ],\n", | |
| " [ 0.66666667, 0.66666667, 0.33333333]]\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "d8bd955a-fe73-4b27-9449-1fb546443e0d", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([2., 2., 1.])" | |
| ] | |
| }, | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "forward(None, backward(None, x_unconstrained)).eval()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "02176309-cca5-496b-a958-3599c26ddda1", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[1. , 0. , 0. ],\n", | |
| " [0.70710678, 0.70710678, 0. ],\n", | |
| " [0.66666667, 0.66666667, 0.33333333]])" | |
| ] | |
| }, | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "backward(None, forward(None, x_constrained)).eval()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "379d291e-a399-4a7f-973c-8418283292d9", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array(5.43416993)" | |
| ] | |
| }, | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "log_jac_det(None, x_unconstrained).eval()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "437f9b78-f4d4-4c90-b720-78217834f8ee", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[1. , 0. , 0. ],\n", | |
| " [0.70710678, 0.70710678, 0. ],\n", | |
| " [0.66666667, 0.66666667, 0.33333333]])" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "backward(None, x_unconstrained).eval()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "5b983dd8-8a70-4595-8f3e-da495eb95782", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[ 0. , 0. , 0. ],\n", | |
| " [ 1.41421356, -1.41421356, 0. ],\n", | |
| " [ 3.00000003, 3.00000003, -12.0000003 ]])" | |
| ] | |
| }, | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "pt.grad(forward(None, x_constrained).sum(), x_constrained).eval()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "id": "e5f1947f-0a17-42a7-84cc-86eeeed26ff2", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([2, 2, 1])" | |
| ] | |
| }, | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "x_unconstrained.eval()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "id": "d4723307-78ed-40bf-ac22-d8fff5a15e8b", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([-3.70370370e-02, -3.70370370e-02, 1.11022302e-16])" | |
| ] | |
| }, | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "x = pt.vector(\"x\")\n", | |
| "pt.grad(backward(None, x).sum(), x).eval({x: x_unconstrained.eval()})" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "id": "79225742-9dc9-410e-93dc-32cbeb073be7", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([0., 0., 0.])" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "pt.grad(backward(None, x_unconstrained).sum(), x_unconstrained).eval()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "id": "6232b0b0-9261-433b-8af5-6d8fcfff9e73", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from pytensor.graph import rewrite_graph\n", | |
| "\n", | |
| "out = rewrite_graph(\n", | |
| " pt.jacobian(backward(None, x), x, vectorize=False),\n", | |
| " include=(\"fast_run\",), exclude=(\"inplace\",),\n", | |
| ").eval({x: x_unconstrained.eval()})\n", | |
| "# np.log(np.abs(out)).sum()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "id": "ce47e271-cdd5-45e2-911e-7aa87fad88f9", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[[ 0. , 0. , 0. ],\n", | |
| " [ 0. , 0. , 0. ],\n", | |
| " [ 0. , 0. , 0. ]],\n", | |
| "\n", | |
| " [[ 0. , 0. , 0.35355339],\n", | |
| " [ 0. , 0. , -0.35355339],\n", | |
| " [ 0. , 0. , 0. ]],\n", | |
| "\n", | |
| " [[-0.14814815, 0.18518519, 0. ],\n", | |
| " [ 0.18518519, -0.14814815, 0. ],\n", | |
| " [-0.07407407, -0.07407407, 0. ]]])" | |
| ] | |
| }, | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "out" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "id": "54081057-be31-4e66-a1a0-929b462958d3", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[ 0. , 0. , 0. ],\n", | |
| " [ 1.41421356, -1.41421356, 0. ],\n", | |
| " [ 3.00000003, 3.00000003, -12.0000003 ]])" | |
| ] | |
| }, | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X = pt.matrix(\"X\")\n", | |
| "out = pt.jacobian(\n", | |
| " forward(None, X).sum(), \n", | |
| " X, \n", | |
| " vectorize=True,\n", | |
| ").eval({X: x_constrained.eval()})\n", | |
| "out" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "91cbf60a-0d7a-4c47-b0fe-12b6d84cb5d8", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "pymc-dev", | |
| "language": "python", | |
| "name": "pymc-dev" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.12.8" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment