Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Created October 20, 2025 10:08
Show Gist options
  • Save ricardoV94/1f579574570e347b4422470d9cad114f to your computer and use it in GitHub Desktop.
Save ricardoV94/1f579574570e347b4422470d9cad114f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d9a581f1-a09f-450d-bb31-7776ebe7eb12",
"metadata": {},
"outputs": [],
"source": [
"from itertools import pairwise\n",
"from functools import partial\n",
"import numpy as np\n",
"import numba\n",
"\n",
"scan_njit = partial(numba.njit, boundscheck=False)\n",
"\n",
"@numba.njit(fastmath=True)\n",
"def inner_f(x):\n",
" return (x * x.T) / np.max(x)\n",
"\n",
"@scan_njit\n",
"def scan0(n, xs):\n",
" # Read and write in every iteration\n",
" assert xs.shape[0] >= 1\n",
" x_size = xs.shape[0]\n",
" for i in range(n):\n",
" x = xs[i % x_size]\n",
" next_x = inner_f(x)\n",
" xs[(i + 1) % x_size] = next_x\n",
" return xs\n",
"\n",
"@scan_njit\n",
"def scan1(n, xs):\n",
" # Write in every iteration\n",
" if n == 0:\n",
" return xs\n",
" \n",
" assert xs.shape[0] >= 1\n",
" x_size = xs.shape[0]\n",
" x = xs[0]\n",
" for i in range(n):\n",
" next_x = inner_f(x)\n",
" x = next_x\n",
" xs[(i + 1) % x_size] = next_x\n",
" return xs\n",
"\n",
"@scan_njit\n",
"def scan2(n, xs):\n",
" # If size==1, only write once at end\n",
" # Otherwise write and read in every iteration\n",
" if n == 0:\n",
" return xs\n",
"\n",
" assert xs.shape[0] >= 1\n",
" x_size = xs.shape[0]\n",
" \n",
" if x_size == 1:\n",
" x = xs[0]\n",
" for i in range(n):\n",
" x = inner_f(x)\n",
"\n",
" xs[0] = x\n",
" else: \n",
" for i in range(n):\n",
" x = xs[i % x_size]\n",
" next_x = inner_f(x)\n",
" xs[(i + 1) % x_size] = next_x\n",
" return xs\n",
"\n",
"@scan_njit\n",
"def scan3(n, xs):\n",
" # If size==1 only write once at end\n",
" # Otherwise write in every iteration\n",
" if n == 0:\n",
" return xs\n",
"\n",
" assert xs.shape[0] >= 1\n",
" x_size = xs.shape[0]\n",
" x = xs[0]\n",
" \n",
" if x_size == 1:\n",
" for i in range(n):\n",
" next_x = inner_f(x)\n",
" x = next_x\n",
"\n",
" xs[0] = x\n",
" else: \n",
" for i in range(n):\n",
" next_x = inner_f(x)\n",
" x = next_x\n",
" xs[(i + 1) % x_size] = next_x\n",
" return xs\n",
"\n",
"scans = (scan0, scan1, scan2, scan3)"
]
},
{
"cell_type": "markdown",
"id": "82a1cfbb-b838-42a4-be13-2168320c8a9d",
"metadata": {},
"source": [
"### Full trace is kept"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ed85cc23-d0fe-4a11-8822-6301d6317e9d",
"metadata": {},
"outputs": [],
"source": [
"n_steps = 300\n",
"xs = np.full((n_steps+10, 30), np.nan)\n",
"xs[0] = np.array([0.1, 0.2, 0.3] * 10)\n",
"# jit and check correctness\n",
"for scan0, scan1 in pairwise(scans):\n",
" np.testing.assert_allclose(\n",
" scan0(n_steps, xs.copy()),\n",
" scan1(n_steps, xs.copy()),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b5e442df-b4f4-4642-9bea-47fbda32da57",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"37.7 μs ± 1.53 μs per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n",
"37.4 μs ± 1.01 μs per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n",
"38.3 μs ± 534 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n",
"37.9 μs ± 922 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n"
]
}
],
"source": [
"for scan in scans:\n",
" %timeit -r 14 -n 10_000 scan(n_steps, xs.copy())"
]
},
{
"cell_type": "markdown",
"id": "28cb96bb-83c9-4092-a78f-311fb20011ff",
"metadata": {},
"source": [
"### Only last entry is kept"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3f8f6059-26d6-4b36-be61-156e4abc7ab8",
"metadata": {},
"outputs": [],
"source": [
"ys = xs[:1]\n",
"# jit and check correctness\n",
"for scan0, scan1 in pairwise(scans):\n",
" np.testing.assert_allclose(\n",
" scan0(n_steps, ys.copy()),\n",
" scan1(n_steps, ys.copy()),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "066f0e42-f97d-4967-b975-c455aa898c21",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"35.8 μs ± 456 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n",
"36.6 μs ± 1.69 μs per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n",
"28.7 μs ± 251 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n",
"29.8 μs ± 577 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n"
]
}
],
"source": [
"for scan in scans:\n",
" %timeit -r 14 -n 10_000 scan(n_steps, ys.copy())"
]
},
{
"cell_type": "markdown",
"id": "70f977d5-0b72-4a01-b430-4ebee532f975",
"metadata": {},
"source": [
"## Scalar case that has to be wrapped in a 0d numpy array"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "411fd138-fd3d-4fb7-9fa2-74d71f5f2b5f",
"metadata": {},
"outputs": [],
"source": [
"@numba.njit(fastmath=True)\n",
"def inner_f(x):\n",
" return np.array(x * 0.999)\n",
"\n",
"@scan_njit\n",
"def scan0(n, xs):\n",
" assert xs.shape[0] >= 1\n",
" x_size = xs.shape[0]\n",
" x_temp = np.empty((), dtype=xs.dtype)\n",
" for i in range(n):\n",
" x_temp[()] = xs[i % x_size]\n",
" next_x = inner_f(x_temp)\n",
" xs[(i + 1) % x_size] = next_x\n",
" return xs\n",
"\n",
"@scan_njit\n",
"def scan1(n, xs):\n",
" if n == 0:\n",
" return xs\n",
"\n",
" assert xs.shape[0] >= 1\n",
" x_size = xs.shape[0]\n",
" x = np.asarray(xs[0])\n",
" for i in range(n):\n",
" next_x = inner_f(x)\n",
" x = next_x\n",
" xs[(i + 1) % x_size] = next_x\n",
" return xs\n",
"\n",
"@scan_njit\n",
"def scan2(n, xs):\n",
" if n == 0:\n",
" return xs\n",
"\n",
" assert xs.shape[0] >= 1\n",
" x_size = xs.shape[0]\n",
"\n",
" if x_size == 1:\n",
" x = np.asarray(xs[0])\n",
" for i in range(n):\n",
" x = inner_f(x)\n",
" xs[0] = x\n",
" else:\n",
" x_temp = np.empty((), dtype=xs.dtype)\n",
" for i in range(n):\n",
" x_temp[()] = xs[i % x_size]\n",
" next_x = inner_f(x_temp)\n",
" xs[(i + 1) % x_size] = next_x\n",
"\n",
" return xs\n",
"\n",
"@scan_njit\n",
"def scan3(n, xs):\n",
" if n == 0:\n",
" return xs\n",
"\n",
" assert xs.shape[0] >= 1\n",
" x_size = xs.shape[0]\n",
" x = np.asarray(xs[0])\n",
"\n",
" if x_size == 1:\n",
" for i in range(n):\n",
" next_x = inner_f(x)\n",
" x = next_x\n",
" xs[0] = x\n",
" else:\n",
" for i in range(n):\n",
" next_x = inner_f(x)\n",
" x = next_x\n",
" xs[(i + 1) % x_size] = next_x\n",
"\n",
" return xs\n",
"\n",
"scans = (scan0, scan1, scan2, scan3)"
]
},
{
"cell_type": "markdown",
"id": "939d5fdc-2b65-4e51-8ba2-8a82bd502b60",
"metadata": {},
"source": [
"### Full trace is kept"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "f0d87666-5e95-47f2-a62b-3ec4c1d00628",
"metadata": {},
"outputs": [],
"source": [
"n_steps = 300\n",
"xs = np.full((n_steps+1,), np.nan)\n",
"xs[0] = 900\n",
"# jit and check correctness\n",
"for scan0, scan1 in pairwise(scans):\n",
" np.testing.assert_allclose(\n",
" scan0(n_steps, xs.copy()),\n",
" scan1(n_steps, xs.copy()),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "847d536f-46fd-46ec-b14f-80f76efef0b6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"19.9 μs ± 536 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n",
"23.6 μs ± 393 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n",
"20.2 μs ± 452 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n",
"23.9 μs ± 713 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n"
]
}
],
"source": [
"for scan in scans:\n",
" %timeit -r 14 -n 10_000 scan(n_steps, xs.copy())"
]
},
{
"cell_type": "markdown",
"id": "2e6430cd-3a4a-4b01-80f5-e4892d5f76ba",
"metadata": {},
"source": [
"### Only last entry is kept"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "98f52e61-e4a7-48c0-9429-1417c829aab3",
"metadata": {},
"outputs": [],
"source": [
"ys = xs[:1]\n",
"for scan0, scan1 in pairwise(scans):\n",
" np.testing.assert_allclose(\n",
" scan0(n_steps, ys.copy()),\n",
" scan1(n_steps, ys.copy()),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "7bf7fced-dca5-4a88-92a1-4ff8657cf78a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"20 μs ± 606 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n",
"23.9 μs ± 641 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n",
"9.01 μs ± 182 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n",
"8.95 μs ± 87.2 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n"
]
}
],
"source": [
"for scan in scans:\n",
" %timeit -r 14 -n 10_000 scan(n_steps, ys.copy())"
]
},
{
"cell_type": "markdown",
"id": "66716326-7ff9-4912-b712-15bac42c01c7",
"metadata": {},
"source": [
"## Scan with multiple taps"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "843e1796-eff9-4fac-a5ea-c2d954ae5201",
"metadata": {},
"outputs": [],
"source": [
"@numba.njit(fastmath=True)\n",
"def inner_f(x, y):\n",
" return (x + y) / 2\n",
"\n",
"@scan_njit\n",
"def scan0(n, xs):\n",
" assert xs.shape[0] >= 2\n",
" x_size = xs.shape[0]\n",
" for i in range(n):\n",
" x0 = xs[i % x_size]\n",
" x1 = xs[(i + 1) % x_size]\n",
" next_x = inner_f(x0, x1)\n",
" xs[(i + 2) % x_size] = next_x\n",
" return xs\n",
"\n",
"@scan_njit\n",
"def scan1(n, xs):\n",
" if n == 0:\n",
" return xs\n",
"\n",
" assert xs.shape[0] >= 2\n",
" x_size = xs.shape[0]\n",
" x0 = xs[0]\n",
" x1 = xs[1]\n",
" \n",
" for i in range(n):\n",
" next_x = inner_f(x0, x1)\n",
" x0 = x1\n",
" x1 = next_x\n",
" xs[(i + 2) % x_size] = next_x\n",
" return xs\n",
"\n",
"@scan_njit\n",
"def scan2(n, xs):\n",
" if n == 0:\n",
" return xs\n",
"\n",
" assert xs.shape[0] >= 2\n",
" x_size = xs.shape[0]\n",
" \n",
" if x_size == 2:\n",
" x0 = xs[0]\n",
" x1 = xs[1]\n",
" for i in range(n):\n",
" next_x = inner_f(x0, x1)\n",
" x0 = x1\n",
" x1 = next_x\n",
"\n",
" xs[0] = x0\n",
" xs[1] = x1\n",
" else: \n",
" for i in range(n):\n",
" x0 = xs[i % x_size]\n",
" x1 = xs[(i + 1) % x_size]\n",
" next_x = inner_f(x0, x1)\n",
" xs[(i + 2) % x_size] = next_x\n",
" return xs\n",
"\n",
"@scan_njit\n",
"def scan3(n, xs):\n",
" if n == 0:\n",
" return xs\n",
"\n",
" assert xs.shape[0] >= 2\n",
" x_size = xs.shape[0]\n",
" x0 = xs[0]\n",
" x1 = xs[1] \n",
" \n",
" if x_size == 2:\n",
" for i in range(n):\n",
" next_x = inner_f(x0, x1)\n",
" x0 = x1\n",
" x1 = next_x\n",
" xs[0] = x0\n",
" xs[1] = x1\n",
" else: \n",
" for i in range(n):\n",
" next_x = inner_f(x0, x1)\n",
" x0 = x1\n",
" x1 = next_x\n",
" xs[(i + 2) % x_size] = next_x\n",
" return xs\n",
"\n",
"scans = (scan0, scan1, scan2, scan3)"
]
},
{
"cell_type": "markdown",
"id": "acacd657-38fd-4199-affc-c60e9ac636d3",
"metadata": {},
"source": [
"### Full trace is kept"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "87646adf-0d75-48c4-8061-4b925724483f",
"metadata": {},
"outputs": [],
"source": [
"n_steps = 300\n",
"xs = np.full((n_steps+2, 3), np.nan)\n",
"xs[0] = [1, 2, 3]\n",
"xs[1] = [999, -999, 999]\n",
"# jit and check correctness\n",
"for f0, f1 in pairwise(scans):\n",
" np.testing.assert_allclose(\n",
" scan0(n_steps, xs.copy()),\n",
" scan1(n_steps, xs.copy()),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "cfd19003-706b-432d-8a78-4fba488e5f4d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"27.2 μs ± 753 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n",
"23.1 μs ± 654 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n",
"28 μs ± 5.48 μs per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n",
"24.7 μs ± 5.18 μs per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n"
]
}
],
"source": [
"for scan in scans:\n",
" %timeit -r 14 -n 10_000 scan(n_steps, xs.copy())"
]
},
{
"cell_type": "markdown",
"id": "3e4f5709-d351-4b28-8f3c-95ef0f4e872c",
"metadata": {},
"source": [
"### Only last entry is kept"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "935dc912-7cf2-47df-ad95-fb8cfcf75659",
"metadata": {},
"outputs": [],
"source": [
"ys = xs[:2]\n",
"for f0, f1 in pairwise(scans):\n",
" np.testing.assert_allclose(\n",
" scan0(n_steps, ys.copy()),\n",
" scan1(n_steps, ys.copy()),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "f7095126-653f-403d-a8a0-ec1fcbcd02d2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"26.1 μs ± 610 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n",
"22.3 μs ± 378 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n",
"17.6 μs ± 228 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n",
"17.7 μs ± 356 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n"
]
}
],
"source": [
"for scan in scans:\n",
" %timeit -r 14 -n 10_000 scan(n_steps, ys.copy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b14b4563-0155-4c90-b970-e9c0fed82a85",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pytensor-dev",
"language": "python",
"name": "pytensor-dev"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment