Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Last active October 4, 2025 17:41
Show Gist options
  • Save ricardoV94/d75d42d5c3cdde8827135bcbefc9405a to your computer and use it in GitHub Desktop.
Save ricardoV94/d75d42d5c3cdde8827135bcbefc9405a to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "067a3cc1-1ed2-4c28-a113-0bd24c433f45",
"metadata": {},
"outputs": [],
"source": [
"import pytensor.tensor as pt\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3fb48d6c-5d8b-46ae-a09d-e6622669b72e",
"metadata": {},
"outputs": [],
"source": [
"def fill_triangular_spiral(x_raveled, n=None, upper=False):\n",
" # https://github.com/tensorflow/probability/blob/a26f4cbe5ce1549767e13798d9bf5032dac4257b/tensorflow_probability/python/math/linalg.py#L925\n",
" x_raveled = pt.as_tensor(x_raveled)\n",
" *batch_shape, m = x_raveled.shape\n",
" if n is None:\n",
" n = pt.cast(pt.sqrt(0.25 + 2 * m) - 0.5, \"int32\")\n",
" tail = x_raveled[..., n:]\n",
" \n",
" def reverse(x):\n",
" return x[..., ::-1]\n",
" \n",
" if upper:\n",
" xc = pt.concatenate([x_raveled, reverse(tail)])\n",
" else:\n",
" xc = pt.concatenate([tail, reverse(x_raveled)])\n",
" \n",
" y = pt.reshape(xc, (*batch_shape, n, n))\n",
" return pt.triu(y) if upper else pt.tril(y)\n",
"\n",
"def inverse_fill_triangular_spiral(x, m=None, upper=False):\n",
" x = pt.as_tensor(x)\n",
" *batch_shape, n, n = x.shape\n",
" \n",
" if m is None:\n",
" m = pt.cast((n * (n + 1)) // 2, \"int32\")\n",
" \n",
" ndim = x.ndim\n",
" if upper:\n",
" initial_elements = x[..., 0, :]\n",
" triangular_portion = x[..., 1:, :]\n",
" else:\n",
" initial_elements = pt.flip(x[..., -1, :], axis=-1)\n",
" triangular_portion = x[..., :-1, :]\n",
" # return initial_elements, triangular_portion\n",
" rotated_triangular_portion = pt.flip(triangular_portion, axis=(-1, -2))\n",
" consolidated_matrix = triangular_portion + rotated_triangular_portion\n",
" end_sequence = pt.reshape(\n",
" consolidated_matrix,\n",
" (*batch_shape, pt.cast(n * (n -1), \"int64\")),\n",
" )\n",
" y = pt.concatenate([initial_elements, end_sequence[..., :m - n]], axis=-1)\n",
" return y \n",
"\n",
"\n",
"def test_fill_triangular_spiral():\n",
" x_unconstrained = np.array([1, 2, 3, 4, 5, 6])\n",
" x_constrained_lower = np.array([[4, 0, 0], [6, 5, 0], [3, 2, 1],])\n",
" x_constrained_upper = np.array([[1, 2, 3], [0, 5, 6], [0, 0, 4],])\n",
"\n",
" np.testing.assert_allclose(\n",
" fill_triangular_spiral(x_unconstrained).eval(),\n",
" x_constrained_lower,\n",
" )\n",
" \n",
" np.testing.assert_allclose(\n",
" fill_triangular_spiral(x_unconstrained, upper=True).eval(),\n",
" x_constrained_upper,\n",
" )\n",
"\n",
" np.testing.assert_allclose(\n",
" inverse_fill_triangular_spiral(x_constrained_lower).eval(),\n",
" x_unconstrained, \n",
" )\n",
" \n",
" np.testing.assert_allclose(\n",
" inverse_fill_triangular_spiral(x_constrained_upper, upper=True).eval(),\n",
" x_unconstrained, \n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "95512889-d6e6-4573-aec5-11a2e0bc0eed",
"metadata": {},
"outputs": [],
"source": [
"test_fill_triangular_spiral()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2e078480-52e4-46da-aed3-f01fb2a3d1b7",
"metadata": {},
"outputs": [],
"source": [
"def backward(self, x, *inputs):\n",
" x = pt.as_tensor(x)\n",
" *batch_shape, core_shape = x.shape\n",
"\n",
" y = fill_triangular_spiral(x)\n",
" n = y.shape[-1]\n",
"\n",
" # Pad zeros on the top row and right column.\n",
" paddings = [*([(0, 0)] * (y.ndim - 2)), [1, 0], [0, 1]]\n",
" y = pt.pad(y, paddings)\n",
" \n",
" # # Set diagonal to 1s.\n",
" arange = pt.arange(n + 1)\n",
" y = y[..., arange, arange].set(1)\n",
"\n",
" # Normalize each row to have Euclidean (L2) norm 1.\n",
" y /= pt.linalg.norm(y, axis=-1, ord=2)[..., None]\n",
" return y\n",
"\n",
"# x: constrained -> y: unconstrained\n",
"def forward(self, y, *inputs):\n",
" y = pt.as_tensor(y)\n",
" *batch_shape, n = y.shape\n",
"\n",
" # Extract the reciprocal of the row norms from the diagonal.\n",
" diag = pt.diagonal(y, axis1=-2, axis2=-1)[..., None]\n",
"\n",
" # Set the diagonal to 0s.\n",
" arange = pt.arange(n)\n",
" y = y[..., arange, arange].set(0)\n",
" \n",
" # Multiply with the norm (or divide by its reciprocal) to recover the\n",
" # unconstrained reals in the (strictly) lower triangular part.\n",
" x = y / diag\n",
"\n",
" # Remove the first row and last column before inverting the FillTriangular\n",
" # transformation.\n",
" return inverse_fill_triangular_spiral(x[..., 1:, :-1])\n",
"\n",
"def log_jac_det(self, x, *inputs):\n",
" y = backward(None, x, *inputs)\n",
" n = y.shape[-1]\n",
" return -pt.sum(\n",
" pt.arange(2, 2 + n, dtype=x.dtype) * pt.log(pt.diagonal(y, axis1=-2, axis2=-1)), \n",
" axis=-1,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "afd21694-f720-44ed-9761-b9423c67276b",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c9a1dda2-62a9-40ed-bbc8-2ad373265019",
"metadata": {},
"outputs": [],
"source": [
"x_unconstrained = pt.as_tensor([2, 2, 1])\n",
"x_constrained = pt.as_tensor(\n",
" [[ 1. , 0. , 0. ],\n",
" [ 0.70710678, 0.70710678, 0. ],\n",
" [ 0.66666667, 0.66666667, 0.33333333]]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "d8bd955a-fe73-4b27-9449-1fb546443e0d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([2., 2., 1.])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"forward(None, backward(None, x_unconstrained)).eval()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "02176309-cca5-496b-a958-3599c26ddda1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[1. , 0. , 0. ],\n",
" [0.70710678, 0.70710678, 0. ],\n",
" [0.66666667, 0.66666667, 0.33333333]])"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"backward(None, forward(None, x_constrained)).eval()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "379d291e-a399-4a7f-973c-8418283292d9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(5.43416993)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"log_jac_det(None, x_unconstrained).eval()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "437f9b78-f4d4-4c90-b720-78217834f8ee",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[1. , 0. , 0. ],\n",
" [0.70710678, 0.70710678, 0. ],\n",
" [0.66666667, 0.66666667, 0.33333333]])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"backward(None, x_unconstrained).eval()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "5b983dd8-8a70-4595-8f3e-da495eb95782",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0. , 0. , 0. ],\n",
" [ 1.41421356, -1.41421356, 0. ],\n",
" [ 3.00000003, 3.00000003, -12.0000003 ]])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pt.grad(forward(None, x_constrained).sum(), x_constrained).eval()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "e5f1947f-0a17-42a7-84cc-86eeeed26ff2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([2, 2, 1])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x_unconstrained.eval()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "d4723307-78ed-40bf-ac22-d8fff5a15e8b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([-3.70370370e-02, -3.70370370e-02, 1.11022302e-16])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = pt.vector(\"x\")\n",
"pt.grad(backward(None, x).sum(), x).eval({x: x_unconstrained.eval()})"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "79225742-9dc9-410e-93dc-32cbeb073be7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0., 0., 0.])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pt.grad(backward(None, x_unconstrained).sum(), x_unconstrained).eval()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "6232b0b0-9261-433b-8af5-6d8fcfff9e73",
"metadata": {},
"outputs": [],
"source": [
"from pytensor.graph import rewrite_graph\n",
"\n",
"out = rewrite_graph(\n",
" pt.jacobian(backward(None, x), x, vectorize=False),\n",
" include=(\"fast_run\",), exclude=(\"inplace\",),\n",
").eval({x: x_unconstrained.eval()})\n",
"# np.log(np.abs(out)).sum()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "ce47e271-cdd5-45e2-911e-7aa87fad88f9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[ 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. ]],\n",
"\n",
" [[ 0. , 0. , 0.35355339],\n",
" [ 0. , 0. , -0.35355339],\n",
" [ 0. , 0. , 0. ]],\n",
"\n",
" [[-0.14814815, 0.18518519, 0. ],\n",
" [ 0.18518519, -0.14814815, 0. ],\n",
" [-0.07407407, -0.07407407, 0. ]]])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "54081057-be31-4e66-a1a0-929b462958d3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0. , 0. , 0. ],\n",
" [ 1.41421356, -1.41421356, 0. ],\n",
" [ 3.00000003, 3.00000003, -12.0000003 ]])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = pt.matrix(\"X\")\n",
"out = pt.jacobian(\n",
" forward(None, X).sum(), \n",
" X, \n",
" vectorize=True,\n",
").eval({X: x_constrained.eval()})\n",
"out"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "91cbf60a-0d7a-4c47-b0fe-12b6d84cb5d8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pymc-dev",
"language": "python",
"name": "pymc-dev"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment