Skip to content

Instantly share code, notes, and snippets.

@titu1994
Last active February 2, 2023 10:31
Show Gist options
  • Save titu1994/8c8d478a917cf62a6acd0f40af779f77 to your computer and use it in GitHub Desktop.
Save titu1994/8c8d478a917cf62a6acd0f40af779f77 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Compute ASR Normalization statistics"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import nemo\n",
"import nemo.collections.asr as nemo_asr"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Setup the dataset paths\n",
"\n",
"NeMo supports \",\" syntax to concatenate train ASR manifest files."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# TRAIN PATHS\n",
"TRAIN_CLEAN_100_PATH = \"train_clean_100.json\"\n",
"TRAIN_CLEAN_360_PATH = \"train_clean_360.json\"\n",
"TRAIN_OTHER_500_PATH = \"train_other_500.json\"\n",
"\n",
"manifest_filepath = \",\".join([TRAIN_CLEAN_100_PATH, TRAIN_CLEAN_360_PATH, TRAIN_OTHER_500_PATH])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load the model for which we want to compute statistics on the dataset\n",
"\n",
"Since each model can potentially have different preprocessing, it's better to load the model we want and then observe its config file."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for info in nemo_asr.models.EncDecCTCModel.list_available_models():\n",
" print(info)\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = nemo_asr.models.EncDecCTCModel.from_pretrained(\"QuartzNet15x5Base-En\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import copy\n",
"from omegaconf import OmegaConf\n",
"\n",
"cfg = copy.deepcopy(model.cfg)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Prepare dataset + dataloader"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Update validation dataset config\n",
"OmegaConf.set_struct(cfg, False)\n",
"\n",
"train_cfg = OmegaConf.create({\n",
" \"manifest_filepath\": manifest_filepath,\n",
" \"max_duration\": 16.7,\n",
" \"batch_size\": 128, # <-- change batch size to fit your GPU memory\n",
" \"sample_rate\": cfg.preprocessor.params.sample_rate,\n",
" \"labels\": cfg.decoder.params.vocabulary,\n",
" \"shuffle\": False,\n",
" \"num_workers\": 4, # <-- Change number of workers ~ number of physical CPU cores\n",
" \"pin_memory\": True,\n",
"})\n",
"\n",
"# We are using the test data loader, but with train dataset\n",
"cfg.test_ds = train_cfg\n",
"\n",
"OmegaConf.set_struct(cfg, True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Setup the data loader (this will take some time)\n",
"model.setup_multiple_test_data(cfg.test_ds)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Extract the data loader\n",
"dataloader = model._test_dl\n",
"print(\"Number of steps in dataloader :\", len(dataloader))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Prepare helper method to compute statistics"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from nemo.collections.asr.parts.features import FilterbankFeatures, normalize_batch, CONSTANT"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# See the source code of the forward step of normalization function\n",
"# %psource normalize_batch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Lets create a replica of the above, but it instead returns the mean and std\n",
"@torch.no_grad()\n",
"def compute_statistics(x, seq_len, normalize_type):\n",
" if normalize_type == \"per_feature\":\n",
" x_mean = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)\n",
" x_std = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)\n",
" for i in range(x.shape[0]):\n",
" if x[i, :, : seq_len[i]].shape[1] == 1:\n",
" raise ValueError(\n",
" \"normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result \"\n",
" \"in torch.std() returning nan\"\n",
" )\n",
" x_mean[i, :] = x[i, :, : seq_len[i]].mean(dim=1)\n",
" x_std[i, :] = x[i, :, : seq_len[i]].std(dim=1)\n",
" # make sure x_std is not zero\n",
" x_std += CONSTANT\n",
" return x_mean.unsqueeze(2), x_std.unsqueeze(2)\n",
" \n",
" elif normalize_type == \"all_features\":\n",
" x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)\n",
" x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)\n",
" for i in range(x.shape[0]):\n",
" x_mean[i] = x[i, :, : seq_len[i].item()].mean()\n",
" x_std[i] = x[i, :, : seq_len[i].item()].std()\n",
" # make sure x_std is not zero\n",
" x_std += CONSTANT\n",
" return x_mean.view(-1, 1, 1), x_std.view(-1, 1, 1)\n",
" elif \"fixed_mean\" in normalize_type and \"fixed_std\" in normalize_type:\n",
" x_mean = torch.tensor(normalize_type[\"fixed_mean\"], device=x.device)\n",
" x_std = torch.tensor(normalize_type[\"fixed_std\"], device=x.device)\n",
" return x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2), x_std.view(x.shape[0], x.shape[1]).unsqueeze(2)\n",
" else:\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Subclass FilterbankFeatures to return the statistics\n",
"\n",
"The filterbank featurizer does not return the mean and std values, but the actual normalized tensor \"x\".\n",
"\n",
"Since we have no need for \"x\" itself, we subclass and overwrite its forward() step to match the original, but at the end we return the computed mean and std instead of the normalized \"x\"."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class ExtractFilterbankFeatures(FilterbankFeatures):\n",
" \n",
" @torch.no_grad()\n",
" def forward(self, x, seq_len):\n",
" seq_len = self.get_seq_len(seq_len.float())\n",
"\n",
" # dither\n",
" if self.dither > 0:\n",
" x += self.dither * torch.randn_like(x)\n",
"\n",
" # do preemphasis\n",
" if self.preemph is not None:\n",
" x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1)\n",
"\n",
" # disable autocast to get full range of stft values\n",
" with torch.cuda.amp.autocast(enabled=False):\n",
" x = self.stft(x)\n",
"\n",
" # torch returns real, imag; so convert to magnitude\n",
" if not self.stft_conv:\n",
" x = torch.sqrt(x.pow(2).sum(-1))\n",
"\n",
" # get power spectrum\n",
" if self.mag_power != 1.0:\n",
" x = x.pow(self.mag_power)\n",
"\n",
" # dot with filterbank energies\n",
" x = torch.matmul(self.fb.to(x.dtype), x)\n",
"\n",
" # log features if required\n",
" if self.log:\n",
" if self.log_zero_guard_type == \"add\":\n",
" x = torch.log(x + self.log_zero_guard_value_fn(x))\n",
" elif self.log_zero_guard_type == \"clamp\":\n",
" x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x)))\n",
" else:\n",
" raise ValueError(\"log_zero_guard_type was not understood\")\n",
"\n",
" # frame splicing if required\n",
" if self.frame_splicing > 1:\n",
" x = splice_frames(x, self.frame_splicing)\n",
"\n",
" # Return Normalization values\n",
" if self.normalize:\n",
" mean, std = compute_statistics(x, seq_len, normalize_type=self.normalize)\n",
" return mean, std\n",
" else:\n",
" raise RuntimeError('This class is meant only to extract normalization values')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# See preprocessor kwargs\n",
"print(OmegaConf.to_yaml(cfg.preprocessor))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create an object of above using parameters from config\n",
"\n",
"sample_rate = cfg.preprocessor.params.sample_rate\n",
"\n",
"featurizer_cfg = OmegaConf.create(dict(\n",
" sample_rate=sample_rate,\n",
" n_window_size=int(cfg.preprocessor.params.window_size * sample_rate),\n",
" n_window_stride=int(cfg.preprocessor.params.window_stride * sample_rate),\n",
" window=cfg.preprocessor.params.window,\n",
" normalize=cfg.preprocessor.params.normalize,\n",
" n_fft=cfg.preprocessor.params.n_fft,\n",
" nfilt=cfg.preprocessor.params.features,\n",
" dither=0.0, # <-- We do not want dither when computing statistics\n",
" pad_to=0, # <-- Nor do we want padding to impact the statistics\n",
" frame_splicing=cfg.preprocessor.params.frame_splicing,\n",
" stft_conv=cfg.preprocessor.params.stft_conv\n",
"))\n",
"\n",
"featurizer = ExtractFilterbankFeatures(**featurizer_cfg)\n",
"featurizer = featurizer.to(model.device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Computing online statistics\n",
"\n",
"We use [Welford's algorithm](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance) to compute the global average of the `mean` and `std` vectors. \n",
"\n",
"An online algorithm is preferred as it can be run over datasets of any size while using constant memory."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@torch.no_grad()\n",
"@torch.jit.script\n",
"def welford_step(mean: torch.Tensor, std: torch.Tensor, mean_buffer: torch.Tensor, std_buffer: torch.Tensor, count: torch.Tensor):\n",
" # Get back to [B, F] shape, compute Welford step over B.\n",
" mean = mean.squeeze(-1)\n",
" std = std.squeeze(-1)\n",
" \n",
" for i in range(mean.size(0)):\n",
" count += 1\n",
"\n",
" delta = (mean[i, :] - mean_buffer)\n",
" mean_buffer.add_(delta / count)\n",
"\n",
" delta = (std[i, :] - std_buffer)\n",
" std_buffer.add_(delta / count)\n",
" \n",
" return count, mean_buffer, std_buffer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from tqdm import tqdm\n",
"import os\n",
"\n",
"\n",
"@torch.no_grad()\n",
"def compute_dataset_statistics(model, dataloader, normalize_type='per_feature', stats_path=\"dataset_statistics\"):\n",
" dl = iter(dataloader)\n",
" \n",
" count = None\n",
" mean_buffer = None\n",
" std_buffer = None\n",
"\n",
" for batch in tqdm(range(len(dataloader))):\n",
" # Uncomment these two lines to run for a few steps only\n",
" # if batch >= 5:\n",
" # break\n",
" \n",
" # Get next batch\n",
" input_signal, input_signal_length, _, _ = next(dl)\n",
" \n",
" # Move data to model's device\n",
" input_signal = input_signal.to(model.device)\n",
" input_signal_length = input_signal_length.to(model.device)\n",
" \n",
" # Compute statistics for this batch\n",
" mean, std = featurizer(input_signal, input_signal_length)\n",
" \n",
" # Prepare buffers for Welford's algorithm step (if first step only)\n",
" if count is None:\n",
" count = torch.tensor(0, device=model.device, dtype=torch.int64)\n",
" \n",
" if mean_buffer is None:\n",
" mean_buffer = torch.zeros(mean.shape[1], device=mean.device, dtype=mean.dtype)\n",
" \n",
" if std_buffer is None:\n",
" std_buffer = torch.zeros(std.shape[1], device=std.device, dtype=std.dtype)\n",
" \n",
" # Compute Welford's step to update buffers\n",
" count, mean_buffer, std_buffer = welford_step(mean, std, mean_buffer, std_buffer, count)\n",
" \n",
" # Preserve buffers\n",
" mean_buffer = mean_buffer.to('cpu').numpy()\n",
" std_buffer = std_buffer.to('cpu').numpy()\n",
" \n",
" if not os.path.exists(stats_path):\n",
" os.makedirs(stats_path)\n",
" \n",
" np.save(os.path.join(stats_path, \"mean.npy\"), mean_buffer)\n",
" np.save(os.path.join(stats_path, \"std.npy\"), std_buffer)\n",
" print(f\"Statistics saved at {stats_path}\")\n",
" \n",
" return mean_buffer, std_buffer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if torch.cuda.is_available():\n",
" device = torch.device('cuda')\n",
"else:\n",
" device = torch.device('cpu')\n",
"\n",
"model = model.to(device)\n",
"\n",
"stats_path = \"dataset_statistics\"\n",
"compute_dataset_statistics(model, dataloader, stats_path=stats_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mean = np.load(os.path.join(stats_path, \"mean.npy\"))\n",
"std = np.load(os.path.join(stats_path, \"std.npy\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.7.6 64-bit ('NeMo': conda)",
"language": "python",
"name": "python37664bitnemoconda43f94a748a2e4953b0129556ecdf4f62"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
@ridasaleem0
Copy link

Hey i've been exploring your notebook to compute normalization statistics for Citrinet model, can you please probably clear how we can use manifest path, do we need to download the datasets or what?

@titu1994
Copy link
Author

Yes you will need to have the dataset + it's manifest file in order to calculate the dataset statistics.

@ridasaleem0
Copy link

Okay, I am basically looking to calculate dataset statistics for "stt_zh_citrinet_1024_gamma_0_25_1.0.0" model, since it has been trained on Multilingual LibriSpeech English corpus (pre-training) and Aishell-2 corpus (fine-tuning), i am not sure where to get the manifest file for it.

@titu1994
Copy link
Author

For such models, if you don't have the datasets it might be valuable to simply run https://github.com/NVIDIA/NeMo/blob/main/tutorials/asr/Streaming_ASR.ipynb with large buffer sizes (it won't require pre-calculation of the dataset statistics then)

@ridasaleem0
Copy link

Can this be used for real-time asr with microphone? I am specifically looking for offline microphone asr solution.

@titu1994
Copy link
Author

Not for realtime. Jarvis would be a proper production toolkit for streaming (real time) ASR. In Nemo we have buffered audio (the notebook above) and streaming audio is not perfect support

@ridasaleem0
Copy link

I have Jetson xavier and nano, as far as I know Jarvis is not compatible with Jetson for now.

@titu1994
Copy link
Author

It is not compatible for now. I don't think Nemo supports ASR on Jetson either

@ridasaleem0
Copy link

Ah seems like a pickle, anyways thank you so much for your assistance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment