Last active
November 7, 2022 03:36
-
-
Save titu1994/080c5387c4c02b41ce79dd4405d87104 to your computer and use it in GitHub Desktop.
NeMo Conformer Transducer on MCV Hindi
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"collapsed_sections": [] | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU", | |
"gpuClass": "standard" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# NeMo Conformer Transducer on MCV Hindi" | |
], | |
"metadata": { | |
"id": "nqaj9eMjDPCB" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "L4y7itGOancP" | |
}, | |
"outputs": [], | |
"source": [ | |
"!add-apt-repository -y ppa:jonathonf/ffmpeg-4\n", | |
"!apt update\n", | |
"!apt install -y ffmpeg\n", | |
"\n", | |
"!apt-get install libsndfile1\n", | |
"!pip install git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[all]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Install HF datasets and evaluate libraries\n", | |
"!pip install --upgrade datasets evaluate " | |
], | |
"metadata": { | |
"id": "rUo4fO_CdVba" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install --upgrade numba" | |
], | |
"metadata": { | |
"id": "cyziQX_2a3uN" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# **Note** \n", | |
"\n", | |
"This will restart the kernel, run the next cell after this" | |
], | |
"metadata": { | |
"id": "ODOq5qcaz0_f" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import os\n", | |
"\n", | |
"print(\"Restarting kernel to update numba !\")\n", | |
"os.kill(os.getpid(), 9)" | |
], | |
"metadata": { | |
"id": "MU2tNk-5bNYf" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Login to **Hugging Face** in order to download the dataset from MCV. You must visit this page (when signed in) and accept the terms of MCV to be able to download the dataset - https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0\n" | |
], | |
"metadata": { | |
"id": "VoHxgw9WzzMY" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from huggingface_hub import notebook_login\n", | |
"\n", | |
"notebook_login()" | |
], | |
"metadata": { | |
"id": "bmxjLWoUbW1n" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import nemo\n", | |
"import numba\n", | |
"print(\"NeMo: \", nemo.__version__)\n", | |
"print(\"Numba \", numba.__version__)" | |
], | |
"metadata": { | |
"id": "6m2R-rOGIqd1" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Prepare the dataset\n", | |
"\n", | |
"Due to Colab's slower CPU, it might take upwords of 20 minutes to run this data preprocessing step." | |
], | |
"metadata": { | |
"id": "0cniyrD80EuR" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Prepare HF dataset for NeMo usage\n", | |
"import os\n", | |
"import shutil\n", | |
"\n", | |
"if not os.path.exists(\"scripts/\"):\n", | |
" os.makedirs(\"scripts/\")\n", | |
"\n", | |
"HF_SCRIPT = \"https://raw.githubusercontent.com/NVIDIA/NeMo/main/scripts/speech_recognition/convert_hf_dataset_to_nemo.py\"\n", | |
"if not os.path.exists(\"scripts/convert_hf_dataset_to_nemo.py\"):\n", | |
" !wget -P scripts/ $HF_SCRIPT" | |
], | |
"metadata": { | |
"id": "Ez-7BbNXeASc" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Optionally, delete the data directory to reprocess\n", | |
"# if os.path.exists(\"datasets/Hindi/\"):\n", | |
"# shutil.rmtree(\"datasets/Hindi/\")\n", | |
"\n", | |
"os.makedirs(\"/content/datasets/Hindi\", exist_ok=True)\n", | |
"\n", | |
"!python scripts/convert_hf_dataset_to_nemo.py \\\n", | |
" output_dir=\"/content/datasets/Hindi/Train/\" \\\n", | |
" path=\"mozilla-foundation/common_voice_11_0\" \\\n", | |
" name=\"hi\" \\\n", | |
" split=\"train+validation\" \\\n", | |
" use_auth_token=True\n", | |
"\n", | |
"!python scripts/convert_hf_dataset_to_nemo.py \\\n", | |
" output_dir=\"/content/datasets/Hindi/Test/\" \\\n", | |
" path=\"mozilla-foundation/common_voice_11_0\" \\\n", | |
" name=\"hi\" \\\n", | |
" split=\"test\" \\\n", | |
" use_auth_token=True\n" | |
], | |
"metadata": { | |
"id": "0IkiRv4JdqmI" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Download scripts for inference and evaluation (if needed after download of file)" | |
], | |
"metadata": { | |
"id": "uQUUJRbe0MX1" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"if not os.path.exists(\"scripts/transcribe_speech.py\"):\n", | |
" !wget -P scripts/ https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/asr/transcribe_speech.py\n", | |
"\n", | |
"if not os.path.exists(\"scripts/speech_to_text_eval.py\"):\n", | |
" !wget -P scripts/ https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/asr/speech_to_text_eval.py " | |
], | |
"metadata": { | |
"id": "biHE7M2sofcJ" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Process the final manifest files" | |
], | |
"metadata": { | |
"id": "ovPj0H9P0SRX" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"TRAIN_MANIFEST = \"/content/datasets/Hindi/Train/mozilla-foundation/common_voice_11_0/hi/train+validation/train+validation_mozilla-foundation_common_voice_11_0_manifest.json\"\n", | |
"TEST_MANIFEST = \"/content/datasets/Hindi/Test/mozilla-foundation/common_voice_11_0/hi/test/test_mozilla-foundation_common_voice_11_0_manifest.json\"" | |
], | |
"metadata": { | |
"id": "KyxmothcpYy0" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from tqdm.auto import tqdm\n", | |
"import json\n", | |
"\n", | |
"def read_manifest(path):\n", | |
" manifest = []\n", | |
" with open(path, 'r') as f:\n", | |
" for line in tqdm(f, desc=\"Reading manifest data\"):\n", | |
" line = line.replace(\"\\n\", \"\")\n", | |
" data = json.loads(line)\n", | |
" manifest.append(data)\n", | |
" return manifest\n", | |
"\n", | |
"\n", | |
"def write_processed_manifest(data, original_path):\n", | |
" original_manifest_name = os.path.basename(original_path)\n", | |
" new_manifest_name = original_manifest_name.replace(\".json\", \"_processed.json\")\n", | |
"\n", | |
" manifest_dir = os.path.split(original_path)[0]\n", | |
" filepath = os.path.join(manifest_dir, new_manifest_name)\n", | |
" with open(filepath, 'w') as f:\n", | |
" for datum in tqdm(data, desc=\"Writing manifest data\"):\n", | |
" datum = json.dumps(datum)\n", | |
" f.write(f\"{datum}\\n\")\n", | |
" print(f\"Finished writing manifest: {filepath}\")\n", | |
" return filepath" | |
], | |
"metadata": { | |
"id": "FmP6h-J-o4mo" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"train_manifest_data = read_manifest(TRAIN_MANIFEST)\n", | |
"test_manifest_data = read_manifest(TEST_MANIFEST)" | |
], | |
"metadata": { | |
"id": "pWe3r8vSrh_8" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Huggingface data loader returns \"sentence\" but NeMo expects \"text\" so remap that.\n", | |
"\n", | |
"Also, process the audio filepaths to have the correct extention (this may not be needed depending on what NeMo version is installed)" | |
], | |
"metadata": { | |
"id": "XXm6YTyO0WwI" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"for sample in train_manifest_data:\n", | |
" sample['text'] = sample['sentence']\n", | |
" sample['audio_filepath'] = os.path.splitext(sample['audio_filepath'])[0] + '.wav'\n", | |
"\n", | |
"write_processed_manifest(train_manifest_data, TRAIN_MANIFEST)\n", | |
"\n", | |
"for sample in test_manifest_data:\n", | |
" sample['text'] = sample['sentence']\n", | |
" sample['audio_filepath'] = os.path.splitext(sample['audio_filepath'])[0] + '.wav'\n", | |
"\n", | |
"write_processed_manifest(test_manifest_data, TEST_MANIFEST)" | |
], | |
"metadata": { | |
"id": "euUTOuZjroSO" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"TRAIN_MANIFEST_CLEANED = \"/content/datasets/Hindi/Train/mozilla-foundation/common_voice_11_0/hi/train+validation/train+validation_mozilla-foundation_common_voice_11_0_manifest_processed.json\"\n", | |
"TEST_MANIFEST_CLEANED = \"/content/datasets/Hindi/Test/mozilla-foundation/common_voice_11_0/hi/test/test_mozilla-foundation_common_voice_11_0_manifest_processed.json\"" | |
], | |
"metadata": { | |
"id": "yY2KkUAqrste" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Setup Model and Trainer" | |
], | |
"metadata": { | |
"id": "bsSpB-qW0oUs" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Load model and trainer\n", | |
"import torch\n", | |
"import pytorch_lightning as ptl\n", | |
"import nemo.collections.asr as nemo_asr\n", | |
"from nemo.utils import logging" | |
], | |
"metadata": { | |
"id": "N6APlPUatTYC" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Load Conformer Transducer Large" | |
], | |
"metadata": { | |
"id": "8GqmOuKY0qwD" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model = nemo_asr.models.ASRModel.from_pretrained(\"nvidia/stt_en_conformer_transducer_large\")\n" | |
], | |
"metadata": { | |
"id": "zdl-R5cAtiSu" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Setup Trainer for training" | |
], | |
"metadata": { | |
"id": "VaUpxSIT0tcj" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"if torch.cuda.is_available():\n", | |
" accelerator = 'gpu'\n", | |
"else:\n", | |
" accelerator = 'cpu'\n", | |
"\n", | |
"trainer = ptl.Trainer(\n", | |
" devices=1, \n", | |
" accelerator=accelerator, \n", | |
" max_epochs=None,\n", | |
" max_steps=5000, \n", | |
" accumulate_grad_batches=1,\n", | |
" enable_checkpointing=False,\n", | |
" logger=False,\n", | |
" log_every_n_steps=10,\n", | |
" val_check_interval=500,\n", | |
")\n", | |
"\n", | |
"model.set_trainer(trainer)" | |
], | |
"metadata": { | |
"id": "YwhxmNiYt3R1" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Build Tokenizer\n", | |
"\n", | |
"Since we are going from English to Hindi, we need to change the tokenizer, load the pretrained weights again and then do Language Transfer training." | |
], | |
"metadata": { | |
"id": "1j9FWrVC00Iy" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Vocab size should match the original model's vocab size in order to easily load all the parameters of the pretrained model.\n", | |
"VOCAB_SIZE = model.tokenizer.vocab_size" | |
], | |
"metadata": { | |
"id": "Ya1aNBLut3dy" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"if not os.path.exists(\"scripts/process_asr_text_tokenizer.py\"):\n", | |
" !wget -P scripts/ https://raw.githubusercontent.com/NVIDIA/NeMo/main/scripts/tokenizers/process_asr_text_tokenizer.py" | |
], | |
"metadata": { | |
"id": "HSwtyhWAstgg" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"tokenizer_dir = \"/content/tokenizer/\" " | |
], | |
"metadata": { | |
"id": "fIoXz8xa7It5" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"if not os.path.exists(tokenizer_dir):\n", | |
" os.makedirs(tokenizer_dir)\n", | |
"\n", | |
"!python scripts/process_asr_text_tokenizer.py \\\n", | |
" --manifest=$TRAIN_MANIFEST_CLEANED,$TEST_MANIFEST_CLEANED \\\n", | |
" --vocab_size=$VOCAB_SIZE \\\n", | |
" --data_root=$tokenizer_dir \\\n", | |
" --tokenizer=\"spe\" \\\n", | |
" --spe_type=\"bpe\" \\\n", | |
" --spe_character_coverage=1.0 \\\n", | |
" --no_lower_case \\\n", | |
" --log" | |
], | |
"metadata": { | |
"id": "Nh7ZnK0Dt2r-" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"TOKENIZER_DIR = os.path.join(tokenizer_dir, f\"tokenizer_spe_bpe_v{VOCAB_SIZE}\")\n", | |
"print(\"Tokenizer dir :\", TOKENIZER_DIR)" | |
], | |
"metadata": { | |
"id": "MV-TkY5Bwecq" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Change vocabulary and restore weights" | |
], | |
"metadata": { | |
"id": "rzHv263I0-2Q" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Preserve the decoder parameters in case weight matching can be done later\n", | |
"pretrained_decoder = model.decoder.state_dict()\n", | |
"pretrained_joint = model.joint.state_dict()" | |
], | |
"metadata": { | |
"id": "zu7TgLVrtNCE" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model.change_vocabulary(new_tokenizer_dir=TOKENIZER_DIR, new_tokenizer_type=\"bpe\")" | |
], | |
"metadata": { | |
"id": "QR1uNSXswpvV" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Insert preserved model weights\n", | |
"model.decoder.load_state_dict(pretrained_decoder)\n", | |
"logging.info(\"Decoder shapes matched - restored weights from pre-trained model\")\n", | |
"\n", | |
"model.joint.load_state_dict(pretrained_joint)\n", | |
"logging.info(\"Joint shapes matched - restored weights from pre-trained model\")" | |
], | |
"metadata": { | |
"id": "spXPSzo_wtuu" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Setup Train/Val Data loaders" | |
], | |
"metadata": { | |
"id": "dHk0ZZoS1GGd" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import copy\n", | |
"from omegaconf import OmegaConf, open_dict" | |
], | |
"metadata": { | |
"id": "pfZeruM6xXu4" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"cfg = copy.deepcopy(model.cfg)" | |
], | |
"metadata": { | |
"id": "4PS8PJOQxlxa" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print(OmegaConf.to_yaml(cfg.train_ds))" | |
], | |
"metadata": { | |
"id": "KFaSciu_wxzU" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Setup train, validation, test configs\n", | |
"with open_dict(cfg):\n", | |
" # Train dataset\n", | |
" cfg.train_ds.manifest_filepath = f\"{TRAIN_MANIFEST_CLEANED}\"\n", | |
" cfg.train_ds.batch_size = 8\n", | |
" cfg.train_ds.is_tarred = False\n", | |
"\n", | |
" # Validation dataset\n", | |
" cfg.validation_ds.manifest_filepath = f\"{TEST_MANIFEST_CLEANED}\"\n", | |
" cfg.validation_ds.batch_size = 8\n", | |
"\n", | |
" # Test dataset\n", | |
" cfg.test_ds.manifest_filepath = f\"{TEST_MANIFEST_CLEANED}\"\n", | |
" cfg.test_ds.batch_size = 8\n" | |
], | |
"metadata": { | |
"id": "RSIQXyhbxgaQ" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# setup model with new configs\n", | |
"model.setup_training_data(cfg.train_ds)\n", | |
"model.setup_multiple_validation_data(cfg.validation_ds)\n", | |
"model.setup_multiple_test_data(cfg.test_ds)" | |
], | |
"metadata": { | |
"id": "Mn9dsXoryELk" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Setup Optimizer / Scheduler" | |
], | |
"metadata": { | |
"id": "R14DmDa-1Jpu" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print(OmegaConf.to_yaml(cfg.optim))\n" | |
], | |
"metadata": { | |
"id": "qkc4QfYlyGLa" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"with open_dict(model.cfg.optim):\n", | |
" # Note: LR here is selected according the recommendation of 1/10th the original training LR used for finetuning. \n", | |
" # May not be the best hyper parameter combination, no search has been done to optimize training.\n", | |
" model.cfg.optim.lr = 0.5 # Noam LR scaling factor. Ends up being peak lr close ~ 0.001.\n", | |
" model.cfg.optim.weight_decay = 0.001\n", | |
" model.cfg.optim.sched.warmup_steps = 500\n", | |
" model.cfg.optim.sched.warmup_ratio = None\n", | |
" model.cfg.optim.sched.min_lr = 1e-6\n", | |
"\n", | |
"model.setup_optimization(model.cfg.optim);" | |
], | |
"metadata": { | |
"id": "CDN4PsW1yKor" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Setup Spec Augment" | |
], | |
"metadata": { | |
"id": "2KWRh_lr1UDn" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"with open_dict(model.cfg.spec_augment):\n", | |
" model.cfg.spec_augment.freq_masks = 0 # Can be 2; then applies 2 frequency masks\n", | |
" model.cfg.spec_augment.freq_width = 25\n", | |
" model.cfg.spec_augment.time_masks = 0 # Can be 10; then applies 10 time masks, each upto 5% of the seq length.\n", | |
" model.cfg.spec_augment.time_width = 0.05\n", | |
"\n", | |
"model.spec_augmentation = model.from_config_dict(model.cfg.spec_augment)" | |
], | |
"metadata": { | |
"id": "h7Rmu938yWSU" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Setup Metric \n", | |
"\n", | |
"We dont need CER calculation, and to preserve speed we will disable prediction logging." | |
], | |
"metadata": { | |
"id": "VvZo1XDU1Yok" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#@title Metric\n", | |
"use_cer = False #@param [\"False\", \"True\"] {type:\"raw\"}\n", | |
"log_prediction = False #@param [\"False\", \"True\"] {type:\"raw\"}\n", | |
"\n", | |
"model.wer.use_cer = use_cer\n", | |
"model.wer.log_prediction = log_prediction\n", | |
"\n", | |
"# Finalize config for WandB\n", | |
"model.cfg = model.cfg" | |
], | |
"metadata": { | |
"id": "Xjrel830yh13" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## WandB Logging\n", | |
"\n", | |
"Really recommended for longer runs for experiment tracking. Replace `{API_KEY}` with your WandB key." | |
], | |
"metadata": { | |
"id": "bjt_bGTu1gQr" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# !wandb login {API_KEY}" | |
], | |
"metadata": { | |
"id": "YcrgVfcKysZ0" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Setup Experment Manager" | |
], | |
"metadata": { | |
"id": "BX2leHQj1sXY" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from nemo.utils import exp_manager\n", | |
"\n", | |
"# Environment variable generally used for multi-node multi-gpu training.\n", | |
"# In notebook environments, this flag is unnecessary and can cause logs of multiple training runs to overwrite each other.\n", | |
"os.environ.pop('NEMO_EXPM_VERSION', None)\n", | |
"\n", | |
"config = exp_manager.ExpManagerConfig(\n", | |
" exp_dir=f'experiments/lang-hindi/',\n", | |
" name=f\"ASR-Model-Language-Hindi\",\n", | |
" checkpoint_callback_params=exp_manager.CallbackParams(\n", | |
" monitor=\"val_wer\",\n", | |
" mode=\"min\",\n", | |
" always_save_nemo=True,\n", | |
" ),\n", | |
" create_wandb_logger=True,\n", | |
" wandb_logger_kwargs=OmegaConf.create({\n", | |
" \"name\": \"RNNT-Hindi\",\n", | |
" \"project\": \"RNNT-Hindi\",\n", | |
" }),\n", | |
")\n", | |
"\n", | |
"config = OmegaConf.structured(config)\n", | |
"logdir = exp_manager.exp_manager(trainer, config)\n" | |
], | |
"metadata": { | |
"id": "M6tsgUIFy8Ue" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Train\n", | |
"\n", | |
"Should get around 40% WER after 30-40 minutes of training. (Total runtime of notebook with 30 mins of dataloading ~ 1 hour)." | |
], | |
"metadata": { | |
"id": "HxHFOEaI1vgI" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%%time\n", | |
"trainer.fit(model)" | |
], | |
"metadata": { | |
"id": "oPQE9Ek5y5P9" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Validation\n", | |
"\n", | |
"Perform final validation step over the dataset" | |
], | |
"metadata": { | |
"id": "yqOU-7A880jJ" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"trainer.validate(model)" | |
], | |
"metadata": { | |
"id": "-s8ytf6m8vqZ" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Beam Search (Optional)\n", | |
"\n", | |
"Indic languages seem to benefit significantly from beam search, though Transducer Greedy search is often on par with beam search.\n", | |
"\n", | |
"Still, just as an experiment, we can evaluate the model with beam search using the `Modified Adaptive Beam Search` algorithm." | |
], | |
"metadata": { | |
"id": "KCoinkE888jh" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"decoding = model.cfg.decoding\n", | |
"print(OmegaConf.to_yaml(decoding))" | |
], | |
"metadata": { | |
"id": "6PO7Cmy28-qi" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Change decoding strategy to use Modified Adaptive Expansion Search with beam size of 4 \n", | |
"\n", | |
"**Note**: You dont need very large beam sizes with RNNT, and may dramatically slow down evaluation" | |
], | |
"metadata": { | |
"id": "Cs5fj-qv9Km9" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"decoding.strategy = \"maes\" # can be beam, tsd, alsd, maes\n", | |
"decoding.beam.beam_size = 4 # higher beam sizes may be significantly slower !\n", | |
"\n", | |
"decoding.fused_batch_size = -1 # disable RNNT fused batches during beam search inference\n", | |
"\n", | |
"model.change_decoding_strategy(decoding)" | |
], | |
"metadata": { | |
"id": "u35SnJNR9Gop" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Evaluate again - score will show up on WandB as \"test_wer\" or similar.\n", | |
"trainer.test(model)" | |
], | |
"metadata": { | |
"id": "sVc2XfUq9c01" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model.save_to(\"stt_hi_conformer_transducer_large.nemo\") # Remember to open the file browser on the left and download the file" | |
], | |
"metadata": { | |
"id": "iuxLRjt9zbcr" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"metadata": { | |
"id": "NjzIeW3mhwgS" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment