Skip to content

Instantly share code, notes, and snippets.

@spezold
Last active October 23, 2025 09:13
Show Gist options
  • Save spezold/f7b9547efc4132cf1ca595a6628c4089 to your computer and use it in GitHub Desktop.
Save spezold/f7b9547efc4132cf1ca595a6628c4089 to your computer and use it in GitHub Desktop.
A demonstration of applying the parallel scan algorithm (Blelloch, 1990) to a first-order recursive problem
"""
Demonstrate the application of the parallel scan algorithm with a first-order recurrence problem, as proposed by
Blelloch (1990). Harris et al. (2007) provide helpful illustrations and discuss a CUDA implementation; online version:
https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda
- G. E. Blelloch, “Prefix Sums and Their Applications,” School of Computer Science, Carnegie Mellon University,
CMU-CS-90-190, Nov. 1990.
- M. Harris, S. Sengupta, and J. D. Owens, “Parallel prefix sum (scan) with CUDA,” GPU gems, vol. 3, no. 39, pp.
851–876, 2007.
"""
from collections.abc import Sequence
from copy import copy
from math import log2
from typing import Protocol, TypeVar
T = TypeVar("T")
class Op(Protocol):
def __call__(x: T, y: T, /) -> T: ...
class Plus(Op): """Must be associative: ``(x+y)+z==x+(y+z)``, where ``x+y := Plus(x, y)``"""; pass
class Dot(Op): """Must be associative: ``(x·y)·z==x·(y·z)``, where ``x·y := Dot(x, y)``"""; pass
class Cross(Op):
"""
Must be:
- semiassociative with ``Dot``: ``(x×y)×z==x×(y·z)``, where ``x×y := Cross(x, y)`` and ``x·y := Dot(x, y)``
- distributive over ``Plus``: ``x×(y+z)==(x×y)+(x×z)``, where ``x×y := Cross(x, y)`` and ``x+y := Plus(x, y)``
"""
pass
def is_power_of_2(n: int) -> bool:
"""Return True for ``n`` being a power of 2, False otherwise (CAUTION: does not check for NumPy's integer types)"""
return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0
def s_s(a_s: Sequence[T], b_s: Sequence[T], *, plus: Plus, cross: Cross, dot: Dot | None = None,
id_plus: T, id_cross: T, inclusive: bool = True) -> Sequence[tuple[T, T]]:
"""
Given the sequences ``a_s``, ``b_s``; the operators ``plus``, ``cross`` (, ``dot``); and their identity elements
``id_plus``, ``id_cross``; solve the recursive problem
1. ``h[0] = b[0]``,
2. ``h[i] = (h[i-1] × a[i]) + b[i] (i > 0)``,
with the parallel scan algorithm, where ``a × b := cross(a, b)`` and ``a + b := plus(a, b)``.
Note that the parallel scan algorithm is actually implemented sequentially here, since the goal is not efficiency
but correctness. Also note that the sequence length must be a power of 2.
:param a_s: "multiplicative" inputs
:param b_s: "additive" inputs
:param plus: associative binary operator to be applied to the "additive" inputs
:param cross: semiassociative binary operator to be applied to the "multiplicative" inputs (see ``Cross`` for a
more formal definition)
:param dot: optional associative "companion operator" to ``cross`` (see ``Dot`` and ``Cross`` for a more formal
definition) (default: None, i.e. use ``cross``, implying ``cross`` is "fully" associative)
:param id_plus: identity element of ``plus``
:param id_cross: identity element of ``cross`` (and ``dot``)
:param inclusive: if False, perform an exclusive scan, i.e. the result sequence is right-shifted by one element,
with the identity element prepended and the previously last element missing; if True (default), perform an
inclusive scan as defined above (note that the exclusive scan is actually the cheaper operation)
:return: sequence of two-tuples, holding the resulting ``h[i] (i >= 0)`` from the original problem as 2nd elements
"""
assert (n := len(a_s)) == len(b_s) and is_power_of_2(n)
dot = cross if dot is None else dot
# Following Blelloch (1990), Sect. 1.4: define the new sequence, `bullet` operator, and identity element
bullet = lambda c_i, c_j: (dot(c_i[0], c_j[0]), plus(cross(c_i[1], c_j[0]), c_j[1]))
return scan(c_s=list(zip(a_s, b_s)), plus=bullet, id_=(id_cross, id_plus), inclusive=inclusive)
def scan(c_s: Sequence[T], plus: Plus, id_: T, *, inclusive: bool = True) -> Sequence[T]:
"""
Given the sequence ``c_s``, the ``plus`` operator, and its identity element ``id_``, solve the problem
1. ``s[0] = c[0]``,
2. ``s[i] = s[i - 1] + c[i] (i > 0)``,
with the parallel scan algorithm, where ``a + b := plus(a, b)``.
Note that the parallel scan algorithm is actually implemented sequentially here, since the goal is not efficiency
but correctness. Also note that the sequence length must be a power of 2.
:param c_s: inputs
:param plus: associative binary operator to be applied to the inputs
:param id_: identity element of the binary operator
:param inclusive: if False, perform an exclusive scan, i.e. the result sequence is right-shifted by one element,
with the identity element prepended and the previously last element missing; if True (default), perform an
inclusive scan as defined above (note that the exclusive scan is actually the cheaper operation)
:return: sequence of resulting ``s[i] (i >= 0)``
"""
assert is_power_of_2(n := len(c_s))
c_s = copy(c_s)
if inclusive:
last_element_orig = c_s[-1] # Store in temporary variable, as corresponding list element will be overwritten
# Provide helpers for the up-sweep and down-sweep
range_d = range(int(round(log2(n))))
idxs_lr = lambda d: (((i + 2 ** d - 1), (i + 2 ** (d + 1) - 1)) for i in range(0, n, 2 ** (d + 1)))
# Implement the up-sweep (Blelloch, 1990, Fig. 1.2)
for d in range_d:
for idx_l, idx_r in idxs_lr(d): # This can be parallelized
c_s[idx_r] = plus(c_s[idx_l], c_s[idx_r])
# Implement "clear" and down-sweep (Blelloch, 1990, Fig. 1.4). CAUTION: For non-commutative `plus()`, it is
# crucial to always have `t` as the *2nd operand* in the last step! While this is not obvious from Fig. 1.4 (and
# thus caused me some headache), it becomes clear on closer inspection: `t` holds the value of originally *larger*
# index (cf. the visualization of steps in Harris et al., 2007, Fig. 39-4), so it must be "added" *last*.
c_s[n - 1] = id_ # "Clear" (Fig. 1.4b)
for d in reversed(range_d):
for idx_l, idx_r in idxs_lr(d): # This can be parallelized
t = c_s[idx_l]
c_s[idx_l] = c_s[idx_r]
c_s[idx_r] = plus(c_s[idx_r], t)
if inclusive:
# Harris et al. (2007): "An inclusive scan can be generated from an exclusive scan by shifting the resulting
# array left and inserting at the end the sum of the last element of the scan and the last element of the input
# array." Likewise Blelloch (1990): "The scan can be generated from the prescan by shifting left, and inserting
# at the end the sum of the last element of the prescan and the last element of the original vector."
c_s = c_s[1:] + [plus(c_s[-1], last_element_orig)]
return c_s
if __name__ == "__main__":
from itertools import accumulate as acc
import numpy as np
random = np.random.default_rng(seed=42)
# Check the `scan()` implementation, i.e. the actual parallel scan algorithm: Calculate cumulative sum with
# `scan()`, using the values from Blelloch (1990), Figs. 1.2+1.4
c = [3, 1, 7, 0, 4, 1, 6, 3]
s_ref = np.cumsum(c)
s_scan = scan(c, plus = lambda a, b: a + b, id_ = 0)
assert (s_ref == s_scan).all()
# Check the `s_s()` implementation, i.e. the transformation of the recursive problem into one for the parallel scan
# algorithm: Calculate an `n`-step recursive matrix-vector product (`d×d` matrices `A`, `d`-element vectors `b`)
#
# s[0] = b[0]
# s[i] = A[i] @ s[i - 1] + b[i] (0 < i < n)
n, d = 16, 5
b_s, a_s = random.normal(size=(n, d)), random.normal(size=(n, d, d))
pls, crs, id_pls, id_crs = (lambda v1, v2: np.add(v1, v2)), (lambda v, m: np.matmul(m, v)), np.zeros(d), np.eye(d)
s_ref = acc(zip(a_s[1:], b_s[1:]), lambda s, a_b: pls(crs(s, a_b[0]), a_b[1]), initial=b_s[0]) # Sequential impl.
s_scan = [s[1] for s in s_s(a_s=a_s, b_s=b_s, plus=pls, cross=crs, id_plus=id_pls, id_cross=id_crs)]
assert np.allclose(list(s_ref), s_scan)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment