Skip to content

Instantly share code, notes, and snippets.

@yoyolicoris
Last active June 30, 2025 21:08
Show Gist options
  • Save yoyolicoris/b67407ffb56fa168c59275aea548fe96 to your computer and use it in GitHub Desktop.
Save yoyolicoris/b67407ffb56fa168c59275aea548fe96 to your computer and use it in GitHub Desktop.
Notebook for Block-based Fast Differentiable IIR in PyTorch
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "170fc102",
"metadata": {},
"source": [
"# Block-based Fast Differentiable IIR in PyTorch\n",
"\n",
"I recently came across a presentation by Andres Ezequiel Viso from GPU Audio at ADC 2022, in which he talked about how they accelerate IIR filters on the GPU.\n",
"The approach they use is to formulate the IIR filter as a state-space model (SSM) and augment the transition matrix so that each step processes multiple samples at once.\n",
"The primary speedup stems from the fact that GPUs are very good at performing large matrix multiplications, and the SSM formulation enables us to leverage this capability.\n",
"\n",
"<iframe width=\"1024px\" height=\"576px\"\n",
"src=\"https://www.youtube.com/embed/UmYnoFo0Bb8?start=1356\"\n",
"allowfullscreen>\n",
"</iframe><br>\n",
"\n",
"Speeding up IIR filters while maintaining differentiability has always been my interest.\n",
"The most recent method I worked on is from my recent [submission](https://arxiv.org/abs/2504.14735) to DAFx 25, where my co-author Ben proposed using parallel associative scan to speed up the recursion on the GPU.\n",
"Nevertheless, since PyTorch does not have a built-in associative scan operator (in contrast to JAX), we must implement custom kernels for it, which is non-trivial.\n",
"It also requires that the filter has distinct poles so that the state-space transition matrix is diagonalisable.\n",
"The method that GPU Audio presented appears to be feasible solely using the PyTorch Python API and doesn't have the restrictions I mentioned; thus, I decided to benchmark it and see how it performs.\n",
"\n",
"Since it's just a proof of concept, the filter I'm going to test is a **time-invariant all-pole IIR filter**, which is the minimal case of a recursive filter.\n",
"This allows us to leverage some special optimisations that won't work with time-varying general IIR filters, but that won't affect the main idea I'm going to present here.\n"
]
},
{
"cell_type": "markdown",
"id": "b5b10fde",
"metadata": {},
"source": [
"## Naive implementation of an all-pole IIR filter\n",
"\n",
"The difference equation of an $M$-th order all-pole IIR filter is given by:\n",
"\n",
"$$\n",
"y[n] = x[n] -\\sum_{m=1}^{M} a_m y[n-m].\n",
"$$\n",
"\n",
"Let's implement this in PyTorch:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c379420f",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import Tensor\n",
"\n",
"\n",
"@torch.jit.script\n",
"def naive_allpole(x: Tensor, a: Tensor) -> Tensor:\n",
" \"\"\"\n",
" Naive all-pole filter implementation.\n",
"\n",
" Args:\n",
" x (Tensor): Input signal.\n",
" a (Tensor): All-pole coefficients.\n",
"\n",
" Returns:\n",
" Tensor: Filtered output signal.\n",
" \"\"\"\n",
" assert x.dim() == 2, \"Input signal must be a 2D tensor (batch_size, signal_length)\"\n",
" assert a.dim() == 1, \"All-pole coefficients must be a 1D tensor\"\n",
"\n",
" # list to store output at each time step\n",
" output = []\n",
" # assume initial condition is zero\n",
" zi = x.new_zeros(x.size(0), a.size(0))\n",
"\n",
" for xt in x.unbind(1):\n",
" # use addmv for efficient matrix-vector multiplication\n",
" yt = torch.addmv(xt, zi, a, alpha=-1.0)\n",
" output.append(yt)\n",
"\n",
" # update the state for the next time step\n",
" zi = torch.cat([yt.unsqueeze(1), zi[:, :-1]], dim=1)\n",
"\n",
" return torch.stack(output, dim=1)"
]
},
{
"cell_type": "markdown",
"id": "ca6c42ed",
"metadata": {},
"source": [
"In this implementation, I didn't use any in-place operations for speedup since it would break the differentiability of the function.\n",
"This naive implementation is not very efficient, as `torch.addmv` and `torch.cat` are called at each time step. \n",
"Typically, the audio signal is hundreds of thousands of samples long, resulting in a significant amount of function call overhead.\n",
"For details, please take a look at my [tutorial on differentiable IIR filters](https://intro2ddsp.github.io/filters/iir_torch.html) at ISMIR 2023.\n",
"\n",
"Notice that I used `torch.jit.script` to compile the function for some slight speedup.\n",
"I tried the newer compilation feature `torch.compile`, but it didn't work.\n",
"The compilation hangs forever, I don't know why...\n",
"I never found `torch.compile` to be useful in my research projects, and `torch.jit.*` has proven to be way more reliable.\n",
"\n",
"Let's benchmark its speed on my Ubuntu with an Intel i7-7700K.\n",
"We'll use a batch size of 8, a signal length of 16384, and $M=2$, which is a reasonable setting for audio processing."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d49eec12",
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.benchmark import Timer\n",
"\n",
"batch_size = 8\n",
"signal_length = 16384\n",
"order = 2\n",
"\n",
"def order2a(order: int) -> Tensor:\n",
" a = torch.randn(order)\n",
" # simple way to ensure stability\n",
" a = a / a.abs().sum()\n",
" return a\n",
"\n",
"a = order2a(order)\n",
"x = torch.randn(batch_size, signal_length)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "71b56202",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch.utils.benchmark.utils.common.Measurement object at 0x7f5b4423b260>\n",
"naive_allpole\n",
"Naive All-Pole Filter\n",
" Median: 168.93 ms\n",
" IQR: 0.54 ms (168.57 to 169.11)\n",
" 6 measurements, 1 runs per measurement, 4 threads"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"naive_allpole_t = Timer(\n",
" stmt=\"naive_allpole(x, a)\",\n",
" globals={\"naive_allpole\": naive_allpole, \"x\": x, \"a\": a},\n",
" label=\"naive_allpole\",\n",
" description=\"Naive All-Pole Filter\",\n",
" num_threads=4,\n",
")\n",
"naive_allpole_t.blocked_autorange(min_run_time=1.0)"
]
},
{
"cell_type": "markdown",
"id": "af322a8a",
"metadata": {},
"source": [
"182.59 ms is relatively slow, but it is expected.\n",
"\n",
"## State-space model formulation\n",
"\n",
"Before we proceed to showing the sample unrolling trick, let's first introduce the state-space model (SSM) formulation of the all-pole IIR filter.\n",
"The model is similar to the one in my previous blogpost on [TDF-II filter](https://iamycy.github.io/posts/2025/04/differentiable-tdf-ii/):\n",
"\n",
"$$\n",
"\\begin{align}\n",
"\\mathbf{h}[n] &= \\begin{bmatrix}\n",
" -a_1 & -a_2 & \\cdots & -a_{M-1} & -a_M \\\\\n",
" 1 & 0 &\\cdots & 0 & 0 \\\\\n",
" 0 & 1 & \\cdots & 0 & 0 \\\\\n",
" \\vdots & \\vdots & \\ddots & \\vdots & \\vdots \\\\\n",
" 0 & 0 & \\cdots & 1 & 0 \\\\\n",
"\\end{bmatrix} \\mathbf{h}[n-1] + \\begin{bmatrix}\n",
" 1 \\\\\n",
" 0 \\\\\n",
" 0 \\\\\n",
" \\vdots \\\\\n",
" 0 \\\\\n",
"\\end{bmatrix} x[n] \\\\\n",
"&= \\mathbf{A} \\mathbf{h}[n-1] + \\mathbf{B} x[n] \\\\\n",
"\n",
"y[n] &= \\mathbf{B}^\\top \\mathbf{h}[n].\n",
"\\end{align}\n",
"$$\n",
"\n",
"Here, I simplified the original SSM by omitting the direct path, as it can be derived from the state vector (for the all-pole filter only).\n",
"Below is the PyTorch implementation of it:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "cde4d9f0",
"metadata": {},
"outputs": [],
"source": [
"@torch.jit.script\n",
"def a2companion(a: Tensor) -> Tensor:\n",
" \"\"\"\n",
" Convert all-pole coefficients to a companion matrix.\n",
"\n",
" Args:\n",
" a (Tensor): All-pole coefficients.\n",
"\n",
" Returns:\n",
" Tensor: Companion matrix.\n",
" \"\"\"\n",
" assert a.dim() == 1, \"All-pole coefficients must be a 1D tensor\"\n",
" order = a.size(0)\n",
" c = torch.diag(a.new_ones(order - 1), -1)\n",
" c[0, :] = -a\n",
" return c\n",
"\n",
"\n",
"@torch.jit.script\n",
"def state_space_allpole(x: Tensor, a: Tensor) -> Tensor:\n",
" \"\"\"\n",
" State-space implementation of all-pole filtering.\n",
"\n",
" Args:\n",
" x (Tensor): Input signal.\n",
" a (Tensor): All-pole coefficients.\n",
"\n",
" Returns:\n",
" Tensor: Filtered output signal.\n",
" \"\"\"\n",
" assert x.dim() == 2, \"Input signal must be a 2D tensor (batch_size, signal_length)\"\n",
" assert a.dim() == 1, \"All-pole coefficients must be a 1D tensor\"\n",
"\n",
" c = a2companion(a).T\n",
"\n",
" output = []\n",
" # assume initial condition is zero\n",
" h = x.new_zeros(x.size(0), c.size(0))\n",
"\n",
" # B * x\n",
" x = torch.cat(\n",
" [x.unsqueeze(-1), x.new_zeros(x.size(0), x.size(1), c.size(0) - 1)], dim=2\n",
" )\n",
"\n",
" for xt in x.unbind(1):\n",
" h = torch.addmm(xt, h, c)\n",
" # B^T @ h\n",
" output.append(h[:, 0])\n",
" return torch.stack(output, dim=1)"
]
},
{
"cell_type": "markdown",
"id": "04e71034",
"metadata": {},
"source": [
"`a2companion` converts the all-pole coefficients to a [companion matrix](https://en.wikipedia.org/wiki/Companion_matrix), which is $\\mathbf{A}$ in the SSM formulation.\n",
"\n",
"Before we benchmark the speed of this implementation, let's predict how fast it will be.\n",
"Intuitively, since the complexity of vector-dot product is $O(M)$ and matrix-vector multiplication is $O(M^2)$, the SSM implementation uses more computational resources, so it should be slower than the naive implementation.\n",
"Let's benchmark its speed:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "1e24c12e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch.utils.benchmark.utils.common.Measurement object at 0x7f5a02eaf4a0>\n",
"state_space_allpole\n",
"State-Space All-Pole Filter\n",
" Median: 118.41 ms\n",
" IQR: 1.17 ms (117.79 to 118.96)\n",
" 9 measurements, 1 runs per measurement, 4 threads"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"state_space_allpole_t = Timer(\n",
" stmt=\"state_space_allpole(x, a)\",\n",
" globals={\"state_space_allpole\": state_space_allpole, \"x\": x, \"a\": a},\n",
" label=\"state_space_allpole\",\n",
" description=\"State-Space All-Pole Filter\",\n",
" num_threads=4,\n",
")\n",
"state_space_allpole_t.blocked_autorange(min_run_time=1.0)"
]
},
{
"cell_type": "markdown",
"id": "48ea53a8",
"metadata": {},
"source": [
"Interestingly, the SSM implementation is approximately 50 ms faster."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9d529e45",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:2025-06-28 23:59:13 247710:247710 init.cpp:107] function cbapi->getCuptiStatus() failed with error CUPTI_ERROR_INVALID_DEVICE (2)\n",
"WARNING:2025-06-28 23:59:13 247710:247710 init.cpp:108] CUPTI initialization failed - CUDA profiler activities will be missing\n",
"INFO:2025-06-28 23:59:13 247710:247710 init.cpp:110] If you see CUPTI_ERROR_INSUFFICIENT_PRIVILEGES, refer to https://developer.nvidia.com/nvidia-development-tools-solutions-err-nvgpuctrperm-cupti\n"
]
}
],
"source": [
"from torch.profiler import profile, ProfilerActivity\n",
"\n",
"with profile(\n",
" activities=[ProfilerActivity.CPU],\n",
" profile_memory=True,\n",
" record_shapes=True,\n",
" # with_flops=True\n",
") as naive_prof:\n",
" naive_allpole(x, a)\n",
"\n",
"with profile(\n",
" activities=[ProfilerActivity.CPU],\n",
" profile_memory=True,\n",
" # record_shapes=True,\n",
" # with_flops=True\n",
") as state_space_prof:\n",
" state_space_allpole(x, a)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "6285468a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ -------------------------------------------- \n",
" Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls Input Shapes \n",
"-------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ -------------------------------------------- \n",
" naive_allpole 16.56% 79.351ms 100.00% 479.094ms 479.094ms 512.00 Kb -175.50 Kb 1 [[8, 16384], [2]] \n",
" aten::cat 19.26% 92.271ms 36.35% 174.143ms 10.629us 175.44 Kb 175.44 Kb 16384 [[], [], [], []] \n",
" aten::narrow 6.17% 29.578ms 17.09% 81.872ms 2.499us 0 b 0 b 32768 [[8, 2], [], [], []] \n",
" aten::slice 12.73% 60.993ms 15.56% 74.563ms 2.275us 0 b 0 b 32768 [[], [], [8, 1], [8, 2], [], [], [], []] \n",
" aten::slice 8.24% 39.476ms 10.92% 52.294ms 1.596us 0 b 0 b 32768 [[8, 2], [], [], [], []] \n",
" aten::addmv 9.52% 45.632ms 9.52% 45.632ms 2.785us 512.00 Kb 512.00 Kb 16384 [[], [], [8], [8, 2], [2], [], []] \n",
" aten::stack 4.34% 20.770ms 9.21% 44.138ms 44.138ms 0 b -512.00 Kb 1 [[], []] \n",
" aten::unsqueeze 5.89% 28.220ms 7.30% 34.989ms 2.136us 0 b 0 b 16384 [[], [], [8], []] \n",
" aten::as_strided 5.51% 26.388ms 5.51% 26.388ms 0.403us 0 b 0 b 65536 [[8, 2], [], [], []] \n",
" aten::unbind 1.78% 8.525ms 5.48% 26.238ms 26.238ms 0 b 0 b 1 [[8, 16384], []] \n",
" aten::unsqueeze 3.02% 14.462ms 3.97% 19.041ms 1.162us 0 b 0 b 16384 [[8], []] \n",
" aten::select 2.90% 13.905ms 3.70% 17.713ms 1.081us 0 b 0 b 16384 [[8, 16384], [], []] \n",
" aten::as_strided 2.37% 11.349ms 2.37% 11.349ms 0.346us 0 b 0 b 32768 [[8], [], [], []] \n",
" aten::cat 0.90% 4.312ms 0.90% 4.327ms 4.327ms 512.00 Kb 512.00 Kb 1 [[], []] \n",
" aten::as_strided 0.80% 3.810ms 0.80% 3.810ms 0.233us 0 b 0 b 16385 [[8, 16384], [], [], []] \n",
" aten::new_zeros 0.00% 20.973us 0.01% 39.671us 39.671us 64 b 0 b 1 [[8, 16384], [], [], [], [], []] \n",
" aten::new_empty 0.00% 6.574us 0.00% 17.806us 17.806us 64 b 0 b 1 [[8, 16384], [], [], [], [], []] \n",
" aten::narrow 0.00% 5.425us 0.00% 14.761us 14.761us 0 b 0 b 1 [[8, 16384], [], [], []] \n",
" aten::empty 0.00% 11.232us 0.00% 11.232us 11.232us 64 b 64 b 1 [[], [], [], [], [], []] \n",
" aten::slice 0.00% 6.981us 0.00% 9.336us 9.336us 0 b 0 b 1 [[8, 16384], [], [], [], []] \n",
"-------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ -------------------------------------------- \n",
"Self CPU time total: 479.094ms\n",
"\n"
]
}
],
"source": [
"print(\n",
" naive_prof.key_averages(group_by_input_shape=True).table(\n",
" sort_by=\"cpu_time_total\", row_limit=20\n",
" )\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "0711f5ba",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n",
" Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls \n",
"----------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n",
" state_space_allpole 13.88% 31.898ms 100.00% 229.844ms 229.844ms 512.00 Kb -1.19 Mb 1 \n",
" aten::mm 19.88% 45.686ms 20.74% 47.671ms 2.910us 1023.94 Kb 1023.94 Kb 16384 \n",
" aten::select 15.34% 35.262ms 20.52% 47.156ms 1.439us 0 b 0 b 32769 \n",
" aten::add 17.37% 39.914ms 17.37% 39.914ms 2.436us 189.94 Kb 189.94 Kb 16384 \n",
" aten::stack 7.75% 17.812ms 13.18% 30.287ms 30.287ms -512.00 Kb -1.00 Mb 1 \n",
" aten::unbind 4.60% 10.564ms 11.41% 26.230ms 26.230ms 0 b 0 b 1 \n",
" aten::slice 7.37% 16.945ms 9.59% 22.031ms 1.344us 0 b 0 b 16387 \n",
" aten::as_strided 9.11% 20.930ms 9.11% 20.930ms 0.319us 0 b 0 b 65543 \n",
" aten::unsqueeze 2.35% 5.400ms 4.07% 9.345ms 0.570us 0 b 0 b 16385 \n",
" aten::cat 1.43% 3.286ms 1.44% 3.302ms 1.651ms 1.00 Mb 1.00 Mb 2 \n",
" aten::resolve_conj 0.86% 1.985ms 0.86% 1.985ms 0.061us 0 b 0 b 32768 \n",
" aten::new_zeros 0.00% 3.646us 0.04% 82.891us 41.445us 512.06 Kb 0 b 2 \n",
" aten::zero_ 0.00% 1.303us 0.03% 76.858us 25.619us 0 b 0 b 3 \n",
" aten::fill_ 0.03% 76.828us 0.03% 76.828us 38.414us 0 b 0 b 2 \n",
" aten::diag 0.00% 3.150us 0.02% 39.884us 39.884us 12 b -4 b 1 \n",
" aten::diag_embed 0.01% 15.298us 0.02% 36.734us 36.734us 16 b 0 b 1 \n",
" aten::new_ones 0.00% 6.760us 0.01% 20.771us 20.771us 4 b 0 b 1 \n",
" aten::narrow 0.00% 9.106us 0.01% 16.032us 8.016us 0 b 0 b 2 \n",
" aten::new_empty 0.00% 4.355us 0.01% 15.543us 5.181us 512.07 Kb 0 b 3 \n",
" aten::copy_ 0.01% 14.018us 0.01% 14.018us 7.009us -8 b -8 b 2 \n",
"----------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n",
"Self CPU time total: 229.844ms\n",
"\n"
]
}
],
"source": [
"print(state_space_prof.key_averages().table(sort_by=\"cpu_time_total\", row_limit=20))"
]
},
{
"cell_type": "markdown",
"id": "2586986e",
"metadata": {},
"source": [
"By using `torch.profiler.profile`, I found that, in the naive implementation, `torch.cat` for updating the last M outputs accounts for a significant portion of the total time (~20%).\n",
"The actual computation, `torch.addmv`, takes only about 10% of the time.\n",
"Regarding memory usage, the most memory-intensive operation is `torch.addmv`, which consumes approximately 512 Kb of memory.\n",
"In contrast, the SSM implementation uses more memory (> 1 Mb) due to matrix multiplication, but roughly 38% of the time is spent on filtering since it no longer has to call `torch.cat` at each time step.\n",
"The state vector (a.k.a the last M outputs) is automatically updated during the matrix multiplication.\n",
"\n",
"**Conclusion**: Tensor concatenation (including `torch.cat` and `torch.stack`) is computationally expensive, and it is advisable to avoid it whenever possible.\n",
"\n",
"## Unrolling the SSM\n",
"\n",
"Now we can apply the unrolling trick to the SSM implementation.\n",
"The idea is to divide the input signal into blocks of size $T$ and perform the recursion on these blocks instead of processing them sample-by-sample.\n",
"Each recursion takes the last vector state $\\mathbf{h}[n-1]$ and predicts the next $T$ states $[\\mathbf{h}[n], \\mathbf{h}[n+1], \\ldots, \\mathbf{h}[n+T-1]]^\\top$ at once.\n",
"To see how to calculate these states, let's unroll the SSM recursion for $T$ steps:\n",
"\n",
"$$\n",
"\\begin{align}\n",
"\\mathbf{h}[n] &= \\mathbf{A} \\mathbf{h}[n-1] + \\mathbf{B} x[n] \\\\\n",
"\\mathbf{h}[n+1] &= \\mathbf{A} \\mathbf{h}[n] + \\mathbf{B} x[n+1] \\\\\n",
"&= \\mathbf{A} (\\mathbf{A} \\mathbf{h}[n-1] + \\mathbf{B} x[n]) + \\mathbf{B} x[n+1] \\\\\n",
"&= \\mathbf{A}^2 \\mathbf{h}[n-1] + \\mathbf{A} \\mathbf{B} x[n] + \\mathbf{B} x[n+1] \\\\\n",
"\\mathbf{h}[n+2] &= \\mathbf{A} \\mathbf{h}[n+1] + \\mathbf{B} x[n+2] \\\\\n",
"&= \\mathbf{A} (\\mathbf{A}^2 \\mathbf{h}[n-1] + \\mathbf{A} \\mathbf{B} x[n] + \\mathbf{B} x[n+1]) + \\mathbf{B} x[n+2] \\\\\n",
"&= \\mathbf{A}^3 \\mathbf{h}[n-1] + \\mathbf{A}^2 \\mathbf{B} x[n] + \\mathbf{A} \\mathbf{B} x[n+1] + \\mathbf{B} x[n] \\\\\n",
"& \\vdots \\\\\n",
"\\mathbf{h}[n+T-1] &= \\mathbf{A}^{T} \\mathbf{h}[n-1] + \\sum_{t=0}^{T-1} \\mathbf{A}^{T - t -1} \\mathbf{B} x[n+t] \\\\\n",
"\\end{align}\n",
"$$\n",
"\n",
"We can rewrite the above equation in matrix form as follows:\n",
"\n",
"$$\n",
"\\begin{align}\n",
"\\begin{bmatrix}\n",
" \\mathbf{h}[n] \\\\\n",
" \\mathbf{h}[n+1] \\\\\n",
" \\vdots \\\\\n",
" \\mathbf{h}[n+T-1]\n",
"\\end{bmatrix} &= \\begin{bmatrix}\n",
" \\mathbf{A} \\\\\n",
" \\mathbf{A}^2 \\\\\n",
" \\vdots \\\\\n",
" \\mathbf{A}^T \\\\\n",
"\\end{bmatrix} \\mathbf{h}[n-1]\n",
"+ \\begin{bmatrix}\n",
" \\mathbf{I} & 0 & \\cdots & 0 \\\\\n",
" \\mathbf{A} & \\mathbf{I} & \\cdots & 0 \\\\\n",
" \\vdots & \\vdots & \\ddots & \\vdots \\\\\n",
" \\mathbf{A}^{T-1} & \\mathbf{A}^{T-2} & \\cdots & \\mathbf{I}\n",
"\\end{bmatrix}\n",
"\\begin{bmatrix}\n",
" \\mathbf{B}x[n] \\\\\n",
" \\mathbf{B}x[n+1] \\\\\n",
" \\vdots \\\\\n",
" \\mathbf{B}x[n+T-1]\n",
"\\end{bmatrix} \\\\\n",
"& = \\begin{bmatrix}\n",
" \\mathbf{A} \\\\\n",
" \\mathbf{A}^2 \\\\\n",
" \\vdots \\\\\n",
" \\mathbf{A}^T \\\\\n",
"\\end{bmatrix} \\mathbf{h}[n-1]\n",
"+ \\begin{bmatrix}\n",
" \\mathbf{I}_{.1} & 0 & \\cdots & 0 \\\\\n",
" \\mathbf{A}_{.1} & \\mathbf{I}_{.1} & \\cdots & 0 \\\\\n",
" \\vdots & \\vdots & \\ddots & \\vdots \\\\\n",
" \\mathbf{A}_{.1}^{T-1} & \\mathbf{A}_{.1}^{T-2} & \\cdots & \\mathbf{I}_{.1}\n",
"\\end{bmatrix}\n",
"\\begin{bmatrix}\n",
" x[n] \\\\\n",
" x[n+1] \\\\\n",
" \\vdots \\\\\n",
" x[n+T-1]\n",
"\\end{bmatrix} \\\\\n",
"& = \\mathbf{M} \\mathbf{h}[n-1] + \\mathbf{V} \\begin{bmatrix}\n",
" x[n] \\\\\n",
" x[n+1] \\\\\n",
" \\vdots \\\\\n",
" x[n+T-1]\n",
"\\end{bmatrix} \\\\\n",
"\\end{align}\n",
"$$\n",
"\n",
"Notice that in the second line, I utilised the fact that $\\mathbf{B}$ has only one non-zero entry to simplify the matrix.\n",
"(This is not possible if the filter is not strictly all-pole.)\n",
"$\\mathbf{I}_{.1}$ denotes the first column of the identity matrix and so on.\n",
"\n",
"Now, the number of autoregressive steps is reduced from $T$ to $\\frac{N}{T}$ and the matrix multiplication is done in parallel for every $T$ samples.\n",
"There are added costs for pre-computing the transition matrix $\\mathbf{M}$ and the input matrix $\\mathbf{V}$, though.\n",
"However, as long as the extra cost is relatively small compared to the cost of $T$ autoregressive steps, we should observe a speedup.\n",
"\n",
"Here's the PyTorch implementation of the unrolled SSM:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "1669e150",
"metadata": {},
"outputs": [],
"source": [
"@torch.jit.script\n",
"def state_space_allpole_unrolled(\n",
" x: Tensor, a: Tensor, unroll_factor: int = 1\n",
") -> Tensor:\n",
" \"\"\"\n",
" Unrolled state-space implementation of all-pole filtering.\n",
"\n",
" Args:\n",
" x (Tensor): Input signal.\n",
" a (Tensor): All-pole coefficients.\n",
" unroll_factor (int): Factor by which to unroll the loop.\n",
"\n",
" Returns:\n",
" Tensor: Filtered output signal.\n",
" \"\"\"\n",
" if unroll_factor == 1:\n",
" return state_space_allpole(x, a)\n",
" elif unroll_factor < 1:\n",
" raise ValueError(\"Unroll factor must be >= 1\")\n",
"\n",
" assert x.dim() == 2, \"Input signal must be a 2D tensor (batch_size, signal_length)\"\n",
" assert a.dim() == 1, \"All-pole coefficients must be a 1D tensor\"\n",
" assert (\n",
" x.size(1) % unroll_factor == 0\n",
" ), \"Signal length must be divisible by unroll factor\"\n",
"\n",
" c = a2companion(a)\n",
"\n",
" # create an initial identity matrix\n",
" initial = torch.eye(c.size(0), device=c.device, dtype=c.dtype)\n",
" c_list = [initial]\n",
" # TODO: use parallel scan to improve speed\n",
" for _ in range(unroll_factor):\n",
" c_list.append(c_list[-1] @ c)\n",
"\n",
" # c_list = [I c c^2 ... c^unroll_factor]\n",
" M = torch.cat(c_list[1:], dim=0).T\n",
" flatten_c_list = torch.cat(\n",
" [c.new_zeros(c.size(0) * (unroll_factor - 1))]\n",
" + [xx[:, 0] for xx in c_list[:-1]],\n",
" dim=0,\n",
" )\n",
" V = flatten_c_list.unfold(0, c.size(0) * unroll_factor, c.size(0)).flip(0)\n",
"\n",
" # divide the input signal into blocks of size unroll_factor\n",
" unrolled_x = x.unflatten(1, (-1, unroll_factor)) @ V\n",
"\n",
" output = []\n",
" # assume initial condition is zero\n",
" h = x.new_zeros(x.size(0), c.size(0))\n",
" for xt in unrolled_x.unbind(1):\n",
" h = torch.addmm(xt, h, M)\n",
" # B^T @ h\n",
" output.append(h[:, :: c.size(0)])\n",
" h = h[\n",
" :, -c.size(0) :\n",
" ] # take the last state vector as the initial condition for the next step\n",
" return torch.cat(output, dim=1)"
]
},
{
"cell_type": "markdown",
"id": "dfa4cd9f",
"metadata": {},
"source": [
"The `unroll_factor` parameter controls the number of samples to process in parallel.\n",
"If it is set to 1, the function is the original SSM implementation.\n",
"\n",
"Let's first make sure that the unrolled SSM implementation is equivalent to the original one."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "637a2c76",
"metadata": {},
"outputs": [],
"source": [
"y1 = naive_allpole(x, a)\n",
"y2 = state_space_allpole_unrolled(x, a, unroll_factor=8)\n",
"assert torch.allclose(y1, y2, atol=5e-6), \"Outputs are not close enough\"\n",
"# print(y2[0, -10:] - y1[0, -10:])"
]
},
{
"cell_type": "markdown",
"id": "db1f3402",
"metadata": {},
"source": [
"Now let's benchmark the speed of the unrolled SSM implementation.\n",
"We'll use `unroll_factor=128` since I already tested that it is the optimal value :)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "d63b82d7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch.utils.benchmark.utils.common.Measurement object at 0x7f5a01d75160>\n",
"state_space_allpole_unrolled\n",
"State-Space All-Pole Filter Unrolled\n",
" Median: 1.89 ms\n",
" IQR: 0.08 ms (1.88 to 1.96)\n",
" 6 measurements, 100 runs per measurement, 4 threads"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"state_space_allpole_unrolled_t = Timer(\n",
" stmt=\"state_space_allpole_unrolled(x, a, unroll_factor=unroll_factor)\",\n",
" globals={\n",
" \"state_space_allpole_unrolled\": state_space_allpole_unrolled,\n",
" \"x\": x,\n",
" \"a\": a,\n",
" \"unroll_factor\": 128,\n",
" },\n",
" label=\"state_space_allpole_unrolled\",\n",
" description=\"State-Space All-Pole Filter Unrolled\",\n",
" num_threads=4,\n",
")\n",
"state_space_allpole_unrolled_t.blocked_autorange(min_run_time=1.0)"
]
},
{
"cell_type": "markdown",
"id": "b0ee10d8",
"metadata": {},
"source": [
"1.91 ms! What sorcery is this? That's a whopping 70x speedup compared to the standard SSM implementation!"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "3b519301",
"metadata": {},
"outputs": [],
"source": [
"with profile(\n",
" activities=[ProfilerActivity.CPU], \n",
" profile_memory=True,\n",
" # record_shapes=True,\n",
" # with_flops=True\n",
") as unrolled_prof:\n",
" state_space_allpole_unrolled(x, a, unroll_factor=128)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "9fa4d5de",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n",
" Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls \n",
"-------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n",
" state_space_allpole_unrolled 16.38% 840.515us 100.00% 5.133ms 5.133ms 512.00 Kb -1.08 Mb 1 \n",
" aten::mm 25.38% 1.303ms 26.33% 1.351ms 5.258us 2.00 Mb 2.00 Mb 257 \n",
" aten::slice 16.89% 867.017us 22.32% 1.145ms 1.781us 0 b 0 b 643 \n",
" aten::matmul 1.82% 93.325us 17.67% 906.709us 7.029us 898.00 Kb -128.00 Kb 129 \n",
" aten::cat 8.45% 433.905us 13.68% 702.018us 234.006us -511.00 Kb -511.00 Kb 3 \n",
" aten::add 8.78% 450.775us 8.78% 450.775us 3.522us 80.00 Kb 80.00 Kb 128 \n",
" aten::select 6.62% 339.652us 8.57% 439.668us 1.711us 0 b 0 b 257 \n",
" aten::as_strided 7.47% 383.528us 7.47% 383.528us 0.424us 0 b 0 b 904 \n",
" aten::narrow 2.01% 103.309us 5.22% 268.113us 2.062us 0 b 0 b 130 \n",
" aten::unbind 1.68% 86.022us 4.96% 254.820us 254.820us 0 b 0 b 1 \n",
" aten::resolve_conj 0.95% 48.722us 0.95% 48.722us 0.076us 0 b 0 b 643 \n",
" aten::flip 0.52% 26.760us 0.87% 44.541us 44.541us 126.01 Kb -1.99 Kb 1 \n",
" aten::eye 0.52% 26.449us 0.85% 43.830us 21.915us 32 b 0 b 2 \n",
" aten::diag 0.07% 3.375us 0.73% 37.458us 37.458us 12 b -4 b 1 \n",
" aten::diag_embed 0.18% 9.223us 0.66% 34.083us 34.083us 16 b 0 b 1 \n",
" aten::new_ones 0.17% 8.569us 0.49% 25.367us 25.367us 4 b 0 b 1 \n",
" aten::new_empty 0.12% 6.200us 0.41% 20.969us 6.990us 1.06 Kb 0 b 3 \n",
" aten::empty_like 0.07% 3.456us 0.34% 17.331us 17.331us 128.00 Kb 0 b 1 \n",
" aten::empty 0.33% 16.792us 0.33% 16.792us 3.358us 1.07 Kb 1.07 Kb 5 \n",
" aten::copy_ 0.32% 16.361us 0.32% 16.361us 8.180us -8 b -8 b 2 \n",
"-------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n",
"Self CPU time total: 5.133ms\n",
"\n"
]
}
],
"source": [
"print(\n",
" unrolled_prof.key_averages().table(sort_by=\"cpu_time_total\", row_limit=20)\n",
")"
]
},
{
"cell_type": "markdown",
"id": "27d9d59c",
"metadata": {},
"source": [
"A closer look at the profiling results shows that in total, 38% of the time is spent on matrix multiplication and addition.\n",
"The speedup comes with a cost of increased memory usage, requiring more than 2 MB for filtering.\n",
"Not a significant cost for modern Hardwares.\n",
"\n",
"For convenience, I ran the above benchmarks using the CPU, which has very limited parallelism compared to the GPU.\n",
"Thus, the significant speedup we observe indicates that function call overhead is the major bottleneck for running recursions.\n",
"\n",
"\n",
"## More comparison\n",
"\n",
"Since $T$ is an essential parameter for the unrolled SSM, I did some benchmarks to see how it affects the speed.\n",
"\n",
"### Varying sequence length\n",
"\n",
"In this benchmark, I fixed the batch size to 8 and the order to 2, and varied the sequence length from 4096 to 262144.\n",
"The results suggest that the best unroll factor increases as the sequence length increases, and it's very likely to be $\\sqrt{N}$.\n",
"Additionally, the longer the sequence length, the greater the speedup we achieve from the unrolled SSM."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "84311402",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"from torch.utils.benchmark import Compare\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "5d5d38a1",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:07<00:00, 1.24s/it]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:08<00:00, 1.47s/it]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:08<00:00, 1.36s/it]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:11<00:00, 1.84s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[--------------------- State-Space All-Pole Unrolled ----------------------]\n",
" | 4096 | 16384 | 65536 | 262144 \n",
"4 threads: -----------------------------------------------------------------\n",
" unroll factor: 1 | 27191.0 | 118571.2 | 513106.7 | 2159240.8\n",
" unroll factor: 32 | 1201.2 | 4142.5 | 16838.3 | 69831.7\n",
" unroll factor: 64 | 889.3 | 2456.3 | 9469.4 | 38095.0\n",
" unroll factor: 128 | 954.5 | 1896.8 | 6388.2 | 24019.4\n",
" unroll factor: 256 | 1418.1 | 2108.4 | 5675.1 | 18562.1\n",
" unroll factor: 512 | 2571.3 | 3314.9 | 6876.7 | 20691.5\n",
"\n",
"Times are in microseconds (us).\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"factors = [1, 32, 64, 128, 256, 512]\n",
"signal_lengths = [4096, 16384, 65536, 262144]\n",
"order = 2\n",
"a = order2a(order)\n",
"\n",
"results = []\n",
"\n",
"label = \"State-Space All-Pole Unrolled\"\n",
"for signal_length in signal_lengths:\n",
" x = torch.randn(batch_size, signal_length)\n",
" for unroll_factor in tqdm(factors):\n",
" sub_label = f\"unroll factor: {unroll_factor}\"\n",
" results.append(\n",
" Timer(\n",
" stmt=\"state_space_allpole_unrolled(x, a, unroll_factor=unroll_factor)\",\n",
" globals={\n",
" \"state_space_allpole_unrolled\": state_space_allpole_unrolled,\n",
" \"x\": x,\n",
" \"a\": a,\n",
" \"unroll_factor\": unroll_factor,\n",
" },\n",
" num_threads=4,\n",
" label=label,\n",
" sub_label=sub_label,\n",
" description=f\"{signal_length}\",\n",
" ).blocked_autorange(min_run_time=1)\n",
" )\n",
"\n",
"compare = Compare(results)\n",
"compare.print()\n",
"\n",
"naive_result = [x.median for x in results]"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "f6f35b44",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots()\n",
"for i, signal_length in enumerate(signal_lengths):\n",
" baseline, *rest = naive_result[i * len(factors) : (i + 1) * len(factors)]\n",
" normalise = [baseline / x for x in rest]\n",
" ax.plot(\n",
" factors[1:],\n",
" normalise,\n",
" marker=\"o\",\n",
" label=f\"N={signal_length}\",\n",
" )\n",
"\n",
"ax.set_title(f\"M={order}, batch size={batch_size}\")\n",
"ax.set_yscale(\"log\")\n",
"ax.set_xscale(\"log\")\n",
"ax.set_xticks(factors[1:])\n",
"ax.get_xaxis().set_major_formatter(plt.ScalarFormatter())\n",
"ax.legend()\n",
"ax.set_xlabel(\"Unroll Factor\")\n",
"ax.set_ylabel(\"Speedup (vs standard SSM)\")\n",
"ax.grid()\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "0c431b72",
"metadata": {},
"source": [
"### Varying filter order\n",
"\n",
"To examine the impact of filter order on speed, I set the batch size to 8 and the sequence length to 16384, and then varied the filter order from 2 to 16.\n",
"It appears that my hypothesis that the best factor is $\\sqrt{N}$ still holds, but the peak gradually shifts to the left as the order increases.\n",
"Moreover, the speedup is less significant for higher orders, which is expected as the $\\mathbf{V}$ matrix becomes larger."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "09ee2e5b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/6 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:08<00:00, 1.42s/it]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:08<00:00, 1.37s/it]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:08<00:00, 1.47s/it]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:07<00:00, 1.31s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[--------------- State-Space All-Pole Unrolled ----------------]\n",
" | 2 | 4 | 8 | 16 \n",
"4 threads: -----------------------------------------------------\n",
" unroll factor: 1 | 120.3 | 118.3 | 120.3 | 126.9\n",
" unroll factor: 32 | 4.0 | 4.4 | 5.5 | 7.5\n",
" unroll factor: 64 | 2.4 | 2.8 | 3.7 | 5.6\n",
" unroll factor: 128 | 1.9 | 2.2 | 3.2 | 5.2\n",
" unroll factor: 256 | 2.1 | 2.6 | 3.8 | 6.2\n",
" unroll factor: 512 | 35.3 | 4.4 | 6.4 | 10.1\n",
"\n",
"Times are in milliseconds (ms).\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"factors = [1, 32, 64, 128, 256, 512]\n",
"signal_length = 16384\n",
"orders = [2, 4, 8, 16]\n",
"batch_size = 8\n",
"\n",
"x = torch.randn(batch_size, signal_length)\n",
"results = []\n",
"\n",
"label = \"State-Space All-Pole Unrolled\"\n",
"for order in orders:\n",
" a = order2a(order)\n",
" for unroll_factor in tqdm(factors):\n",
" sub_label = f\"unroll factor: {unroll_factor}\"\n",
" results.append(\n",
" Timer(\n",
" stmt=\"state_space_allpole_unrolled(x, a, unroll_factor=unroll_factor)\",\n",
" globals={\n",
" \"state_space_allpole_unrolled\": state_space_allpole_unrolled,\n",
" \"x\": x,\n",
" \"a\": a,\n",
" \"unroll_factor\": unroll_factor,\n",
" },\n",
" num_threads=4,\n",
" label=label,\n",
" sub_label=sub_label,\n",
" description=f\"{order}\",\n",
" ).blocked_autorange(min_run_time=1)\n",
" )\n",
"\n",
"compare = Compare(results)\n",
"compare.print()\n",
"\n",
"factor_vs_order = [x.median for x in results]"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "81260990",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots()\n",
"for i, order in enumerate(orders):\n",
" baseline, *rest = factor_vs_order[i * len(factors) : (i + 1) * len(factors)]\n",
" normalise = [baseline / x for x in rest]\n",
" ax.plot(\n",
" factors[1:],\n",
" normalise,\n",
" marker=\"o\",\n",
" label=f\"M={order}\",\n",
" )\n",
"ax.set_title(f\"N={signal_length}, batch size={batch_size}\")\n",
"ax.set_yscale(\"log\")\n",
"ax.set_xscale(\"log\")\n",
"ax.set_xticks(factors[1:])\n",
"ax.get_xaxis().set_major_formatter(plt.ScalarFormatter())\n",
"ax.legend()\n",
"ax.set_xlabel(\"Unroll Factor\")\n",
"ax.set_ylabel(\"Speedup (vs standard SSM)\")\n",
"ax.grid()\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "bee81bfb",
"metadata": {},
"source": [
"### Varying batch size\n",
"\n",
"The speedup is less as the batch size increases, which is expected.\n",
"However, the peak of the best unroll factor also shifts slightly to the left as the batch size increases."
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "8453522b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/6 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:08<00:00, 1.33s/it]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:08<00:00, 1.47s/it]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:08<00:00, 1.46s/it]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:08<00:00, 1.47s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[--------------- State-Space All-Pole Unrolled ----------------]\n",
" | 2 | 8 | 32 | 64 \n",
"4 threads: -----------------------------------------------------\n",
" unroll factor: 1 | 116.7 | 118.7 | 128.3 | 139.1\n",
" unroll factor: 32 | 3.9 | 4.4 | 7.0 | 9.9\n",
" unroll factor: 64 | 2.3 | 2.8 | 5.1 | 8.2\n",
" unroll factor: 128 | 1.7 | 2.2 | 5.0 | 8.0\n",
" unroll factor: 256 | 1.9 | 2.6 | 6.1 | 9.9\n",
" unroll factor: 512 | 8.1 | 39.9 | 150.2 | 294.8\n",
"\n",
"Times are in milliseconds (ms).\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"factors = [1, 32, 64, 128, 256, 512]\n",
"signal_length = 16384\n",
"batch_sizes = [2, 8, 32, 64]\n",
"order = 4\n",
"results = []\n",
"a = order2a(order)\n",
"\n",
"label = \"State-Space All-Pole Unrolled\"\n",
"for batch_size in batch_sizes:\n",
" x = torch.randn(batch_size, signal_length)\n",
" for unroll_factor in tqdm(factors):\n",
" sub_label = f\"unroll factor: {unroll_factor}\"\n",
" results.append(\n",
" Timer(\n",
" stmt=\"state_space_allpole_unrolled(x, a, unroll_factor=unroll_factor)\",\n",
" globals={\n",
" \"state_space_allpole_unrolled\": state_space_allpole_unrolled,\n",
" \"x\": x,\n",
" \"a\": a,\n",
" \"unroll_factor\": unroll_factor,\n",
" },\n",
" num_threads=4,\n",
" label=label,\n",
" sub_label=sub_label,\n",
" description=f\"{batch_size}\",\n",
" ).blocked_autorange(min_run_time=1)\n",
" )\n",
"\n",
"compare = Compare(results)\n",
"compare.print()\n",
"\n",
"factor_vs_batch_size = [x.median for x in results]"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "dded0ac0",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots()\n",
"for i, batch_size in enumerate(batch_sizes):\n",
" baseline, *rest = factor_vs_batch_size[i * len(factors) : (i + 1) * len(factors)]\n",
" normalise = [baseline / x for x in rest]\n",
" ax.plot(\n",
" factors[1:],\n",
" normalise,\n",
" marker=\"o\",\n",
" label=f\"batch size={batch_size}\",\n",
" )\n",
"ax.set_title(f\"N={signal_length}, M={order}\")\n",
"ax.set_yscale(\"log\")\n",
"ax.set_xscale(\"log\")\n",
"ax.set_xticks(factors[1:])\n",
"ax.get_xaxis().set_major_formatter(plt.ScalarFormatter())\n",
"ax.legend()\n",
"ax.set_xlabel(\"Unroll Factor\")\n",
"ax.set_ylabel(\"Speedup (vs standard SSM)\")\n",
"ax.grid()\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "ccf1ae2e",
"metadata": {},
"source": [
"### Memory usage\n",
"\n",
"To observe how memory usage changes in a differentiable training context, I ran the unrolled SSM on a 5060 Ti, allowing me to use `torch.cuda.max_memory_allocated()` to measure memory usage.\n",
"When batch size is 1, as expected, the memory usage grows quadratically with the unroll factor, due to the creation of the $\\mathbf{V}$ matrix.\n",
"When using a larger batch size (32 in this case), this cost becomes less significant compared to the more memory used for the input signal."
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "98fa398c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'NVIDIA GeForce RTX 5060 Ti'"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.cuda.get_device_name()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "02ec6fd5",
"metadata": {},
"outputs": [],
"source": [
"factors = [1, 32, 64, 128, 256, 512, 1024]\n",
"signal_length = 65536\n",
"batch_size = 1\n",
"orders = [2, 4, 8, 16]\n",
"cuda_result = []\n",
"x = torch.randn(batch_size, signal_length, requires_grad=True).cuda()\n",
"\n",
"label = \"State-Space All-Pole Unrolled\"\n",
"for order in orders:\n",
" for unroll_factor in factors:\n",
" a = order2a(order).cuda()\n",
" a.requires_grad = True\n",
" sub_label = f\"unroll factor: {unroll_factor}\"\n",
" torch.cuda.reset_peak_memory_stats()\n",
" y = state_space_allpole_unrolled(x, a, unroll_factor=unroll_factor)\n",
" peak_memory = torch.cuda.max_memory_allocated() / (1024 * 1024)\n",
" cuda_result.append(peak_memory)\n",
" y.detach_()\n"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "c26c71b0",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots()\n",
"for i, order in enumerate(orders):\n",
" # baseline, *rest = cuda_result[i * len(factors) : (i + 1) * len(factors)]\n",
" # normalise = [x / baseline for x in rest]\n",
" ax.plot(\n",
" # factors[1:],\n",
" # normalise,\n",
" factors,\n",
" cuda_result[i * len(factors) : (i + 1) * len(factors)],\n",
" marker=\"o\",\n",
" label=f\"M={order}\",\n",
" )\n",
"ax.set_title(f\"N={signal_length}, batch size={batch_size}\")\n",
"# ax.set_yscale(\"log\")\n",
"# ax.set_xscale(\"log\")\n",
"ax.set_xticks(factors)\n",
"ax.get_xaxis().set_major_formatter(plt.ScalarFormatter())\n",
"ax.legend()\n",
"ax.set_xlabel(\"Unroll Factor\")\n",
"ax.set_ylabel(\"Peak Memory Usage (MB)\")\n",
"ax.grid()\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "9a3202e0",
"metadata": {},
"source": [
"## Discussion\n",
"\n",
"So far, we have seen that the unrolled SSM can achieve a significant speedup for IIR filtering in PyTorch.\n",
"However, determining the best unrolling factor automatically is still unclear.\n",
"From the benchmarks I did on an i7 CPU, it seems that the optimal $T^*$ is $\\sqrt{N}\\alpha$ and $0 < \\alpha \\leq 1$ is given by a function of the filter order and batch size.\n",
"Since I also observe similar behaviour on the GPU, it is likely that this hypothesis holds true for other hardware as well.\n",
"\n",
"One thing I didn't mention is numerical accuracy.\n",
"If $|\\mathbf{A}|$ is very small, the precomputed exponentials $\\mathbf{A}^T \\to \\mathbf{0}$ which may not be accurately represented in floating point, especially in deep learning applications we use single precision a lot.\n",
"This is less of a problem for the standard SSM, since at each time step, the input is mixed with the state vector, which could help cancel out the numerical errors.\n",
"\n",
"The idea should apply when there are zeros in the filter.\n",
"$\\mathbf{B}$ will not be a simple one-hot vector anymore so $\\mathbf{V}$ has to be a full $MT \\times MT$ square matrix.\n",
"Time-varying filters will benefit less from the unrolling trick since $\\mathbf{V}$ will also be time-varying, and computing $\\frac{N}{T}$ such matrices in advance increases the cost.\n",
"\n",
"\n",
"## Conclusion & Thoughts\n",
"\n",
"In this post, I demonstrate that the unrolling trick can significantly accelerate differentiable IIR filtering in PyTorch.\n",
"The extra memory cost is less of a problem for large batch sizes.\n",
"Although the filter I tested is a simple all-pole filter, it's trivial to extend the idea to general IIR filters.\n",
"\n",
"This method might help address one of the issues for future TorchAudio, after the Meta developers [announced](https://github.com/pytorch/audio/issues/3902) their future plan for it.\n",
"In the next major release, all the specialised kernels written in C++, including the `lfilter` I contributed years ago, will be removed from TorchAudio.\n",
"The filter I presented here is written entirely in Python, and it could serve as a straightforward drop-in replacement for the current compiled `lfilter` implementation."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment