Created
June 23, 2021 17:44
-
-
Save ShigekiKarita/c3b513ce5e3bed9aeda726c4d2c2e200 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "id": "0da2d661", | |
| "metadata": {}, | |
| "source": [ | |
| "# CTC implementation with PyTorch JIT\n", | |
| "\n", | |
| "Requires pytorch 1.9.0 and editdistance pip packages\n", | |
| "And espnet AN4 data prep (wav.scp and text)\n", | |
| "\n", | |
| "author: Shigeki Karita ([email protected])\n", | |
| "\n", | |
| "https://pytorch.org/docs/stable/jit.html" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "171bf40b", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:62: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.5 and num_layers=1\n", | |
| " warnings.warn(\"dropout option adds dropout after all but last \"\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "#params:\t 125\n", | |
| "out shape:\t torch.Size([1])\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "graph(%self : __torch__.SpeechModel,\n", | |
| " %x.1 : Tensor,\n", | |
| " %xlen.1 : Tensor,\n", | |
| " %y.1 : Tensor,\n", | |
| " %ylen.1 : Tensor):\n", | |
| " %5 : int[] = prim::Constant[value=[2]]()\n", | |
| " %6 : int[] = prim::Constant[value=[1]]()\n", | |
| " %7 : int[] = prim::Constant[value=[0]]()\n", | |
| " %8 : int[] = prim::Constant[value=[5]]()\n", | |
| " %9 : bool = prim::Constant[value=0]()\n", | |
| " %10 : int = prim::Constant[value=0]() # <ipython-input-1-14e2ee0fb7fa>:85:24\n", | |
| " %11 : int = prim::Constant[value=1]() # <ipython-input-1-14e2ee0fb7fa>:85:27\n", | |
| " %12 : str = prim::Constant[value=\"input.size(-1) must be equal to input_size. Expected {}, got {}\"]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:206:16\n", | |
| " %13 : str = prim::Constant[value=\"input must have {} dimensions, got {}\"]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:202:16\n", | |
| " %14 : int = prim::Constant[value=3]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:199:63\n", | |
| " %15 : str = prim::Constant[value=\"Expected hidden[0] size {}, got {}\"]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:622:31\n", | |
| " %16 : str = prim::Constant[value=\"Expected hidden[1] size {}, got {}\"]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:624:31\n", | |
| " %17 : bool = prim::Constant[value=1]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:66\n", | |
| " %18 : str = prim::Constant[value=\"Expected more than 1 value per channel when training, got input size {}\"]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n", | |
| " %19 : float = prim::Constant[value=1.0000000000000001e-05]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/normalization.py:253:60\n", | |
| " %20 : float = prim::Constant[value=0.5]() # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/dropout.py:58:32\n", | |
| " %21 : int = prim::Constant[value=2]() # <ipython-input-1-14e2ee0fb7fa>:69:53\n", | |
| " %22 : NoneType = prim::Constant()\n", | |
| " %23 : int = prim::Constant[value=-1]() # <ipython-input-1-14e2ee0fb7fa>:72:29\n", | |
| " %24 : __torch__.torch.nn.modules.container.Sequential = prim::GetAttr[name=\"encoder\"](%self)\n", | |
| " %25 : Tensor = aten::slice(%x.1, %10, %22, %22, %11) # <ipython-input-1-14e2ee0fb7fa>:69:25\n", | |
| " %314 : Tensor = prim::profile[profiled_type=Float(2, 1000, strides=[1000, 1], requires_grad=0, device=cpu)](%25)\n", | |
| " %26 : Tensor = aten::unsqueeze(%314, %11) # <ipython-input-1-14e2ee0fb7fa>:69:25\n", | |
| " %315 : Tensor = prim::profile[profiled_type=Float(2, 1, 1000, strides=[1000, 1000, 1], requires_grad=0, device=cpu)](%26)\n", | |
| " %27 : Tensor = aten::slice(%315, %21, %22, %22, %11) # <ipython-input-1-14e2ee0fb7fa>:69:25\n", | |
| " %28 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"0\"](%24)\n", | |
| " %29 : __torch__.torch.nn.modules.conv.Conv1d = prim::GetAttr[name=\"1\"](%24)\n", | |
| " %30 : __torch__.torch.nn.modules.dropout.Dropout = prim::GetAttr[name=\"2\"](%24)\n", | |
| " %31 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"3\"](%24)\n", | |
| " %32 : __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv1d = prim::GetAttr[name=\"5\"](%24)\n", | |
| " %33 : __torch__.torch.nn.modules.dropout.Dropout = prim::GetAttr[name=\"6\"](%24)\n", | |
| " %34 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"7\"](%24)\n", | |
| " %35 : __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv1d = prim::GetAttr[name=\"9\"](%24)\n", | |
| " %36 : __torch__.torch.nn.modules.dropout.Dropout = prim::GetAttr[name=\"10\"](%24)\n", | |
| " %37 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"11\"](%24)\n", | |
| " %38 : __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv1d = prim::GetAttr[name=\"13\"](%24)\n", | |
| " %39 : __torch__.torch.nn.modules.dropout.Dropout = prim::GetAttr[name=\"14\"](%24)\n", | |
| " %40 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"15\"](%24)\n", | |
| " %41 : __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv1d = prim::GetAttr[name=\"17\"](%24)\n", | |
| " %42 : __torch__.torch.nn.modules.dropout.Dropout = prim::GetAttr[name=\"18\"](%24)\n", | |
| " %43 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"19\"](%24)\n", | |
| " %44 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv1d = prim::GetAttr[name=\"21\"](%24)\n", | |
| " %45 : __torch__.torch.nn.modules.dropout.Dropout = prim::GetAttr[name=\"22\"](%24)\n", | |
| " %46 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"23\"](%24)\n", | |
| " %47 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv1d = prim::GetAttr[name=\"25\"](%24)\n", | |
| " %48 : __torch__.torch.nn.modules.dropout.Dropout = prim::GetAttr[name=\"26\"](%24)\n", | |
| " %49 : __torch__.torch.nn.modules.normalization.GroupNorm = prim::GetAttr[name=\"27\"](%24)\n", | |
| " %50 : Tensor = prim::GetAttr[name=\"weight\"](%28)\n", | |
| " %51 : Tensor = prim::GetAttr[name=\"bias\"](%28)\n", | |
| " %316 : Tensor = prim::profile[profiled_type=Float(2, 1, 1000, strides=[1000, 1000, 1], requires_grad=0, device=cpu)](%27)\n", | |
| " %52 : int = aten::size(%316, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %317 : Tensor = prim::profile[profiled_type=Float(2, 1, 1000, strides=[1000, 1000, 1], requires_grad=0, device=cpu)](%27)\n", | |
| " %53 : int = aten::size(%317, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n", | |
| " %54 : int = aten::mul(%52, %53) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %55 : int = aten::floordiv(%54, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %56 : int[] = prim::ListConstruct(%55, %11)\n", | |
| " %318 : Tensor = prim::profile[profiled_type=Float(2, 1, 1000, strides=[1000, 1000, 1], requires_grad=0, device=cpu)](%27)\n", | |
| " %57 : int[] = aten::size(%318) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
| " %58 : int[] = aten::slice(%57, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
| " %59 : int[] = aten::list(%58) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n", | |
| " %60 : int[] = aten::add(%56, %59) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n", | |
| " %size_prods.2 : int = aten::__getitem__(%60, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n", | |
| " %62 : int = aten::len(%60) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
| " %63 : int = aten::sub(%62, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
| " %size_prods.4 : int = prim::Loop(%63, %17, %size_prods.2) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n", | |
| " block0(%i.2 : int, %size_prods.12 : int):\n", | |
| " %67 : int = aten::add(%i.2, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n", | |
| " %68 : int = aten::__getitem__(%60, %67) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n", | |
| " %size_prods.14 : int = aten::mul(%size_prods.12, %68) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n", | |
| " -> (%17, %size_prods.14)\n", | |
| " %70 : bool = aten::eq(%size_prods.4, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n", | |
| " = prim::If(%70) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n", | |
| " block0():\n", | |
| " %71 : str = aten::format(%18, %60) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n", | |
| " = prim::RaiseException(%71) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n", | |
| " -> ()\n", | |
| " block1():\n", | |
| " -> ()\n", | |
| " %input.5 : Tensor = aten::group_norm(%27, %11, %50, %51, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n", | |
| " %73 : Tensor = prim::GetAttr[name=\"weight\"](%29)\n", | |
| " %74 : Tensor? = prim::GetAttr[name=\"bias\"](%29)\n", | |
| " %input.9 : Tensor = aten::conv1d(%input.5, %73, %74, %8, %7, %6, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:294:15\n", | |
| " %76 : bool = prim::GetAttr[name=\"training\"](%30)\n", | |
| " %319 : Tensor = prim::profile[profiled_type=Float(2, 1, 199, strides=[199, 199, 1], requires_grad=1, device=cpu)](%input.9)\n", | |
| " %input.13 : Tensor = aten::dropout(%319, %20, %76) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1168:60\n", | |
| " %78 : Tensor = prim::GetAttr[name=\"weight\"](%31)\n", | |
| " %79 : Tensor = prim::GetAttr[name=\"bias\"](%31)\n", | |
| " %320 : Tensor = prim::profile[profiled_type=Float(2, 1, 199, strides=[199, 199, 1], requires_grad=1, device=cpu)](%input.13)\n", | |
| " %80 : int = aten::size(%320, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %321 : Tensor = prim::profile[profiled_type=Float(2, 1, 199, strides=[199, 199, 1], requires_grad=1, device=cpu)](%input.13)\n", | |
| " %81 : int = aten::size(%321, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n", | |
| " %82 : int = aten::mul(%80, %81) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %83 : int = aten::floordiv(%82, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %84 : int[] = prim::ListConstruct(%83, %11)\n", | |
| " %322 : Tensor = prim::profile[profiled_type=Float(2, 1, 199, strides=[199, 199, 1], requires_grad=1, device=cpu)](%input.13)\n", | |
| " %85 : int[] = aten::size(%322) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
| " %86 : int[] = aten::slice(%85, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
| " %87 : int[] = aten::list(%86) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n", | |
| " %88 : int[] = aten::add(%84, %87) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n", | |
| " %size_prods.16 : int = aten::__getitem__(%88, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n", | |
| " %90 : int = aten::len(%88) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
| " %91 : int = aten::sub(%90, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
| " %size_prods.18 : int = prim::Loop(%91, %17, %size_prods.16) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n", | |
| " block0(%i.4 : int, %size_prods.20 : int):\n", | |
| " %95 : int = aten::add(%i.4, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n", | |
| " %96 : int = aten::__getitem__(%88, %95) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n", | |
| " %size_prods.22 : int = aten::mul(%size_prods.20, %96) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n", | |
| " -> (%17, %size_prods.22)\n", | |
| " %98 : bool = aten::eq(%size_prods.18, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n", | |
| " = prim::If(%98) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n", | |
| " block0():\n", | |
| " %99 : str = aten::format(%18, %88) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n", | |
| " = prim::RaiseException(%99) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n", | |
| " -> ()\n", | |
| " block1():\n", | |
| " -> ()\n", | |
| " %323 : Tensor = prim::profile[profiled_type=Float(2, 1, 199, strides=[199, 199, 1], requires_grad=1, device=cpu)](%input.13)\n", | |
| " %input.17 : Tensor = aten::group_norm(%323, %11, %78, %79, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n", | |
| " %324 : Tensor = prim::profile[profiled_type=Float(2, 1, 199, strides=[199, 199, 1], requires_grad=1, device=cpu)](%input.17)\n", | |
| " %input.21 : Tensor = aten::gelu(%324) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1555:11\n", | |
| " %102 : Tensor = prim::GetAttr[name=\"weight\"](%32)\n", | |
| " %103 : Tensor? = prim::GetAttr[name=\"bias\"](%32)\n", | |
| " %325 : Tensor = prim::profile[profiled_type=Float(2, 1, 199, strides=[199, 199, 1], requires_grad=1, device=cpu)](%input.21)\n", | |
| " %input.25 : Tensor = aten::conv1d(%325, %102, %103, %5, %7, %6, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:294:15\n", | |
| " %105 : bool = prim::GetAttr[name=\"training\"](%33)\n", | |
| " %326 : Tensor = prim::profile[profiled_type=Float(2, 1, 99, strides=[99, 99, 1], requires_grad=1, device=cpu)](%input.25)\n", | |
| " %input.29 : Tensor = aten::dropout(%326, %20, %105) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1168:60\n", | |
| " %107 : Tensor = prim::GetAttr[name=\"weight\"](%34)\n", | |
| " %108 : Tensor = prim::GetAttr[name=\"bias\"](%34)\n", | |
| " %327 : Tensor = prim::profile[profiled_type=Float(2, 1, 99, strides=[99, 99, 1], requires_grad=1, device=cpu)](%input.29)\n", | |
| " %109 : int = aten::size(%327, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %328 : Tensor = prim::profile[profiled_type=Float(2, 1, 99, strides=[99, 99, 1], requires_grad=1, device=cpu)](%input.29)\n", | |
| " %110 : int = aten::size(%328, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n", | |
| " %111 : int = aten::mul(%109, %110) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %112 : int = aten::floordiv(%111, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %113 : int[] = prim::ListConstruct(%112, %11)\n", | |
| " %329 : Tensor = prim::profile[profiled_type=Float(2, 1, 99, strides=[99, 99, 1], requires_grad=1, device=cpu)](%input.29)\n", | |
| " %114 : int[] = aten::size(%329) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
| " %115 : int[] = aten::slice(%114, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
| " %116 : int[] = aten::list(%115) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n", | |
| " %117 : int[] = aten::add(%113, %116) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n", | |
| " %size_prods.24 : int = aten::__getitem__(%117, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n", | |
| " %119 : int = aten::len(%117) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
| " %120 : int = aten::sub(%119, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
| " %size_prods.26 : int = prim::Loop(%120, %17, %size_prods.24) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n", | |
| " block0(%i.6 : int, %size_prods.28 : int):\n", | |
| " %124 : int = aten::add(%i.6, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n", | |
| " %125 : int = aten::__getitem__(%117, %124) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n", | |
| " %size_prods.30 : int = aten::mul(%size_prods.28, %125) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n", | |
| " -> (%17, %size_prods.30)\n", | |
| " %127 : bool = aten::eq(%size_prods.26, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n", | |
| " = prim::If(%127) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n", | |
| " block0():\n", | |
| " %128 : str = aten::format(%18, %117) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n", | |
| " = prim::RaiseException(%128) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n", | |
| " -> ()\n", | |
| " block1():\n", | |
| " -> ()\n", | |
| " %330 : Tensor = prim::profile[profiled_type=Float(2, 1, 99, strides=[99, 99, 1], requires_grad=1, device=cpu)](%input.29)\n", | |
| " %input.33 : Tensor = aten::group_norm(%330, %11, %107, %108, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n", | |
| " %331 : Tensor = prim::profile[profiled_type=Float(2, 1, 99, strides=[99, 99, 1], requires_grad=1, device=cpu)](%input.33)\n", | |
| " %input.37 : Tensor = aten::gelu(%331) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1555:11\n", | |
| " %131 : Tensor = prim::GetAttr[name=\"weight\"](%35)\n", | |
| " %132 : Tensor? = prim::GetAttr[name=\"bias\"](%35)\n", | |
| " %332 : Tensor = prim::profile[profiled_type=Float(2, 1, 99, strides=[99, 99, 1], requires_grad=1, device=cpu)](%input.37)\n", | |
| " %input.41 : Tensor = aten::conv1d(%332, %131, %132, %5, %7, %6, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:294:15\n", | |
| " %134 : bool = prim::GetAttr[name=\"training\"](%36)\n", | |
| " %333 : Tensor = prim::profile[profiled_type=Float(2, 1, 49, strides=[49, 49, 1], requires_grad=1, device=cpu)](%input.41)\n", | |
| " %input.45 : Tensor = aten::dropout(%333, %20, %134) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1168:60\n", | |
| " %136 : Tensor = prim::GetAttr[name=\"weight\"](%37)\n", | |
| " %137 : Tensor = prim::GetAttr[name=\"bias\"](%37)\n", | |
| " %334 : Tensor = prim::profile[profiled_type=Float(2, 1, 49, strides=[49, 49, 1], requires_grad=1, device=cpu)](%input.45)\n", | |
| " %138 : int = aten::size(%334, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %335 : Tensor = prim::profile[profiled_type=Float(2, 1, 49, strides=[49, 49, 1], requires_grad=1, device=cpu)](%input.45)\n", | |
| " %139 : int = aten::size(%335, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n", | |
| " %140 : int = aten::mul(%138, %139) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %141 : int = aten::floordiv(%140, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %142 : int[] = prim::ListConstruct(%141, %11)\n", | |
| " %336 : Tensor = prim::profile[profiled_type=Float(2, 1, 49, strides=[49, 49, 1], requires_grad=1, device=cpu)](%input.45)\n", | |
| " %143 : int[] = aten::size(%336) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
| " %144 : int[] = aten::slice(%143, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
| " %145 : int[] = aten::list(%144) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n", | |
| " %146 : int[] = aten::add(%142, %145) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n", | |
| " %size_prods.32 : int = aten::__getitem__(%146, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n", | |
| " %148 : int = aten::len(%146) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
| " %149 : int = aten::sub(%148, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
| " %size_prods.34 : int = prim::Loop(%149, %17, %size_prods.32) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n", | |
| " block0(%i.8 : int, %size_prods.36 : int):\n", | |
| " %153 : int = aten::add(%i.8, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n", | |
| " %154 : int = aten::__getitem__(%146, %153) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n", | |
| " %size_prods.38 : int = aten::mul(%size_prods.36, %154) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n", | |
| " -> (%17, %size_prods.38)\n", | |
| " %156 : bool = aten::eq(%size_prods.34, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n", | |
| " = prim::If(%156) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n", | |
| " block0():\n", | |
| " %157 : str = aten::format(%18, %146) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n", | |
| " = prim::RaiseException(%157) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n", | |
| " -> ()\n", | |
| " block1():\n", | |
| " -> ()\n", | |
| " %337 : Tensor = prim::profile[profiled_type=Float(2, 1, 49, strides=[49, 49, 1], requires_grad=1, device=cpu)](%input.45)\n", | |
| " %input.49 : Tensor = aten::group_norm(%337, %11, %136, %137, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n", | |
| " %338 : Tensor = prim::profile[profiled_type=Float(2, 1, 49, strides=[49, 49, 1], requires_grad=1, device=cpu)](%input.49)\n", | |
| " %input.53 : Tensor = aten::gelu(%338) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1555:11\n", | |
| " %160 : Tensor = prim::GetAttr[name=\"weight\"](%38)\n", | |
| " %161 : Tensor? = prim::GetAttr[name=\"bias\"](%38)\n", | |
| " %339 : Tensor = prim::profile[profiled_type=Float(2, 1, 49, strides=[49, 49, 1], requires_grad=1, device=cpu)](%input.53)\n", | |
| " %input.57 : Tensor = aten::conv1d(%339, %160, %161, %5, %7, %6, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:294:15\n", | |
| " %163 : bool = prim::GetAttr[name=\"training\"](%39)\n", | |
| " %340 : Tensor = prim::profile[profiled_type=Float(2, 1, 24, strides=[24, 24, 1], requires_grad=1, device=cpu)](%input.57)\n", | |
| " %input.61 : Tensor = aten::dropout(%340, %20, %163) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1168:60\n", | |
| " %165 : Tensor = prim::GetAttr[name=\"weight\"](%40)\n", | |
| " %166 : Tensor = prim::GetAttr[name=\"bias\"](%40)\n", | |
| " %341 : Tensor = prim::profile[profiled_type=Float(2, 1, 24, strides=[24, 24, 1], requires_grad=1, device=cpu)](%input.61)\n", | |
| " %167 : int = aten::size(%341, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %342 : Tensor = prim::profile[profiled_type=Float(2, 1, 24, strides=[24, 24, 1], requires_grad=1, device=cpu)](%input.61)\n", | |
| " %168 : int = aten::size(%342, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n", | |
| " %169 : int = aten::mul(%167, %168) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %170 : int = aten::floordiv(%169, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %171 : int[] = prim::ListConstruct(%170, %11)\n", | |
| " %343 : Tensor = prim::profile[profiled_type=Float(2, 1, 24, strides=[24, 24, 1], requires_grad=1, device=cpu)](%input.61)\n", | |
| " %172 : int[] = aten::size(%343) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
| " %173 : int[] = aten::slice(%172, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
| " %174 : int[] = aten::list(%173) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n", | |
| " %175 : int[] = aten::add(%171, %174) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n", | |
| " %size_prods.40 : int = aten::__getitem__(%175, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n", | |
| " %177 : int = aten::len(%175) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
| " %178 : int = aten::sub(%177, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
| " %size_prods.42 : int = prim::Loop(%178, %17, %size_prods.40) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n", | |
| " block0(%i.10 : int, %size_prods.44 : int):\n", | |
| " %182 : int = aten::add(%i.10, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n", | |
| " %183 : int = aten::__getitem__(%175, %182) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n", | |
| " %size_prods.46 : int = aten::mul(%size_prods.44, %183) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n", | |
| " -> (%17, %size_prods.46)\n", | |
| " %185 : bool = aten::eq(%size_prods.42, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n", | |
| " = prim::If(%185) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n", | |
| " block0():\n", | |
| " %186 : str = aten::format(%18, %175) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n", | |
| " = prim::RaiseException(%186) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n", | |
| " -> ()\n", | |
| " block1():\n", | |
| " -> ()\n", | |
| " %344 : Tensor = prim::profile[profiled_type=Float(2, 1, 24, strides=[24, 24, 1], requires_grad=1, device=cpu)](%input.61)\n", | |
| " %input.65 : Tensor = aten::group_norm(%344, %11, %165, %166, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n", | |
| " %345 : Tensor = prim::profile[profiled_type=Float(2, 1, 24, strides=[24, 24, 1], requires_grad=1, device=cpu)](%input.65)\n", | |
| " %input.69 : Tensor = aten::gelu(%345) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1555:11\n", | |
| " %189 : Tensor = prim::GetAttr[name=\"weight\"](%41)\n", | |
| " %190 : Tensor? = prim::GetAttr[name=\"bias\"](%41)\n", | |
| " %346 : Tensor = prim::profile[profiled_type=Float(2, 1, 24, strides=[24, 24, 1], requires_grad=1, device=cpu)](%input.69)\n", | |
| " %input.73 : Tensor = aten::conv1d(%346, %189, %190, %5, %7, %6, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:294:15\n", | |
| " %192 : bool = prim::GetAttr[name=\"training\"](%42)\n", | |
| " %347 : Tensor = prim::profile[profiled_type=Float(2, 1, 11, strides=[11, 11, 1], requires_grad=1, device=cpu)](%input.73)\n", | |
| " %input.77 : Tensor = aten::dropout(%347, %20, %192) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1168:60\n", | |
| " %194 : Tensor = prim::GetAttr[name=\"weight\"](%43)\n", | |
| " %195 : Tensor = prim::GetAttr[name=\"bias\"](%43)\n", | |
| " %348 : Tensor = prim::profile[profiled_type=Float(2, 1, 11, strides=[11, 11, 1], requires_grad=1, device=cpu)](%input.77)\n", | |
| " %196 : int = aten::size(%348, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %349 : Tensor = prim::profile[profiled_type=Float(2, 1, 11, strides=[11, 11, 1], requires_grad=1, device=cpu)](%input.77)\n", | |
| " %197 : int = aten::size(%349, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n", | |
| " %198 : int = aten::mul(%196, %197) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %199 : int = aten::floordiv(%198, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %200 : int[] = prim::ListConstruct(%199, %11)\n", | |
| " %350 : Tensor = prim::profile[profiled_type=Float(2, 1, 11, strides=[11, 11, 1], requires_grad=1, device=cpu)](%input.77)\n", | |
| " %201 : int[] = aten::size(%350) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
| " %202 : int[] = aten::slice(%201, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
| " %203 : int[] = aten::list(%202) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n", | |
| " %204 : int[] = aten::add(%200, %203) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n", | |
| " %size_prods.48 : int = aten::__getitem__(%204, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n", | |
| " %206 : int = aten::len(%204) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
| " %207 : int = aten::sub(%206, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
| " %size_prods.50 : int = prim::Loop(%207, %17, %size_prods.48) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n", | |
| " block0(%i.12 : int, %size_prods.52 : int):\n", | |
| " %211 : int = aten::add(%i.12, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n", | |
| " %212 : int = aten::__getitem__(%204, %211) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n", | |
| " %size_prods.54 : int = aten::mul(%size_prods.52, %212) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n", | |
| " -> (%17, %size_prods.54)\n", | |
| " %214 : bool = aten::eq(%size_prods.50, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n", | |
| " = prim::If(%214) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n", | |
| " block0():\n", | |
| " %215 : str = aten::format(%18, %204) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n", | |
| " = prim::RaiseException(%215) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n", | |
| " -> ()\n", | |
| " block1():\n", | |
| " -> ()\n", | |
| " %351 : Tensor = prim::profile[profiled_type=Float(2, 1, 11, strides=[11, 11, 1], requires_grad=1, device=cpu)](%input.77)\n", | |
| " %input.81 : Tensor = aten::group_norm(%351, %11, %194, %195, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n", | |
| " %352 : Tensor = prim::profile[profiled_type=Float(2, 1, 11, strides=[11, 11, 1], requires_grad=1, device=cpu)](%input.81)\n", | |
| " %input.85 : Tensor = aten::gelu(%352) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1555:11\n", | |
| " %218 : Tensor = prim::GetAttr[name=\"weight\"](%44)\n", | |
| " %219 : Tensor? = prim::GetAttr[name=\"bias\"](%44)\n", | |
| " %353 : Tensor = prim::profile[profiled_type=Float(2, 1, 11, strides=[11, 11, 1], requires_grad=1, device=cpu)](%input.85)\n", | |
| " %input.89 : Tensor = aten::conv1d(%353, %218, %219, %5, %7, %6, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:294:15\n", | |
| " %221 : bool = prim::GetAttr[name=\"training\"](%45)\n", | |
| " %354 : Tensor = prim::profile[profiled_type=Float(2, 1, 5, strides=[5, 5, 1], requires_grad=1, device=cpu)](%input.89)\n", | |
| " %input.93 : Tensor = aten::dropout(%354, %20, %221) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1168:60\n", | |
| " %223 : Tensor = prim::GetAttr[name=\"weight\"](%46)\n", | |
| " %224 : Tensor = prim::GetAttr[name=\"bias\"](%46)\n", | |
| " %355 : Tensor = prim::profile[profiled_type=Float(2, 1, 5, strides=[5, 5, 1], requires_grad=1, device=cpu)](%input.93)\n", | |
| " %225 : int = aten::size(%355, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %356 : Tensor = prim::profile[profiled_type=Float(2, 1, 5, strides=[5, 5, 1], requires_grad=1, device=cpu)](%input.93)\n", | |
| " %226 : int = aten::size(%356, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n", | |
| " %227 : int = aten::mul(%225, %226) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %228 : int = aten::floordiv(%227, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %229 : int[] = prim::ListConstruct(%228, %11)\n", | |
| " %357 : Tensor = prim::profile[profiled_type=Float(2, 1, 5, strides=[5, 5, 1], requires_grad=1, device=cpu)](%input.93)\n", | |
| " %230 : int[] = aten::size(%357) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
| " %231 : int[] = aten::slice(%230, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
| " %232 : int[] = aten::list(%231) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n", | |
| " %233 : int[] = aten::add(%229, %232) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n", | |
| " %size_prods.56 : int = aten::__getitem__(%233, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n", | |
| " %235 : int = aten::len(%233) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
| " %236 : int = aten::sub(%235, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
| " %size_prods.58 : int = prim::Loop(%236, %17, %size_prods.56) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n", | |
| " block0(%i.14 : int, %size_prods.60 : int):\n", | |
| " %240 : int = aten::add(%i.14, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n", | |
| " %241 : int = aten::__getitem__(%233, %240) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n", | |
| " %size_prods.62 : int = aten::mul(%size_prods.60, %241) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n", | |
| " -> (%17, %size_prods.62)\n", | |
| " %243 : bool = aten::eq(%size_prods.58, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n", | |
| " = prim::If(%243) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n", | |
| " block0():\n", | |
| " %244 : str = aten::format(%18, %233) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n", | |
| " = prim::RaiseException(%244) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n", | |
| " -> ()\n", | |
| " block1():\n", | |
| " -> ()\n", | |
| " %358 : Tensor = prim::profile[profiled_type=Float(2, 1, 5, strides=[5, 5, 1], requires_grad=1, device=cpu)](%input.93)\n", | |
| " %input.97 : Tensor = aten::group_norm(%358, %11, %223, %224, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n", | |
| " %359 : Tensor = prim::profile[profiled_type=Float(2, 1, 5, strides=[5, 5, 1], requires_grad=1, device=cpu)](%input.97)\n", | |
| " %input.101 : Tensor = aten::gelu(%359) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1555:11\n", | |
| " %247 : Tensor = prim::GetAttr[name=\"weight\"](%47)\n", | |
| " %248 : Tensor? = prim::GetAttr[name=\"bias\"](%47)\n", | |
| " %360 : Tensor = prim::profile[profiled_type=Float(2, 1, 5, strides=[5, 5, 1], requires_grad=1, device=cpu)](%input.101)\n", | |
| " %input.105 : Tensor = aten::conv1d(%360, %247, %248, %5, %7, %6, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:294:15\n", | |
| " %250 : bool = prim::GetAttr[name=\"training\"](%48)\n", | |
| " %361 : Tensor = prim::profile[profiled_type=Float(2, 1, 2, strides=[2, 2, 1], requires_grad=1, device=cpu)](%input.105)\n", | |
| " %input.109 : Tensor = aten::dropout(%361, %20, %250) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1168:60\n", | |
| " %252 : Tensor = prim::GetAttr[name=\"weight\"](%49)\n", | |
| " %253 : Tensor = prim::GetAttr[name=\"bias\"](%49)\n", | |
| " %362 : Tensor = prim::profile[profiled_type=Float(2, 1, 2, strides=[2, 2, 1], requires_grad=1, device=cpu)](%input.109)\n", | |
| " %254 : int = aten::size(%362, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %363 : Tensor = prim::profile[profiled_type=Float(2, 1, 2, strides=[2, 2, 1], requires_grad=1, device=cpu)](%input.109)\n", | |
| " %255 : int = aten::size(%363, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:40\n", | |
| " %256 : int = aten::mul(%254, %255) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %257 : int = aten::floordiv(%256, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:24\n", | |
| " %258 : int[] = prim::ListConstruct(%257, %11)\n", | |
| " %364 : Tensor = prim::profile[profiled_type=Float(2, 1, 2, strides=[2, 2, 1], requires_grad=1, device=cpu)](%input.109)\n", | |
| " %259 : int[] = aten::size(%364) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
| " %260 : int[] = aten::slice(%259, %21, %22, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:88\n", | |
| " %261 : int[] = aten::list(%260) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:83\n", | |
| " %262 : int[] = aten::add(%258, %261) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2358:23\n", | |
| " %size_prods.1 : int = aten::__getitem__(%262, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2243:17\n", | |
| " %264 : int = aten::len(%262) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
| " %265 : int = aten::sub(%264, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:19\n", | |
| " %size_prods : int = prim::Loop(%265, %17, %size_prods.1) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2244:4\n", | |
| " block0(%i.1 : int, %size_prods.11 : int):\n", | |
| " %269 : int = aten::add(%i.1, %21) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:27\n", | |
| " %270 : int = aten::__getitem__(%262, %269) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:22\n", | |
| " %size_prods.5 : int = aten::mul(%size_prods.11, %270) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2245:8\n", | |
| " -> (%17, %size_prods.5)\n", | |
| " %272 : bool = aten::eq(%size_prods, %11) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:7\n", | |
| " = prim::If(%272) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2246:4\n", | |
| " block0():\n", | |
| " %273 : str = aten::format(%18, %262) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:25\n", | |
| " = prim::RaiseException(%273) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2247:8\n", | |
| " -> ()\n", | |
| " block1():\n", | |
| " -> ()\n", | |
| " %365 : Tensor = prim::profile[profiled_type=Float(2, 1, 2, strides=[2, 2, 1], requires_grad=1, device=cpu)](%input.109)\n", | |
| " %input.113 : Tensor = aten::group_norm(%365, %11, %252, %253, %19, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:2359:11\n", | |
| " %366 : Tensor = prim::profile[profiled_type=Float(2, 1, 2, strides=[2, 2, 1], requires_grad=1, device=cpu)](%input.113)\n", | |
| " %input.117 : Tensor = aten::gelu(%366) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1555:11\n", | |
| " %367 : Tensor = prim::profile[profiled_type=Float(2, 1, 2, strides=[2, 2, 1], requires_grad=1, device=cpu)](%input.117)\n", | |
| " %h.1 : Tensor = aten::transpose(%367, %11, %21) # <ipython-input-1-14e2ee0fb7fa>:69:12\n", | |
| " %277 : __torch__.torch.nn.modules.rnn.LSTM = prim::GetAttr[name=\"lstm\"](%self)\n", | |
| " %368 : Tensor = prim::profile[profiled_type=Float(2, 2, 1, strides=[2, 1, 2], requires_grad=1, device=cpu)](%h.1)\n", | |
| " %max_batch_size.1 : int = aten::size(%368, %10) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:658:29\n", | |
| " %369 : Tensor = prim::profile[profiled_type=Float(2, 2, 1, strides=[2, 1, 2], requires_grad=1, device=cpu)](%h.1)\n", | |
| " %279 : int = prim::dtype(%369)\n", | |
| " %370 : Tensor = prim::profile[profiled_type=Float(2, 2, 1, strides=[2, 1, 2], requires_grad=1, device=cpu)](%h.1)\n", | |
| " %280 : Device = prim::device(%370)\n", | |
| " %281 : int[] = prim::ListConstruct(%11, %max_batch_size.1, %11)\n", | |
| " %h_zeros.1 : Tensor = aten::zeros(%281, %279, %22, %280, %22) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:665:22\n", | |
| " %c_zeros.1 : Tensor = aten::zeros(%281, %279, %22, %280, %22) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:668:22\n", | |
| " %371 : Tensor = prim::profile[profiled_type=Float(2, 2, 1, strides=[2, 1, 2], requires_grad=1, device=cpu)](%h.1)\n", | |
| " %284 : int = aten::dim(%371) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:200:11\n", | |
| " %285 : bool = aten::ne(%284, %14) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:200:11\n", | |
| " = prim::If(%285) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:200:8\n", | |
| " block0():\n", | |
| " %286 : str = aten::format(%13, %14, %284) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:202:16\n", | |
| " = prim::RaiseException(%286) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:201:12\n", | |
| " -> ()\n", | |
| " block1():\n", | |
| " -> ()\n", | |
| " %372 : Tensor = prim::profile[profiled_type=Float(2, 2, 1, strides=[2, 1, 2], requires_grad=1, device=cpu)](%h.1)\n", | |
| " %287 : int = aten::size(%372, %23) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:204:30\n", | |
| " %288 : bool = aten::ne(%11, %287) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:204:11\n", | |
| " = prim::If(%288) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:204:8\n", | |
| " block0():\n", | |
| " %289 : str = aten::format(%12, %11, %287) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:206:16\n", | |
| " = prim::RaiseException(%289) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:205:12\n", | |
| " -> ()\n", | |
| " block1():\n", | |
| " -> ()\n", | |
| " %expected_hidden_size.3 : (int, int, int) = prim::TupleConstruct(%11, %max_batch_size.1, %11)\n", | |
| " %373 : Tensor = prim::profile[profiled_type=Float(1, 2, 1, strides=[2, 1, 1], requires_grad=0, device=cpu)](%h_zeros.1)\n", | |
| " %291 : int[] = aten::size(%373) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:225:11\n", | |
| " %292 : int[] = prim::ListConstruct(%11, %max_batch_size.1, %11)\n", | |
| " %293 : bool = aten::ne(%291, %292) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:225:11\n", | |
| " = prim::If(%293) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:225:8\n", | |
| " block0():\n", | |
| " %294 : int[] = aten::list(%291) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:226:64\n", | |
| " %295 : str = aten::format(%15, %expected_hidden_size.3, %294) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:226:31\n", | |
| " = prim::RaiseException(%295) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:226:12\n", | |
| " -> ()\n", | |
| " block1():\n", | |
| " -> ()\n", | |
| " %374 : Tensor = prim::profile[profiled_type=Float(1, 2, 1, strides=[2, 1, 1], requires_grad=0, device=cpu)](%c_zeros.1)\n", | |
| " %296 : int[] = aten::size(%374) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:225:11\n", | |
| " %297 : bool = aten::ne(%296, %292) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:225:11\n", | |
| " = prim::If(%297) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:225:8\n", | |
| " block0():\n", | |
| " %298 : int[] = aten::list(%296) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:226:64\n", | |
| " %299 : str = aten::format(%16, %expected_hidden_size.3, %298) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:226:31\n", | |
| " = prim::RaiseException(%299) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:226:12\n", | |
| " -> ()\n", | |
| " block1():\n", | |
| " -> ()\n", | |
| " %300 : Tensor[] = prim::GetAttr[name=\"_flat_weights\"](%277)\n", | |
| " %301 : bool = prim::GetAttr[name=\"training\"](%277)\n", | |
| " %375 : Tensor = prim::profile[profiled_type=Float(1, 2, 1, strides=[2, 1, 1], requires_grad=0, device=cpu)](%h_zeros.1)\n", | |
| " %376 : Tensor = prim::profile[profiled_type=Float(1, 2, 1, strides=[2, 1, 1], requires_grad=0, device=cpu)](%c_zeros.1)\n", | |
| " %302 : Tensor[] = prim::ListConstruct(%375, %376)\n", | |
| " %377 : Tensor = prim::profile[profiled_type=Float(2, 2, 1, strides=[2, 1, 2], requires_grad=1, device=cpu)](%h.1)\n", | |
| " %303 : Tensor, %304 : Tensor, %305 : Tensor = aten::lstm(%377, %302, %300, %17, %11, %20, %301, %9, %17) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/rnn.py:679:21\n", | |
| " %306 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name=\"fc\"](%self)\n", | |
| " %307 : Tensor = prim::GetAttr[name=\"weight\"](%306)\n", | |
| " %308 : Tensor = prim::GetAttr[name=\"bias\"](%306)\n", | |
| " %378 : Tensor = prim::profile[profiled_type=Float(2, 2, 1, strides=[1, 2, 1], requires_grad=1, device=cpu)](%303)\n", | |
| " %379 : Tensor = prim::profile[profiled_type=Float(30, 1, strides=[1, 1], requires_grad=1, device=cpu)](%307)\n", | |
| " %380 : Tensor = prim::profile[profiled_type=Float(30, strides=[1], requires_grad=1, device=cpu)](%308)\n", | |
| " %h.9 : Tensor = aten::linear(%378, %379, %380) # /mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/functional.py:1847:11\n", | |
| " %381 : Tensor = prim::profile[profiled_type=Float(2, 2, 30, strides=[60, 30, 1], requires_grad=1, device=cpu)](%h.9)\n", | |
| " %z.1 : Tensor = aten::log_softmax(%381, %23, %22) # <ipython-input-1-14e2ee0fb7fa>:72:15\n", | |
| " %382 : Tensor = prim::profile[profiled_type=Float(2, 2, 30, strides=[60, 30, 1], requires_grad=1, device=cpu)](%z.1)\n", | |
| " %h.2 : Tensor = aten::transpose(%382, %10, %11) # <ipython-input-1-14e2ee0fb7fa>:85:12\n", | |
| " %383 : Tensor = prim::profile[profiled_type=Float(2, 2, 30, strides=[30, 60, 1], requires_grad=1, device=cpu)](%h.2)\n", | |
| " %312 : Tensor = aten::ctc_loss(%383, %y.1, %xlen.1, %ylen.1, %10, %11, %9) # <ipython-input-1-14e2ee0fb7fa>:86:15\n", | |
| " %384 : Tensor = prim::profile[profiled_type=Float(2, 2, 30, strides=[60, 30, 1], requires_grad=1, device=cpu)](%z.1)\n", | |
| " %313 : (Tensor, Tensor) = prim::TupleConstruct(%312, %384)\n", | |
| " = prim::profile()\n", | |
| " return (%313)" | |
| ] | |
| }, | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# Jittable CTC model def\n", | |
| "import torch\n", | |
| "from torch import nn\n", | |
| "\n", | |
| "\n", | |
| "class SpeechModel(nn.Module):\n", | |
| " def __init__(self, n_vocab, n_hid, n_feat=1, dropout=0.5, act=nn.GELU):\n", | |
| " super().__init__()\n", | |
| " self.encoder = nn.Sequential(\n", | |
| " # Wav2vec 2.0 style encoder https://arxiv.org/abs/2006.11477\n", | |
| " nn.GroupNorm(1, n_feat),\n", | |
| " \n", | |
| " nn.Conv1d(n_feat, n_hid, 10, 5),\n", | |
| " nn.Dropout(dropout),\n", | |
| " nn.GroupNorm(1, n_hid),\n", | |
| " act(),\n", | |
| "\n", | |
| " nn.Conv1d(n_hid, n_hid, 3, 2),\n", | |
| " nn.Dropout(dropout),\n", | |
| " nn.GroupNorm(1, n_hid),\n", | |
| " act(),\n", | |
| " \n", | |
| " nn.Conv1d(n_hid, n_hid, 3, 2),\n", | |
| " nn.Dropout(dropout),\n", | |
| " nn.GroupNorm(1, n_hid),\n", | |
| " act(),\n", | |
| " \n", | |
| " nn.Conv1d(n_hid, n_hid, 3, 2),\n", | |
| " nn.Dropout(dropout),\n", | |
| " nn.GroupNorm(1, n_hid),\n", | |
| " act(),\n", | |
| " \n", | |
| " nn.Conv1d(n_hid, n_hid, 3, 2),\n", | |
| " nn.Dropout(dropout),\n", | |
| " nn.GroupNorm(1, n_hid),\n", | |
| " act(),\n", | |
| " \n", | |
| " nn.Conv1d(n_hid, n_hid, 2, 2),\n", | |
| " nn.Dropout(dropout),\n", | |
| " nn.GroupNorm(1, n_hid),\n", | |
| " act(),\n", | |
| "\n", | |
| " nn.Conv1d(n_hid, n_hid, 2, 2),\n", | |
| " nn.Dropout(dropout),\n", | |
| " nn.GroupNorm(1, n_hid),\n", | |
| " act(),\n", | |
| " )\n", | |
| " self.lstm = nn.LSTM(n_hid, n_hid, 1, dropout=dropout, batch_first=True)\n", | |
| " self.fc = nn.Linear(n_hid, n_vocab)\n", | |
| "\n", | |
| " @torch.jit.ignore # Cannot jit?\n", | |
| " def convlen(self, xlen):\n", | |
| " xlen = xlen.float()\n", | |
| " for m in self.encoder:\n", | |
| " if isinstance(m, nn.Conv1d):\n", | |
| " xlen = torch.floor((xlen - m.kernel_size[0]) / m.stride[0] + 1)\n", | |
| " return xlen.long()\n", | |
| " \n", | |
| " @torch.jit.export\n", | |
| " def inference(self, x):\n", | |
| " \"\"\"Transcribe one wav sequence to ids.\"\"\"\n", | |
| " ids = self.logprob(x[None])[0].argmax(-1)\n", | |
| " ids = torch.unique_consecutive(ids)\n", | |
| " return ids[ids != 0]\n", | |
| "\n", | |
| " @torch.jit.export\n", | |
| " def logprob(self, x):\n", | |
| " \"\"\"Predicts log softmax distribution (batch, time, class).\"\"\"\n", | |
| " h = self.encoder(x[:, None, :]).transpose(1, 2)\n", | |
| " h, _ = self.lstm(h)\n", | |
| " h = self.fc(h)\n", | |
| " return h.log_softmax(-1)\n", | |
| " \n", | |
| " @torch.jit.export\n", | |
| " def forward(self, x, xlen, y, ylen):\n", | |
| " \"\"\"Computes CTC loss.\n", | |
| " \n", | |
| " Args:\n", | |
| " x: input feature, float (batch, time1)\n", | |
| " xlen: encoded lengths, long (batch)\n", | |
| " y: target ids, long (batch, time2)\n", | |
| " ylen: target lengths, long (batch)\n", | |
| " \"\"\"\n", | |
| " z = self.logprob(x) # (batch, time1, n_vocab)\n", | |
| " h = z.transpose(0, 1) # (time1, batch, n_vocab)\n", | |
| " return torch.ctc_loss(h, y, xlen, ylen), z\n", | |
| "\n", | |
| "\n", | |
| "model = torch.jit.script(SpeechModel(30, 1))\n", | |
| "x = torch.randn(2, 1000)\n", | |
| "xlen = model.convlen(torch.tensor([1000, 900]).long())\n", | |
| "y = torch.ones(2, 10).long()\n", | |
| "ylen = torch.tensor([10, 9]).long()\n", | |
| "print(\"#params:\\t\", sum(v.numel() for v in model.state_dict().values()))\n", | |
| "print(\"out shape:\\t\", model.inference(x[0]).shape)\n", | |
| "model.graph_for(x, xlen, y, ylen)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "38f14b44", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "848it [00:00, 528494.77it/s]\n", | |
| "848it [00:00, 4886.44it/s]\n", | |
| "100it [00:00, 412014.15it/s]\n", | |
| "100it [00:00, 4852.50it/s]\n", | |
| "130it [00:00, 388638.29it/s]\n", | |
| "130it [00:00, 4795.81it/s]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Dataset\n", | |
| "import os\n", | |
| "import librosa\n", | |
| "from tqdm import tqdm\n", | |
| "\n", | |
| "\n", | |
| "class ESPnetDataset(torch.utils.data.Dataset):\n", | |
| " def __init__(self, root, sr=16000):\n", | |
| " self.root = root\n", | |
| " self.sr = sr\n", | |
| " self.texts = {}\n", | |
| " self.wavs = {}\n", | |
| " self.ids = []\n", | |
| " self.tokenizer = None\n", | |
| " self.vocab = set()\n", | |
| " with open(os.path.join(root, \"text\"), \"r\") as f:\n", | |
| " for line in tqdm(f):\n", | |
| " i, t = line.split(maxsplit=1)\n", | |
| " self.ids.append(i)\n", | |
| " t = t.strip()\n", | |
| " self.texts[i] = t\n", | |
| " for c in t:\n", | |
| " self.vocab.add(c)\n", | |
| " with open(os.path.join(root, \"wav.scp\"), \"r\") as f:\n", | |
| " for line in tqdm(f):\n", | |
| " s = line.split()\n", | |
| " path = os.path.join(root, \"../..\", s[-2])\n", | |
| " self.wavs[s[0]] = torch.as_tensor(librosa.load(path, sr=sr)[0])\n", | |
| "\n", | |
| " def __len__(self):\n", | |
| " return len(self.ids)\n", | |
| " \n", | |
| " def __getitem__(self, idx):\n", | |
| " k = self.ids[idx]\n", | |
| " wav = self.wavs[k]\n", | |
| " return wav, self.tokenizer.encode(self.texts[k])\n", | |
| " \n", | |
| "\n", | |
| "class Tokenizer:\n", | |
| " def __init__(self, *datasets):\n", | |
| " vocabs = []\n", | |
| " for d in datasets:\n", | |
| " vocabs.append(d.vocab)\n", | |
| " d.tokenizer = self\n", | |
| " tokens = sorted(list(set.union(*vocabs)))\n", | |
| " self.s2i = {'[blank]': 0, '[unk]': 1}\n", | |
| " self.i2s = {0: '[blank]', 1: '[unk]'}\n", | |
| " for i, token in enumerate(tokens, len(self.s2i)):\n", | |
| " self.i2s[i] = token\n", | |
| " self.s2i[token] = i\n", | |
| " \n", | |
| " def __len__(self):\n", | |
| " return len(self.s2i)\n", | |
| "\n", | |
| " def encode(self, s):\n", | |
| " return torch.tensor([self.s2i[c] for c in s])\n", | |
| "\n", | |
| " def decode(self, e):\n", | |
| " return \"\".join(self.i2s[i.item()] for i in e)\n", | |
| "\n", | |
| " \n", | |
| "trainset = ESPnetDataset('../egs/an4/asr1/data/train_nodev/')\n", | |
| "devset = ESPnetDataset('../egs/an4/asr1/data/train_dev/')\n", | |
| "testset = ESPnetDataset('../egs/an4/asr1/data/test')\n", | |
| "tokenizer = Tokenizer(trainset, devset, testset)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "eb1be0e8", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/cuda/__init__.py:83: UserWarning: \n", | |
| " Found GPU%d %s which is of cuda capability %d.%d.\n", | |
| " PyTorch no longer supports this GPU because it is too old.\n", | |
| " The minimum cuda capability supported by this library is %d.%d.\n", | |
| " \n", | |
| " warnings.warn(old_gpu_warn.format(d, name, major, minor, min_arch // 10, min_arch % 10))\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "max xlen: 102400\n", | |
| "max ylen: 57\n", | |
| "#vocab: 29\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/mnt/ae0fa180-1e4d-4945-8f63-f64615d07b30/repos/espnet/tools/venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1051: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters(). (Triggered internally at /opt/conda/conda-bld/pytorch_1623448278899/work/aten/src/ATen/native/cudnn/RNN.cpp:924.)\n", | |
| " return forward_call(*input, **kwargs)\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "epoch: 0\n", | |
| "sec/epoch: 3.954852342605591\n", | |
| "train loss: tensor(5.2251)\n", | |
| "train cer: tensor(0.9882)\n", | |
| "dev loss: tensor(3.0169)\n", | |
| "dev cer: tensor(1.)\n", | |
| "epoch: 5\n", | |
| "sec/epoch: 3.640742063522339\n", | |
| "train loss: tensor(2.9098)\n", | |
| "train cer: tensor(1.)\n", | |
| "dev loss: tensor(2.9442)\n", | |
| "dev cer: tensor(1.)\n", | |
| "epoch: 10\n", | |
| "sec/epoch: 3.664271354675293\n", | |
| "train loss: tensor(2.7860)\n", | |
| "train cer: tensor(0.9860)\n", | |
| "dev loss: tensor(2.7893)\n", | |
| "dev cer: tensor(0.9835)\n", | |
| "epoch: 15\n", | |
| "sec/epoch: 3.6453144550323486\n", | |
| "train loss: tensor(2.6531)\n", | |
| "train cer: tensor(0.9509)\n", | |
| "dev loss: tensor(2.6637)\n", | |
| "dev cer: tensor(0.9455)\n", | |
| "epoch: 20\n", | |
| "sec/epoch: 3.6640310287475586\n", | |
| "train loss: tensor(2.4639)\n", | |
| "train cer: tensor(0.9403)\n", | |
| "dev loss: tensor(2.5285)\n", | |
| "dev cer: tensor(0.9422)\n", | |
| "epoch: 25\n", | |
| "sec/epoch: 3.6508846282958984\n", | |
| "train loss: tensor(2.2603)\n", | |
| "train cer: tensor(0.7138)\n", | |
| "dev loss: tensor(2.3292)\n", | |
| "dev cer: tensor(0.7413)\n", | |
| "epoch: 30\n", | |
| "sec/epoch: 3.669067859649658\n", | |
| "train loss: tensor(2.0850)\n", | |
| "train cer: tensor(0.6624)\n", | |
| "dev loss: tensor(2.1549)\n", | |
| "dev cer: tensor(0.6562)\n", | |
| "epoch: 35\n", | |
| "sec/epoch: 3.6770923137664795\n", | |
| "train loss: tensor(1.8892)\n", | |
| "train cer: tensor(0.5947)\n", | |
| "dev loss: tensor(2.0249)\n", | |
| "dev cer: tensor(0.6347)\n", | |
| "epoch: 40\n", | |
| "sec/epoch: 3.6583163738250732\n", | |
| "train loss: tensor(1.6512)\n", | |
| "train cer: tensor(0.4884)\n", | |
| "dev loss: tensor(1.8241)\n", | |
| "dev cer: tensor(0.5076)\n", | |
| "epoch: 45\n", | |
| "sec/epoch: 3.678729295730591\n", | |
| "train loss: tensor(1.4269)\n", | |
| "train cer: tensor(0.3909)\n", | |
| "dev loss: tensor(1.7563)\n", | |
| "dev cer: tensor(0.4751)\n", | |
| "epoch: 50\n", | |
| "sec/epoch: 3.658202886581421\n", | |
| "train loss: tensor(1.2555)\n", | |
| "train cer: tensor(0.3106)\n", | |
| "dev loss: tensor(1.5976)\n", | |
| "dev cer: tensor(0.3817)\n", | |
| "epoch: 55\n", | |
| "sec/epoch: 3.6857497692108154\n", | |
| "train loss: tensor(1.0824)\n", | |
| "train cer: tensor(0.2570)\n", | |
| "dev loss: tensor(1.4412)\n", | |
| "dev cer: tensor(0.3340)\n", | |
| "epoch: 60\n", | |
| "sec/epoch: 3.6613011360168457\n", | |
| "train loss: tensor(0.9420)\n", | |
| "train cer: tensor(0.2111)\n", | |
| "dev loss: tensor(1.4324)\n", | |
| "dev cer: tensor(0.3054)\n", | |
| "epoch: 65\n", | |
| "sec/epoch: 3.6761670112609863\n", | |
| "train loss: tensor(0.8255)\n", | |
| "train cer: tensor(0.1789)\n", | |
| "dev loss: tensor(1.3962)\n", | |
| "dev cer: tensor(0.2882)\n", | |
| "epoch: 70\n", | |
| "sec/epoch: 3.682213306427002\n", | |
| "train loss: tensor(0.7396)\n", | |
| "train cer: tensor(0.1536)\n", | |
| "dev loss: tensor(1.3824)\n", | |
| "dev cer: tensor(0.2519)\n", | |
| "epoch: 75\n", | |
| "sec/epoch: 3.653740882873535\n", | |
| "train loss: tensor(0.6194)\n", | |
| "train cer: tensor(0.1323)\n", | |
| "dev loss: tensor(1.5055)\n", | |
| "dev cer: tensor(0.2020)\n", | |
| "epoch: 80\n", | |
| "sec/epoch: 3.6865620613098145\n", | |
| "train loss: tensor(0.5580)\n", | |
| "train cer: tensor(0.1158)\n", | |
| "dev loss: tensor(1.4025)\n", | |
| "dev cer: tensor(0.2222)\n", | |
| "epoch: 85\n", | |
| "sec/epoch: 3.6628165245056152\n", | |
| "train loss: tensor(0.4663)\n", | |
| "train cer: tensor(0.0924)\n", | |
| "dev loss: tensor(1.4166)\n", | |
| "dev cer: tensor(0.1741)\n", | |
| "epoch: 90\n", | |
| "sec/epoch: 3.6864466667175293\n", | |
| "train loss: tensor(0.4085)\n", | |
| "train cer: tensor(0.0800)\n", | |
| "dev loss: tensor(1.4796)\n", | |
| "dev cer: tensor(0.1821)\n", | |
| "epoch: 95\n", | |
| "sec/epoch: 3.664776563644409\n", | |
| "train loss: tensor(0.3649)\n", | |
| "train cer: tensor(0.0709)\n", | |
| "dev loss: tensor(1.4671)\n", | |
| "dev cer: tensor(0.2182)\n", | |
| "epoch: 100\n", | |
| "sec/epoch: 3.677140235900879\n", | |
| "train loss: tensor(0.3022)\n", | |
| "train cer: tensor(0.0581)\n", | |
| "dev loss: tensor(1.4241)\n", | |
| "dev cer: tensor(0.1634)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Training\n", | |
| "import time\n", | |
| "import itertools\n", | |
| "import editdistance\n", | |
| "\n", | |
| "# hyperparams\n", | |
| "n_batch = 16\n", | |
| "n_epoch = 100\n", | |
| "n_hid = 128\n", | |
| "lr = 0.001\n", | |
| "dropout = 0.0\n", | |
| "clip = 1.0\n", | |
| "device = torch.device('cuda')\n", | |
| "max_xlen = max(map(len, itertools.chain(trainset.wavs.values(), devset.wavs.values(), testset.wavs.values())))\n", | |
| "max_ylen = max(map(len, itertools.chain(trainset.texts.values(), devset.texts.values(), testset.texts.values())))\n", | |
| "print(\"max xlen:\", max_xlen)\n", | |
| "print(\"max ylen:\", max_ylen)\n", | |
| "print(\"#vocab:\", len(tokenizer))\n", | |
| "\n", | |
| "def collate(batch):\n", | |
| " # print(batch)\n", | |
| " xlen, ylen = [], []\n", | |
| " xpad = torch.zeros(n_batch, max_xlen)\n", | |
| " ypad = torch.zeros(n_batch, max_ylen)\n", | |
| " for i, (x, y) in enumerate(batch):\n", | |
| " xpad[i, :len(x)] = x\n", | |
| " ypad[i, :len(y)] = y\n", | |
| " xlen.append(len(x))\n", | |
| " ylen.append(len(y))\n", | |
| " return xpad, torch.tensor(xlen), ypad, torch.tensor(ylen)\n", | |
| "\n", | |
| "\n", | |
| "def cer(pred, plen, ypad, ylen):\n", | |
| " ids = pred.argmax(-1)\n", | |
| " err = 0\n", | |
| " n = 0\n", | |
| " for p, pl, y, yl in zip(ids, plen, ypad, ylen):\n", | |
| " p = torch.unique_consecutive(p[:pl])\n", | |
| " p = p[p != 0] # filter blank\n", | |
| " err += editdistance.eval(p, y[:yl])\n", | |
| " n += yl\n", | |
| " return err / n\n", | |
| "\n", | |
| "\n", | |
| "train_loader = torch.utils.data.DataLoader(trainset, n_batch, collate_fn=collate, shuffle=True, drop_last=True)\n", | |
| "dev_loader = torch.utils.data.DataLoader(devset, n_batch, collate_fn=collate, shuffle=False, drop_last=True)\n", | |
| "test_loader = torch.utils.data.DataLoader(testset, n_batch, collate_fn=collate, shuffle=False, drop_last=True)\n", | |
| "\n", | |
| "model = torch.jit.script(SpeechModel(len(tokenizer), n_hid=n_hid, dropout=dropout))\n", | |
| "model.to(device)\n", | |
| "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", | |
| "for epoch in range(n_epoch + 1):\n", | |
| " model.train()\n", | |
| " loss_hist = []\n", | |
| " cer_hist = []\n", | |
| " start = time.time()\n", | |
| " for xpad, xlen, ypad, ylen in train_loader:\n", | |
| " optimizer.zero_grad()\n", | |
| " hlen = model.convlen(xlen)\n", | |
| " loss, pred = model(xpad.to(device), hlen.to(device),\n", | |
| " ypad.to(device), ylen.to(device))\n", | |
| " loss.backward()\n", | |
| " nn.utils.clip_grad_norm_(model.parameters(), clip)\n", | |
| " optimizer.step()\n", | |
| " cer_hist.append(cer(pred, hlen, ypad, ylen))\n", | |
| " loss_hist.append(loss.float())\n", | |
| " if epoch % 5 == 0:\n", | |
| " print(\"epoch:\", epoch)\n", | |
| " print(\"sec/epoch:\", time.time() - start)\n", | |
| " print(\"train loss:\", torch.mean(torch.tensor(loss_hist)).float())\n", | |
| " print(\"train cer:\", torch.mean(torch.tensor(cer_hist)).float())\n", | |
| "\n", | |
| " loss_hist = []\n", | |
| " cer_hist = []\n", | |
| " model.eval()\n", | |
| " with torch.no_grad():\n", | |
| " for xpad, xlen, ypad, ylen in dev_loader:\n", | |
| " hlen = model.convlen(xlen)\n", | |
| " loss, pred = model(xpad.to(device), hlen.to(device),\n", | |
| " ypad.to(device), ylen.to(device))\n", | |
| " cer_hist.append(cer(pred, hlen, ypad, ylen))\n", | |
| " loss_hist.append(loss.float())\n", | |
| " if epoch % 5 == 0:\n", | |
| " print(\"dev loss:\", torch.mean(torch.tensor(loss_hist)).float())\n", | |
| " print(\"dev cer:\", torch.mean(torch.tensor(cer_hist)).float())\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "0525926a", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "pred: \n", | |
| "truth: YES\n", | |
| "pred: FIFB SFTY SIX\n", | |
| "truth: FIFTY ONE FIFTY SIX\n", | |
| "pred: U FIVE TWO ONE SEVEN\n", | |
| "truth: ONE FIVE TWO ONE SEVEN\n", | |
| "pred: ETER FIVE\n", | |
| "truth: ENTER FIVE\n", | |
| "pred: UBOU J IU TWE\n", | |
| "truth: RUBOUT J U I P THREE TWO EIGHT\n", | |
| "pred: J LR E\n", | |
| "truth: J E N N I F E R\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "model.cpu()\n", | |
| "model.eval()\n", | |
| "for i, (xs, xlen, ys, ylen) in enumerate(dev_loader):\n", | |
| " with torch.no_grad():\n", | |
| " x = xs[0, :xlen[0]]\n", | |
| " pred = model.inference(x)\n", | |
| "\n", | |
| " print(\"pred: \", tokenizer.decode(pred))\n", | |
| " print(\"truth:\", tokenizer.decode(ys[0, :ylen[0]]))\n", | |
| " if i == 5:\n", | |
| " break" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "742e96c7", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "torch.jit.save(model, 'ctc.pt')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "e2f818e1", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "loaded = torch.jit.load('ctc.pt')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "360beaf7", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "pred: \n", | |
| "truth: YES\n", | |
| "pred: FIFB SFTY SIX\n", | |
| "truth: FIFTY ONE FIFTY SIX\n", | |
| "pred: U FIVE TWO ONE SEVEN\n", | |
| "truth: ONE FIVE TWO ONE SEVEN\n", | |
| "pred: ETER FIVE\n", | |
| "truth: ENTER FIVE\n", | |
| "pred: UBOU J IU TWE\n", | |
| "truth: RUBOUT J U I P THREE TWO EIGHT\n", | |
| "pred: J LR E\n", | |
| "truth: J E N N I F E R\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for i, (xs, xlen, ys, ylen) in enumerate(dev_loader):\n", | |
| " with torch.no_grad():\n", | |
| " x = xs[0, :xlen[0]]\n", | |
| " pred = loaded.inference(x)\n", | |
| "\n", | |
| " print(\"pred: \", tokenizer.decode(pred))\n", | |
| " print(\"truth:\", tokenizer.decode(ys[0, :ylen[0]]))\n", | |
| " if i == 5:\n", | |
| " break" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "id": "f2e5dcbd", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Thu Jun 24 02:31:24 2021 \r\n", | |
| "+-----------------------------------------------------------------------------+\r\n", | |
| "| NVIDIA-SMI 460.80 Driver Version: 460.80 CUDA Version: 11.2 |\r\n", | |
| "|-------------------------------+----------------------+----------------------+\r\n", | |
| "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n", | |
| "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\r\n", | |
| "| | | MIG M. |\r\n", | |
| "|===============================+======================+======================|\r\n", | |
| "| 0 GeForce GTX 760 Off | 00000000:01:00.0 N/A | N/A |\r\n", | |
| "| 36% 48C P8 N/A / N/A | 121MiB / 1996MiB | N/A Default |\r\n", | |
| "| | | N/A |\r\n", | |
| "+-------------------------------+----------------------+----------------------+\r\n", | |
| "| 1 GeForce GTX 108... Off | 00000000:02:00.0 Off | N/A |\r\n", | |
| "| 0% 51C P8 10W / 250W | 2099MiB / 11178MiB | 0% Default |\r\n", | |
| "| | | N/A |\r\n", | |
| "+-------------------------------+----------------------+----------------------+\r\n", | |
| " \r\n", | |
| "+-----------------------------------------------------------------------------+\r\n", | |
| "| Processes: |\r\n", | |
| "| GPU GI CI PID Type Process name GPU Memory |\r\n", | |
| "| ID ID Usage |\r\n", | |
| "|=============================================================================|\r\n", | |
| "| 1 N/A N/A 1456 G /usr/lib/xorg/Xorg 4MiB |\r\n", | |
| "| 1 N/A N/A 25056 C ...net/tools/venv/bin/python 2091MiB |\r\n", | |
| "+-----------------------------------------------------------------------------+\r\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "!nvidia-smi" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "1c686df1", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.8.5" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment