Last active
October 23, 2025 09:13
-
-
Save spezold/f7b9547efc4132cf1ca595a6628c4089 to your computer and use it in GitHub Desktop.
A demonstration of applying the parallel scan algorithm (Blelloch, 1990) to a first-order recursive problem
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Demonstrate the application of the parallel scan algorithm with a first-order recurrence problem, as proposed by | |
| Blelloch (1990). Harris et al. (2007) provide helpful illustrations and discuss a CUDA implementation; online version: | |
| https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda | |
| - G. E. Blelloch, “Prefix Sums and Their Applications,” School of Computer Science, Carnegie Mellon University, | |
| CMU-CS-90-190, Nov. 1990. | |
| - M. Harris, S. Sengupta, and J. D. Owens, “Parallel prefix sum (scan) with CUDA,” GPU gems, vol. 3, no. 39, pp. | |
| 851–876, 2007. | |
| """ | |
| from collections.abc import Sequence | |
| from copy import copy | |
| from math import log2 | |
| from typing import Protocol, TypeVar | |
| T = TypeVar("T") | |
| class Op(Protocol): | |
| def __call__(x: T, y: T, /) -> T: ... | |
| class Plus(Op): """Must be associative: ``(x+y)+z==x+(y+z)``, where ``x+y := Plus(x, y)``"""; pass | |
| class Dot(Op): """Must be associative: ``(x·y)·z==x·(y·z)``, where ``x·y := Dot(x, y)``"""; pass | |
| class Cross(Op): | |
| """ | |
| Must be: | |
| - semiassociative with ``Dot``: ``(x×y)×z==x×(y·z)``, where ``x×y := Cross(x, y)`` and ``x·y := Dot(x, y)`` | |
| - distributive over ``Plus``: ``x×(y+z)==(x×y)+(x×z)``, where ``x×y := Cross(x, y)`` and ``x+y := Plus(x, y)`` | |
| """ | |
| pass | |
| def is_power_of_2(n: int) -> bool: | |
| """Return True for ``n`` being a power of 2, False otherwise (CAUTION: does not check for NumPy's integer types)""" | |
| return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0 | |
| def s_s(a_s: Sequence[T], b_s: Sequence[T], *, plus: Plus, cross: Cross, dot: Dot | None = None, | |
| id_plus: T, id_cross: T, inclusive: bool = True) -> Sequence[tuple[T, T]]: | |
| """ | |
| Given the sequences ``a_s``, ``b_s``; the operators ``plus``, ``cross`` (, ``dot``); and their identity elements | |
| ``id_plus``, ``id_cross``; solve the recursive problem | |
| 1. ``h[0] = b[0]``, | |
| 2. ``h[i] = (h[i-1] × a[i]) + b[i] (i > 0)``, | |
| with the parallel scan algorithm, where ``a × b := cross(a, b)`` and ``a + b := plus(a, b)``. | |
| Note that the parallel scan algorithm is actually implemented sequentially here, since the goal is not efficiency | |
| but correctness. Also note that the sequence length must be a power of 2. | |
| :param a_s: "multiplicative" inputs | |
| :param b_s: "additive" inputs | |
| :param plus: associative binary operator to be applied to the "additive" inputs | |
| :param cross: semiassociative binary operator to be applied to the "multiplicative" inputs (see ``Cross`` for a | |
| more formal definition) | |
| :param dot: optional associative "companion operator" to ``cross`` (see ``Dot`` and ``Cross`` for a more formal | |
| definition) (default: None, i.e. use ``cross``, implying ``cross`` is "fully" associative) | |
| :param id_plus: identity element of ``plus`` | |
| :param id_cross: identity element of ``cross`` (and ``dot``) | |
| :param inclusive: if False, perform an exclusive scan, i.e. the result sequence is right-shifted by one element, | |
| with the identity element prepended and the previously last element missing; if True (default), perform an | |
| inclusive scan as defined above (note that the exclusive scan is actually the cheaper operation) | |
| :return: sequence of two-tuples, holding the resulting ``h[i] (i >= 0)`` from the original problem as 2nd elements | |
| """ | |
| assert (n := len(a_s)) == len(b_s) and is_power_of_2(n) | |
| dot = cross if dot is None else dot | |
| # Following Blelloch (1990), Sect. 1.4: define the new sequence, `bullet` operator, and identity element | |
| bullet = lambda c_i, c_j: (dot(c_i[0], c_j[0]), plus(cross(c_i[1], c_j[0]), c_j[1])) | |
| return scan(c_s=list(zip(a_s, b_s)), plus=bullet, id_=(id_cross, id_plus), inclusive=inclusive) | |
| def scan(c_s: Sequence[T], plus: Plus, id_: T, *, inclusive: bool = True) -> Sequence[T]: | |
| """ | |
| Given the sequence ``c_s``, the ``plus`` operator, and its identity element ``id_``, solve the problem | |
| 1. ``s[0] = c[0]``, | |
| 2. ``s[i] = s[i - 1] + c[i] (i > 0)``, | |
| with the parallel scan algorithm, where ``a + b := plus(a, b)``. | |
| Note that the parallel scan algorithm is actually implemented sequentially here, since the goal is not efficiency | |
| but correctness. Also note that the sequence length must be a power of 2. | |
| :param c_s: inputs | |
| :param plus: associative binary operator to be applied to the inputs | |
| :param id_: identity element of the binary operator | |
| :param inclusive: if False, perform an exclusive scan, i.e. the result sequence is right-shifted by one element, | |
| with the identity element prepended and the previously last element missing; if True (default), perform an | |
| inclusive scan as defined above (note that the exclusive scan is actually the cheaper operation) | |
| :return: sequence of resulting ``s[i] (i >= 0)`` | |
| """ | |
| assert is_power_of_2(n := len(c_s)) | |
| c_s = copy(c_s) | |
| if inclusive: | |
| last_element_orig = c_s[-1] # Store in temporary variable, as corresponding list element will be overwritten | |
| # Provide helpers for the up-sweep and down-sweep | |
| range_d = range(int(round(log2(n)))) | |
| idxs_lr = lambda d: (((i + 2 ** d - 1), (i + 2 ** (d + 1) - 1)) for i in range(0, n, 2 ** (d + 1))) | |
| # Implement the up-sweep (Blelloch, 1990, Fig. 1.2) | |
| for d in range_d: | |
| for idx_l, idx_r in idxs_lr(d): # This can be parallelized | |
| c_s[idx_r] = plus(c_s[idx_l], c_s[idx_r]) | |
| # Implement "clear" and down-sweep (Blelloch, 1990, Fig. 1.4). CAUTION: For non-commutative `plus()`, it is | |
| # crucial to always have `t` as the *2nd operand* in the last step! While this is not obvious from Fig. 1.4 (and | |
| # thus caused me some headache), it becomes clear on closer inspection: `t` holds the value of originally *larger* | |
| # index (cf. the visualization of steps in Harris et al., 2007, Fig. 39-4), so it must be "added" *last*. | |
| c_s[n - 1] = id_ # "Clear" (Fig. 1.4b) | |
| for d in reversed(range_d): | |
| for idx_l, idx_r in idxs_lr(d): # This can be parallelized | |
| t = c_s[idx_l] | |
| c_s[idx_l] = c_s[idx_r] | |
| c_s[idx_r] = plus(c_s[idx_r], t) | |
| if inclusive: | |
| # Harris et al. (2007): "An inclusive scan can be generated from an exclusive scan by shifting the resulting | |
| # array left and inserting at the end the sum of the last element of the scan and the last element of the input | |
| # array." Likewise Blelloch (1990): "The scan can be generated from the prescan by shifting left, and inserting | |
| # at the end the sum of the last element of the prescan and the last element of the original vector." | |
| c_s = c_s[1:] + [plus(c_s[-1], last_element_orig)] | |
| return c_s | |
| if __name__ == "__main__": | |
| from itertools import accumulate as acc | |
| import numpy as np | |
| random = np.random.default_rng(seed=42) | |
| # Check the `scan()` implementation, i.e. the actual parallel scan algorithm: Calculate cumulative sum with | |
| # `scan()`, using the values from Blelloch (1990), Figs. 1.2+1.4 | |
| c = [3, 1, 7, 0, 4, 1, 6, 3] | |
| s_ref = np.cumsum(c) | |
| s_scan = scan(c, plus = lambda a, b: a + b, id_ = 0) | |
| assert (s_ref == s_scan).all() | |
| # Check the `s_s()` implementation, i.e. the transformation of the recursive problem into one for the parallel scan | |
| # algorithm: Calculate an `n`-step recursive matrix-vector product (`d×d` matrices `A`, `d`-element vectors `b`) | |
| # | |
| # s[0] = b[0] | |
| # s[i] = A[i] @ s[i - 1] + b[i] (0 < i < n) | |
| n, d = 16, 5 | |
| b_s, a_s = random.normal(size=(n, d)), random.normal(size=(n, d, d)) | |
| pls, crs, id_pls, id_crs = (lambda v1, v2: np.add(v1, v2)), (lambda v, m: np.matmul(m, v)), np.zeros(d), np.eye(d) | |
| s_ref = acc(zip(a_s[1:], b_s[1:]), lambda s, a_b: pls(crs(s, a_b[0]), a_b[1]), initial=b_s[0]) # Sequential impl. | |
| s_scan = [s[1] for s in s_s(a_s=a_s, b_s=b_s, plus=pls, cross=crs, id_plus=id_pls, id_cross=id_crs)] | |
| assert np.allclose(list(s_ref), s_scan) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment