Last active
August 25, 2022 12:59
-
-
Save yoyolicoris/f63f601d62187562070a61377cec9bf8 to your computer and use it in GitHub Desktop.
This lfilter can propogate gradient to filter coefficients.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchaudio.functional import lfilter as torch_lfilter | |
from torch.autograd import Function, gradcheck | |
class lfilter(Function): | |
@staticmethod | |
def forward(ctx, x, a, b) -> torch.Tensor: | |
with torch.no_grad(): | |
dummy = torch.zeros_like(a) | |
dummy[0] = 1 | |
xh = torch_lfilter(x, a, dummy, False) | |
y = xh.view(-1, 1, xh.shape[-1]) | |
y = F.pad(y, [b.numel() - 1, 0]) | |
y = F.conv1d(y, b.flip(0).view(1, 1, -1)).view(*xh.shape) | |
ctx.save_for_backward(x, a, b, xh) | |
return y | |
@staticmethod | |
def backward(ctx, dy) -> (torch.Tensor, torch.Tensor, torch.Tensor): | |
x, a, b, xh = ctx.saved_tensors | |
dx, da, db = (None,) * 3 | |
batch = x.numel() // x.shape[-1] | |
with torch.no_grad(): | |
if ctx.needs_input_grad[2]: | |
db = F.conv1d(F.pad(xh.view(1, -1, xh.shape[-1]), [b.numel() - 1, 0]), | |
dy.view(-1, 1, dy.shape[-1]), | |
groups=batch).sum((0, 1)).flip(0) | |
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: | |
dxh = F.conv1d(F.pad(dy.view(-1, 1, dy.shape[-1]), [0, b.numel() - 1]), | |
b.view(1, 1, -1)).view(*dy.shape) | |
dummy = torch.zeros_like(a) | |
if ctx.needs_input_grad[0]: | |
dummy[0] = 1 | |
dx = torch_lfilter(dxh.flip(-1), a, dummy, False).flip(-1) | |
if ctx.needs_input_grad[1]: | |
dummy[0] = -1 | |
dxhda = torch_lfilter(xh, a, dummy, False) | |
da = F.conv1d(F.pad(dxhda.view(1, -1, dxhda.shape[-1]), [b.numel() - 1, 0]), | |
dxh.view(-1, 1, dxh.shape[-1]), | |
groups=batch).sum((0, 1)).flip(0) | |
return dx, da, db | |
if __name__ == '__main__': | |
x = torch.randn(4, 256, device='cuda', dtype=torch.double) | |
a = torch.rand(3, device='cuda', dtype=torch.double) | |
b = torch.rand(3, device='cuda', dtype=torch.double) | |
a.div_(a[0].item()) | |
a.requires_grad = True | |
b.requires_grad = True | |
x.requires_grad = True | |
print(a, b) | |
with torch.autograd.profiler.profile(use_cuda=True, profile_memory=True) as prof: | |
y = lfilter.apply(x, a, b) | |
loss = y.abs().sum() | |
loss.backward() | |
print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=5)) | |
print(gradcheck(lfilter.apply, (x, a, b), eps=1e-6)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/ycy/miniconda3/envs/hrtf_notebooks/lib/python3.8/site-packages/torchaudio/backend/utils.py:53: UserWarning: \"sox\" backend is being deprecated. The default backend will be changed to \"sox_io\" backend in 0.8.0 and \"sox\" backend will be removed in 0.9.0. Please migrate to \"sox_io\" backend. Please refer to https://github.com/pytorch/audio/issues/903 for the detail.\n", | |
" warnings.warn(\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.utils.benchmark as benchmark\n", | |
"from numpy.random import uniform\n", | |
"from differentiable_lfilter import lfilter\n", | |
"\n", | |
"from itertools import product" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Define simple second-order IIR\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"#https://github.com/boris-kuz/differentiable_iir_filters/blob/master/differentiable_tdf2_model.py\n", | |
"class DTDFIICell(nn.Module):\n", | |
" def __init__(self):\n", | |
" super(DTDFIICell, self).__init__()\n", | |
" self.b0 = nn.Parameter(torch.FloatTensor([uniform(-1, 1)]))\n", | |
" self.b1 = nn.Parameter(torch.FloatTensor([uniform(-1, 1)]))\n", | |
" self.b2 = nn.Parameter(torch.FloatTensor([uniform(-1, 1)]))\n", | |
" self.a0 = nn.Parameter(torch.FloatTensor([1]))\n", | |
" self.a1 = nn.Parameter(torch.FloatTensor([uniform(-0.5, 0.5)]))\n", | |
" self.a2 = nn.Parameter(torch.FloatTensor([uniform(-0.5, 0.5)]))\n", | |
"\n", | |
" def _cat(self, vectors):\n", | |
" return torch.cat([v_.unsqueeze(-1) for v_ in vectors], dim=-1)\n", | |
"\n", | |
" def forward(self, input, v):\n", | |
" output = (input * self.b0 + v[:, 0]) / self.a0\n", | |
" v = self._cat([(input * self.b1 + v[:, 1] - output * self.a1), (input * self.b2 - output * self.a2)]) / self.a0\n", | |
" return output, v\n", | |
"\n", | |
" def init_states(self, size):\n", | |
" v = torch.zeros(size, 2).to(next(self.parameters()).device)\n", | |
" return v\n", | |
"\n", | |
"\n", | |
"class DTDFII(nn.Module):\n", | |
" def __init__(self):\n", | |
" super(DTDFII, self).__init__()\n", | |
" self.cell = DTDFIICell()\n", | |
"\n", | |
" def forward(self, input, initial_states=None):\n", | |
" batch_size = input.shape[0]\n", | |
" sequence_length = input.shape[1]\n", | |
"\n", | |
" if initial_states is None:\n", | |
" states = input.new_zeros(batch_size, 2)\n", | |
" else:\n", | |
" states = initial_states\n", | |
"\n", | |
" out_sequence = torch.zeros_like(input)\n", | |
" for s_idx in range(sequence_length):\n", | |
" out_sequence[:, s_idx], states = self.cell(input[:, s_idx].view(-1), states)\n", | |
"\n", | |
" if initial_states is None:\n", | |
" return out_sequence\n", | |
" else:\n", | |
" return out_sequence, states" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"pycharm": { | |
"name": "#%% md\n" | |
} | |
}, | |
"source": [ | |
"## Forward, CPU" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Benchmarking on 2 threads\n", | |
"<torch.utils.benchmark.utils.common.Measurement object at 0x7f5f5d48fa30>\n", | |
"simple IIR: Implemented using for-loop\n", | |
" 223.83 ms\n", | |
" 1 measurement, 10 runs , 2 threads\n", | |
"<torch.utils.benchmark.utils.common.Measurement object at 0x7f5f5d48fcd0>\n", | |
"lfilter: Implemented using torchaudio.lfilter\n", | |
" 24.17 ms\n", | |
" 1 measurement, 10 runs , 2 threads\n" | |
] | |
} | |
], | |
"source": [ | |
"batch = 8\n", | |
"samples = 1024\n", | |
"\n", | |
"x = torch.randn(batch, samples)\n", | |
"base = DTDFII()\n", | |
"a = torch.cat([base.cell.a0, base.cell.a1, base.cell.a2]).detach()\n", | |
"b = torch.cat([base.cell.b0, base.cell.b1, base.cell.b2]).detach()\n", | |
"x.requires_grad = a.requires_grad = b.requires_grad = True\n", | |
"\n", | |
"num_threads = torch.get_num_threads()\n", | |
"print(f'Benchmarking on {num_threads} threads')\n", | |
"\n", | |
"\n", | |
"t0 = benchmark.Timer(\n", | |
" stmt='m(x)',\n", | |
" setup='',\n", | |
" globals={'x': x, 'm': base},\n", | |
" num_threads=num_threads,\n", | |
" label='simple IIR',\n", | |
" sub_label='Implemented using for-loop')\n", | |
"\n", | |
"t1 = benchmark.Timer(\n", | |
" stmt='lfilter.apply(x, a, b)',\n", | |
" setup='from differentiable_lfilter import lfilter',\n", | |
" globals={'x': x, 'a': a, 'b': b},\n", | |
" num_threads=num_threads,\n", | |
" label='lfilter',\n", | |
" sub_label='Implemented using torchaudio.lfilter')\n", | |
"\n", | |
"m0 = t0.timeit(10)\n", | |
"m1 = t1.timeit(10)\n", | |
"print(m0)\n", | |
"print(m1)\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Forward, GPU" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<torch.utils.benchmark.utils.common.Measurement object at 0x7f5f5d48f7f0>\n", | |
"simple IIR: Implemented using for-loop\n", | |
" 392.72 ms\n", | |
" 1 measurement, 1 runs , 1 thread\n", | |
"Mean: 392.72 ms\n", | |
"Median: 392.72 ms\n", | |
"<torch.utils.benchmark.utils.common.Measurement object at 0x7f5f5d48f910>\n", | |
"lfilter: Implemented using torchaudio.lfilter\n", | |
" Median: 44.26 ms\n", | |
" IQR: 1.34 ms (43.56 to 44.90)\n", | |
" 5 measurements, 1 runs per measurement, 1 thread\n", | |
"Mean: 44.26 ms\n", | |
"Median: 44.26 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"batch = 8\n", | |
"samples = 1024\n", | |
"\n", | |
"x = torch.randn(batch, samples, device='cuda')\n", | |
"base = DTDFII().to('cuda')\n", | |
"a = torch.cat([base.cell.a0, base.cell.a1, base.cell.a2]).detach()\n", | |
"b = torch.cat([base.cell.b0, base.cell.b1, base.cell.b2]).detach()\n", | |
"x.requires_grad = a.requires_grad = b.requires_grad = True\n", | |
"\n", | |
"\n", | |
"t0 = benchmark.Timer(\n", | |
" stmt='m(x)',\n", | |
" setup='',\n", | |
" globals={'x': x, 'm': base},\n", | |
" label='simple IIR',\n", | |
" sub_label='Implemented using for-loop')\n", | |
"\n", | |
"t1 = benchmark.Timer(\n", | |
" stmt='lfilter.apply(x, a, b)',\n", | |
" setup='from differentiable_lfilter import lfilter',\n", | |
" globals={'x': x, 'a': a, 'b': b},\n", | |
" label='lfilter',\n", | |
" sub_label='Implemented using torchaudio.lfilter')\n", | |
"\n", | |
"m0 = t0.blocked_autorange()\n", | |
"m1 = t1.blocked_autorange()\n", | |
"print(m0)\n", | |
"print(f\"Mean: {m0.mean * 1e3:6.2f} ms\")\n", | |
"print(f\"Median: {m0.median * 1e3:6.2f} ms\")\n", | |
"print(m1)\n", | |
"print(f\"Mean: {m1.mean * 1e3:6.2f} ms\")\n", | |
"print(f\"Median: {m1.median * 1e3:6.2f} ms\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"pycharm": { | |
"name": "#%% md\n" | |
} | |
}, | |
"source": [ | |
"## Forward+backward, GPU" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<torch.utils.benchmark.utils.common.Measurement object at 0x7f5f5d38f5e0>\n", | |
"simple IIR: Implemented using for-loop\n", | |
" 486.54 ms\n", | |
" 1 measurement, 1 runs , 1 thread\n", | |
"Mean: 486.54 ms\n", | |
"Median: 486.54 ms\n", | |
"<torch.utils.benchmark.utils.common.Measurement object at 0x7f5f5d38f610>\n", | |
"lfilter: Implemented using torchaudio.lfilter\n", | |
" 35.69 ms\n", | |
" 1 measurement, 10 runs , 1 thread\n", | |
"Mean: 35.69 ms\n", | |
"Median: 35.69 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"batch = 8\n", | |
"samples = 256\n", | |
"\n", | |
"x = torch.randn(batch, samples, device='cuda')\n", | |
"base = DTDFII().to('cuda')\n", | |
"a = torch.cat([base.cell.a0, base.cell.a1, base.cell.a2]).detach()\n", | |
"b = torch.cat([base.cell.b0, base.cell.b1, base.cell.b2]).detach()\n", | |
"x.requires_grad = a.requires_grad = b.requires_grad = True\n", | |
"\n", | |
"dummy = base(x)\n", | |
"dummy.sum().backward()\n", | |
"\n", | |
"dummy = lfilter.apply(x, a, b)\n", | |
"dummy.sum().backward()\n", | |
"\n", | |
"assert x.grad is not None\n", | |
"assert a.grad is not None\n", | |
"assert b.grad is not None\n", | |
"\n", | |
"t0 = benchmark.Timer(\n", | |
" stmt='y = m(x)\\ny.mean().backward()',\n", | |
" setup='',\n", | |
" globals={'x': x, 'm': base},\n", | |
" label='simple IIR',\n", | |
" sub_label='Implemented using for-loop')\n", | |
"\n", | |
"t1 = benchmark.Timer(\n", | |
" stmt='y = lfilter.apply(x, a, b)\\ny.mean().backward()',\n", | |
" setup='from differentiable_lfilter import lfilter',\n", | |
" globals={'x': x, 'a': a, 'b': b},\n", | |
" label='lfilter',\n", | |
" sub_label='Implemented using torchaudio.lfilter')\n", | |
"\n", | |
"m0 = t0.blocked_autorange()\n", | |
"m1 = t1.blocked_autorange()\n", | |
"print(m0)\n", | |
"print(f\"Mean: {m0.mean * 1e3:6.2f} ms\")\n", | |
"print(f\"Median: {m0.median * 1e3:6.2f} ms\")\n", | |
"print(m1)\n", | |
"print(f\"Mean: {m1.mean * 1e3:6.2f} ms\")\n", | |
"print(f\"Median: {m1.median * 1e3:6.2f} ms\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Different sizes, Forward+backward, GPU" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[-------------- IIR filter -------------]\n", | |
" | for-loop | lfilter\n", | |
"1 threads: ------------------------------\n", | |
" [8, 16] | 30 | \u001b[34m\u001b[1m 6 \u001b[0m\u001b[0m\n", | |
" [8, 64] | \u001b[2m\u001b[91m 100 \u001b[0m\u001b[0m | 11 \n", | |
" [8, 256] | \u001b[31m\u001b[1m 400 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 34 \u001b[0m\u001b[0m\n", | |
" [8, 1024] | \u001b[31m\u001b[1m 1287 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 130 \u001b[0m\u001b[0m\n", | |
" [16, 16] | 30 | \u001b[34m\u001b[1m 6 \u001b[0m\u001b[0m\n", | |
" [16, 64] | \u001b[2m\u001b[91m 81 \u001b[0m\u001b[0m | 11 \n", | |
" [16, 256] | \u001b[31m\u001b[1m 400 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 34 \u001b[0m\u001b[0m\n", | |
" [16, 1024] | \u001b[31m\u001b[1m 1281 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 130 \u001b[0m\u001b[0m\n", | |
" [32, 16] | \u001b[92m\u001b[1m 20 \u001b[0m\u001b[0m | \u001b[92m\u001b[1m 6 \u001b[0m\u001b[0m\n", | |
" [32, 64] | \u001b[2m\u001b[91m 80 \u001b[0m\u001b[0m | \u001b[2m\u001b[91m 10 \u001b[0m\u001b[0m\n", | |
" [32, 256] | \u001b[31m\u001b[1m 400 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 40 \u001b[0m\u001b[0m\n", | |
" [32, 1024] | \u001b[31m\u001b[1m 1533 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 100 \u001b[0m\u001b[0m\n", | |
" [64, 16] | 20 | \u001b[92m\u001b[1m 6 \u001b[0m\u001b[0m\n", | |
" [64, 64] | \u001b[2m\u001b[91m 80 \u001b[0m\u001b[0m | 11 \n", | |
" [64, 256] | \u001b[31m\u001b[1m 320 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 40 \u001b[0m\u001b[0m\n", | |
" [64, 1024] | \u001b[31m\u001b[1m 1316 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 130 \u001b[0m\u001b[0m\n", | |
"\n", | |
"Times are in milliseconds (ms).\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"base = DTDFII().to('cuda')\n", | |
"a = torch.cat([base.cell.a0, base.cell.a1, base.cell.a2]).detach()\n", | |
"b = torch.cat([base.cell.b0, base.cell.b1, base.cell.b2]).detach()\n", | |
"a.requires_grad = b.requires_grad = True\n", | |
"\n", | |
"results = []\n", | |
"batches = [8, 16, 32, 64]\n", | |
"samples = [16, 64, 256, 1024]\n", | |
"for batch, n in product(batches, samples):\n", | |
" label = 'IIR filter'\n", | |
" sub_label = f'[{batch}, {n}]'\n", | |
" x = torch.randn(batch, n, device='cuda')\n", | |
" x.requires_grad = True\n", | |
"\n", | |
" dummy = base(x)\n", | |
" dummy.sum().backward()\n", | |
"\n", | |
" dummy = lfilter.apply(x, a, b)\n", | |
" dummy.sum().backward()\n", | |
"\n", | |
" results.append(benchmark.Timer(\n", | |
" stmt='y = m(x)\\ny.mean().backward()',\n", | |
" setup='',\n", | |
" globals={'x': x, 'm': base},\n", | |
" label=label,\n", | |
" sub_label=sub_label,\n", | |
" description='for-loop',\n", | |
" ).blocked_autorange(min_run_time=1))\n", | |
" results.append(benchmark.Timer(\n", | |
" stmt='y = lfilter.apply(x, a, b)\\ny.mean().backward()',\n", | |
" setup='from differentiable_lfilter import lfilter',\n", | |
" globals={'x': x, 'a': a, 'b': b},\n", | |
" label=label,\n", | |
" sub_label=sub_label,\n", | |
" description='lfilter',\n", | |
" ).blocked_autorange(min_run_time=1))\n", | |
"\n", | |
"compare = benchmark.Compare(results)\n", | |
"compare.trim_significant_figures()\n", | |
"compare.colorize()\n", | |
"compare.print()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[----------------------- IIR filter -----------------------]\n", | |
" | for-loop | lfilter\n", | |
"1 threads: -------------------------------------------------\n", | |
" 442 x 169 | \u001b[31m\u001b[1m 200 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 25 \u001b[0m\u001b[0m\n", | |
" 36 x 244 | \u001b[31m\u001b[1m 300 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 40 \u001b[0m\u001b[0m\n", | |
" 26 x 848 | \u001b[31m\u001b[1m 1129 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 110 \u001b[0m\u001b[0m\n", | |
" 126 x 2653 | \u001b[31m\u001b[1m 3389 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 330 \u001b[0m\u001b[0m\n", | |
" 1201 x 755 (discontiguous) | \u001b[31m\u001b[1m 990 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 98 \u001b[0m\u001b[0m\n", | |
" 56 x 917 | \u001b[31m\u001b[1m 1166 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 120 \u001b[0m\u001b[0m\n", | |
" 324 x 463 | \u001b[31m\u001b[1m 600 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 60 \u001b[0m\u001b[0m\n", | |
" 97 x 639 | \u001b[31m\u001b[1m 801 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 81 \u001b[0m\u001b[0m\n", | |
" 56 x 4 (discontiguous) | \u001b[92m\u001b[1m 5 \u001b[0m\u001b[0m | \u001b[92m\u001b[1m 3 \u001b[0m\u001b[0m\n", | |
" 192 x 183 (discontiguous) | \u001b[31m\u001b[1m 254 \u001b[0m\u001b[0m | \u001b[31m\u001b[1m 30 \u001b[0m\u001b[0m\n", | |
"\n", | |
"Times are in milliseconds (ms).\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"base = DTDFII().to('cuda')\n", | |
"a = torch.cat([base.cell.a0, base.cell.a1, base.cell.a2]).detach()\n", | |
"b = torch.cat([base.cell.b0, base.cell.b1, base.cell.b2]).detach()\n", | |
"a.requires_grad = b.requires_grad = True\n", | |
"\n", | |
"results = []\n", | |
"\n", | |
"example_fuzzer = benchmark.Fuzzer(\n", | |
" parameters = [\n", | |
" benchmark.FuzzedParameter('k0', minval=1, maxval=5000, distribution='loguniform'),\n", | |
" benchmark.FuzzedParameter('k1', minval=1, maxval=5000, distribution='loguniform'),\n", | |
" ],\n", | |
" tensors = [\n", | |
" benchmark.FuzzedTensor('x', size=('k0', 'k1'), min_elements=128, max_elements=1000000, probability_contiguous=0.6)\n", | |
" ],\n", | |
" seed=0,\n", | |
")\n", | |
"\n", | |
"for tensors, tensor_params, params in example_fuzzer.take(10):\n", | |
" sub_label=f\"{params['k0']:<6} x {params['k1']:<4} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}\"\n", | |
" \n", | |
" x = tensors['x']\n", | |
" x = x.cuda()\n", | |
" x.requires_grad = True\n", | |
"\n", | |
" dummy = base(x)\n", | |
" dummy.sum().backward()\n", | |
"\n", | |
" dummy = lfilter.apply(x, a, b)\n", | |
" dummy.sum().backward()\n", | |
" \n", | |
" label = 'IIR filter'\n", | |
"\n", | |
" results.append(benchmark.Timer(\n", | |
" stmt='y = m(x)\\ny.mean().backward()',\n", | |
" setup='',\n", | |
" globals={'x': x, 'm': base},\n", | |
" label=label,\n", | |
" sub_label=sub_label,\n", | |
" description='for-loop',\n", | |
" ).blocked_autorange(min_run_time=1))\n", | |
" results.append(benchmark.Timer(\n", | |
" stmt='y = lfilter.apply(x, a, b)\\ny.mean().backward()',\n", | |
" setup='from differentiable_lfilter import lfilter',\n", | |
" globals={'x': x, 'a': a, 'b': b},\n", | |
" label=label,\n", | |
" sub_label=sub_label,\n", | |
" description='lfilter',\n", | |
" ).blocked_autorange(min_run_time=1))\n", | |
"\n", | |
"compare = benchmark.Compare(results)\n", | |
"compare.trim_significant_figures()\n", | |
"compare.colorize()\n", | |
"compare.print()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 1 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This custom backward function have been added in newest torchaudio master branch.