Created
October 20, 2025 10:08
-
-
Save ricardoV94/1f579574570e347b4422470d9cad114f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "d9a581f1-a09f-450d-bb31-7776ebe7eb12", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from itertools import pairwise\n", | |
| "from functools import partial\n", | |
| "import numpy as np\n", | |
| "import numba\n", | |
| "\n", | |
| "scan_njit = partial(numba.njit, boundscheck=False)\n", | |
| "\n", | |
| "@numba.njit(fastmath=True)\n", | |
| "def inner_f(x):\n", | |
| " return (x * x.T) / np.max(x)\n", | |
| "\n", | |
| "@scan_njit\n", | |
| "def scan0(n, xs):\n", | |
| " # Read and write in every iteration\n", | |
| " assert xs.shape[0] >= 1\n", | |
| " x_size = xs.shape[0]\n", | |
| " for i in range(n):\n", | |
| " x = xs[i % x_size]\n", | |
| " next_x = inner_f(x)\n", | |
| " xs[(i + 1) % x_size] = next_x\n", | |
| " return xs\n", | |
| "\n", | |
| "@scan_njit\n", | |
| "def scan1(n, xs):\n", | |
| " # Write in every iteration\n", | |
| " if n == 0:\n", | |
| " return xs\n", | |
| " \n", | |
| " assert xs.shape[0] >= 1\n", | |
| " x_size = xs.shape[0]\n", | |
| " x = xs[0]\n", | |
| " for i in range(n):\n", | |
| " next_x = inner_f(x)\n", | |
| " x = next_x\n", | |
| " xs[(i + 1) % x_size] = next_x\n", | |
| " return xs\n", | |
| "\n", | |
| "@scan_njit\n", | |
| "def scan2(n, xs):\n", | |
| " # If size==1, only write once at end\n", | |
| " # Otherwise write and read in every iteration\n", | |
| " if n == 0:\n", | |
| " return xs\n", | |
| "\n", | |
| " assert xs.shape[0] >= 1\n", | |
| " x_size = xs.shape[0]\n", | |
| " \n", | |
| " if x_size == 1:\n", | |
| " x = xs[0]\n", | |
| " for i in range(n):\n", | |
| " x = inner_f(x)\n", | |
| "\n", | |
| " xs[0] = x\n", | |
| " else: \n", | |
| " for i in range(n):\n", | |
| " x = xs[i % x_size]\n", | |
| " next_x = inner_f(x)\n", | |
| " xs[(i + 1) % x_size] = next_x\n", | |
| " return xs\n", | |
| "\n", | |
| "@scan_njit\n", | |
| "def scan3(n, xs):\n", | |
| " # If size==1 only write once at end\n", | |
| " # Otherwise write in every iteration\n", | |
| " if n == 0:\n", | |
| " return xs\n", | |
| "\n", | |
| " assert xs.shape[0] >= 1\n", | |
| " x_size = xs.shape[0]\n", | |
| " x = xs[0]\n", | |
| " \n", | |
| " if x_size == 1:\n", | |
| " for i in range(n):\n", | |
| " next_x = inner_f(x)\n", | |
| " x = next_x\n", | |
| "\n", | |
| " xs[0] = x\n", | |
| " else: \n", | |
| " for i in range(n):\n", | |
| " next_x = inner_f(x)\n", | |
| " x = next_x\n", | |
| " xs[(i + 1) % x_size] = next_x\n", | |
| " return xs\n", | |
| "\n", | |
| "scans = (scan0, scan1, scan2, scan3)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "82a1cfbb-b838-42a4-be13-2168320c8a9d", | |
| "metadata": {}, | |
| "source": [ | |
| "### Full trace is kept" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "ed85cc23-d0fe-4a11-8822-6301d6317e9d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "n_steps = 300\n", | |
| "xs = np.full((n_steps+10, 30), np.nan)\n", | |
| "xs[0] = np.array([0.1, 0.2, 0.3] * 10)\n", | |
| "# jit and check correctness\n", | |
| "for scan0, scan1 in pairwise(scans):\n", | |
| " np.testing.assert_allclose(\n", | |
| " scan0(n_steps, xs.copy()),\n", | |
| " scan1(n_steps, xs.copy()),\n", | |
| " )" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "b5e442df-b4f4-4642-9bea-47fbda32da57", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "37.7 μs ± 1.53 μs per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n", | |
| "37.4 μs ± 1.01 μs per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n", | |
| "38.3 μs ± 534 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n", | |
| "37.9 μs ± 922 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for scan in scans:\n", | |
| " %timeit -r 14 -n 10_000 scan(n_steps, xs.copy())" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "28cb96bb-83c9-4092-a78f-311fb20011ff", | |
| "metadata": {}, | |
| "source": [ | |
| "### Only last entry is kept" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "3f8f6059-26d6-4b36-be61-156e4abc7ab8", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "ys = xs[:1]\n", | |
| "# jit and check correctness\n", | |
| "for scan0, scan1 in pairwise(scans):\n", | |
| " np.testing.assert_allclose(\n", | |
| " scan0(n_steps, ys.copy()),\n", | |
| " scan1(n_steps, ys.copy()),\n", | |
| " )" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "066f0e42-f97d-4967-b975-c455aa898c21", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "35.8 μs ± 456 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n", | |
| "36.6 μs ± 1.69 μs per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n", | |
| "28.7 μs ± 251 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n", | |
| "29.8 μs ± 577 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for scan in scans:\n", | |
| " %timeit -r 14 -n 10_000 scan(n_steps, ys.copy())" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "70f977d5-0b72-4a01-b430-4ebee532f975", | |
| "metadata": {}, | |
| "source": [ | |
| "## Scalar case that has to be wrapped in a 0d numpy array" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "411fd138-fd3d-4fb7-9fa2-74d71f5f2b5f", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "@numba.njit(fastmath=True)\n", | |
| "def inner_f(x):\n", | |
| " return np.array(x * 0.999)\n", | |
| "\n", | |
| "@scan_njit\n", | |
| "def scan0(n, xs):\n", | |
| " assert xs.shape[0] >= 1\n", | |
| " x_size = xs.shape[0]\n", | |
| " x_temp = np.empty((), dtype=xs.dtype)\n", | |
| " for i in range(n):\n", | |
| " x_temp[()] = xs[i % x_size]\n", | |
| " next_x = inner_f(x_temp)\n", | |
| " xs[(i + 1) % x_size] = next_x\n", | |
| " return xs\n", | |
| "\n", | |
| "@scan_njit\n", | |
| "def scan1(n, xs):\n", | |
| " if n == 0:\n", | |
| " return xs\n", | |
| "\n", | |
| " assert xs.shape[0] >= 1\n", | |
| " x_size = xs.shape[0]\n", | |
| " x = np.asarray(xs[0])\n", | |
| " for i in range(n):\n", | |
| " next_x = inner_f(x)\n", | |
| " x = next_x\n", | |
| " xs[(i + 1) % x_size] = next_x\n", | |
| " return xs\n", | |
| "\n", | |
| "@scan_njit\n", | |
| "def scan2(n, xs):\n", | |
| " if n == 0:\n", | |
| " return xs\n", | |
| "\n", | |
| " assert xs.shape[0] >= 1\n", | |
| " x_size = xs.shape[0]\n", | |
| "\n", | |
| " if x_size == 1:\n", | |
| " x = np.asarray(xs[0])\n", | |
| " for i in range(n):\n", | |
| " x = inner_f(x)\n", | |
| " xs[0] = x\n", | |
| " else:\n", | |
| " x_temp = np.empty((), dtype=xs.dtype)\n", | |
| " for i in range(n):\n", | |
| " x_temp[()] = xs[i % x_size]\n", | |
| " next_x = inner_f(x_temp)\n", | |
| " xs[(i + 1) % x_size] = next_x\n", | |
| "\n", | |
| " return xs\n", | |
| "\n", | |
| "@scan_njit\n", | |
| "def scan3(n, xs):\n", | |
| " if n == 0:\n", | |
| " return xs\n", | |
| "\n", | |
| " assert xs.shape[0] >= 1\n", | |
| " x_size = xs.shape[0]\n", | |
| " x = np.asarray(xs[0])\n", | |
| "\n", | |
| " if x_size == 1:\n", | |
| " for i in range(n):\n", | |
| " next_x = inner_f(x)\n", | |
| " x = next_x\n", | |
| " xs[0] = x\n", | |
| " else:\n", | |
| " for i in range(n):\n", | |
| " next_x = inner_f(x)\n", | |
| " x = next_x\n", | |
| " xs[(i + 1) % x_size] = next_x\n", | |
| "\n", | |
| " return xs\n", | |
| "\n", | |
| "scans = (scan0, scan1, scan2, scan3)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "939d5fdc-2b65-4e51-8ba2-8a82bd502b60", | |
| "metadata": {}, | |
| "source": [ | |
| "### Full trace is kept" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "f0d87666-5e95-47f2-a62b-3ec4c1d00628", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "n_steps = 300\n", | |
| "xs = np.full((n_steps+1,), np.nan)\n", | |
| "xs[0] = 900\n", | |
| "# jit and check correctness\n", | |
| "for scan0, scan1 in pairwise(scans):\n", | |
| " np.testing.assert_allclose(\n", | |
| " scan0(n_steps, xs.copy()),\n", | |
| " scan1(n_steps, xs.copy()),\n", | |
| " )" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "847d536f-46fd-46ec-b14f-80f76efef0b6", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "19.9 μs ± 536 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n", | |
| "23.6 μs ± 393 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n", | |
| "20.2 μs ± 452 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n", | |
| "23.9 μs ± 713 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for scan in scans:\n", | |
| " %timeit -r 14 -n 10_000 scan(n_steps, xs.copy())" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "2e6430cd-3a4a-4b01-80f5-e4892d5f76ba", | |
| "metadata": {}, | |
| "source": [ | |
| "### Only last entry is kept" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "98f52e61-e4a7-48c0-9429-1417c829aab3", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "ys = xs[:1]\n", | |
| "for scan0, scan1 in pairwise(scans):\n", | |
| " np.testing.assert_allclose(\n", | |
| " scan0(n_steps, ys.copy()),\n", | |
| " scan1(n_steps, ys.copy()),\n", | |
| " )" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "7bf7fced-dca5-4a88-92a1-4ff8657cf78a", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "20 μs ± 606 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n", | |
| "23.9 μs ± 641 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n", | |
| "9.01 μs ± 182 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n", | |
| "8.95 μs ± 87.2 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for scan in scans:\n", | |
| " %timeit -r 14 -n 10_000 scan(n_steps, ys.copy())" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "66716326-7ff9-4912-b712-15bac42c01c7", | |
| "metadata": {}, | |
| "source": [ | |
| "## Scan with multiple taps" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "id": "843e1796-eff9-4fac-a5ea-c2d954ae5201", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "@numba.njit(fastmath=True)\n", | |
| "def inner_f(x, y):\n", | |
| " return (x + y) / 2\n", | |
| "\n", | |
| "@scan_njit\n", | |
| "def scan0(n, xs):\n", | |
| " assert xs.shape[0] >= 2\n", | |
| " x_size = xs.shape[0]\n", | |
| " for i in range(n):\n", | |
| " x0 = xs[i % x_size]\n", | |
| " x1 = xs[(i + 1) % x_size]\n", | |
| " next_x = inner_f(x0, x1)\n", | |
| " xs[(i + 2) % x_size] = next_x\n", | |
| " return xs\n", | |
| "\n", | |
| "@scan_njit\n", | |
| "def scan1(n, xs):\n", | |
| " if n == 0:\n", | |
| " return xs\n", | |
| "\n", | |
| " assert xs.shape[0] >= 2\n", | |
| " x_size = xs.shape[0]\n", | |
| " x0 = xs[0]\n", | |
| " x1 = xs[1]\n", | |
| " \n", | |
| " for i in range(n):\n", | |
| " next_x = inner_f(x0, x1)\n", | |
| " x0 = x1\n", | |
| " x1 = next_x\n", | |
| " xs[(i + 2) % x_size] = next_x\n", | |
| " return xs\n", | |
| "\n", | |
| "@scan_njit\n", | |
| "def scan2(n, xs):\n", | |
| " if n == 0:\n", | |
| " return xs\n", | |
| "\n", | |
| " assert xs.shape[0] >= 2\n", | |
| " x_size = xs.shape[0]\n", | |
| " \n", | |
| " if x_size == 2:\n", | |
| " x0 = xs[0]\n", | |
| " x1 = xs[1]\n", | |
| " for i in range(n):\n", | |
| " next_x = inner_f(x0, x1)\n", | |
| " x0 = x1\n", | |
| " x1 = next_x\n", | |
| "\n", | |
| " xs[0] = x0\n", | |
| " xs[1] = x1\n", | |
| " else: \n", | |
| " for i in range(n):\n", | |
| " x0 = xs[i % x_size]\n", | |
| " x1 = xs[(i + 1) % x_size]\n", | |
| " next_x = inner_f(x0, x1)\n", | |
| " xs[(i + 2) % x_size] = next_x\n", | |
| " return xs\n", | |
| "\n", | |
| "@scan_njit\n", | |
| "def scan3(n, xs):\n", | |
| " if n == 0:\n", | |
| " return xs\n", | |
| "\n", | |
| " assert xs.shape[0] >= 2\n", | |
| " x_size = xs.shape[0]\n", | |
| " x0 = xs[0]\n", | |
| " x1 = xs[1] \n", | |
| " \n", | |
| " if x_size == 2:\n", | |
| " for i in range(n):\n", | |
| " next_x = inner_f(x0, x1)\n", | |
| " x0 = x1\n", | |
| " x1 = next_x\n", | |
| " xs[0] = x0\n", | |
| " xs[1] = x1\n", | |
| " else: \n", | |
| " for i in range(n):\n", | |
| " next_x = inner_f(x0, x1)\n", | |
| " x0 = x1\n", | |
| " x1 = next_x\n", | |
| " xs[(i + 2) % x_size] = next_x\n", | |
| " return xs\n", | |
| "\n", | |
| "scans = (scan0, scan1, scan2, scan3)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "acacd657-38fd-4199-affc-c60e9ac636d3", | |
| "metadata": {}, | |
| "source": [ | |
| "### Full trace is kept" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "id": "87646adf-0d75-48c4-8061-4b925724483f", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "n_steps = 300\n", | |
| "xs = np.full((n_steps+2, 3), np.nan)\n", | |
| "xs[0] = [1, 2, 3]\n", | |
| "xs[1] = [999, -999, 999]\n", | |
| "# jit and check correctness\n", | |
| "for f0, f1 in pairwise(scans):\n", | |
| " np.testing.assert_allclose(\n", | |
| " scan0(n_steps, xs.copy()),\n", | |
| " scan1(n_steps, xs.copy()),\n", | |
| " )" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "id": "cfd19003-706b-432d-8a78-4fba488e5f4d", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "27.2 μs ± 753 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n", | |
| "23.1 μs ± 654 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n", | |
| "28 μs ± 5.48 μs per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n", | |
| "24.7 μs ± 5.18 μs per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for scan in scans:\n", | |
| " %timeit -r 14 -n 10_000 scan(n_steps, xs.copy())" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "3e4f5709-d351-4b28-8f3c-95ef0f4e872c", | |
| "metadata": {}, | |
| "source": [ | |
| "### Only last entry is kept" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "id": "935dc912-7cf2-47df-ad95-fb8cfcf75659", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "ys = xs[:2]\n", | |
| "for f0, f1 in pairwise(scans):\n", | |
| " np.testing.assert_allclose(\n", | |
| " scan0(n_steps, ys.copy()),\n", | |
| " scan1(n_steps, ys.copy()),\n", | |
| " )" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "id": "f7095126-653f-403d-a8a0-ec1fcbcd02d2", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "26.1 μs ± 610 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n", | |
| "22.3 μs ± 378 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n", | |
| "17.6 μs ± 228 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n", | |
| "17.7 μs ± 356 ns per loop (mean ± std. dev. of 14 runs, 10,000 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for scan in scans:\n", | |
| " %timeit -r 14 -n 10_000 scan(n_steps, ys.copy())" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "b14b4563-0155-4c90-b970-e9c0fed82a85", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "pytensor-dev", | |
| "language": "python", | |
| "name": "pytensor-dev" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.12.8" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment