Created
April 28, 2021 01:41
-
-
Save titu1994/273416f3d4104d237ae477a1e7526837 to your computer and use it in GitHub Desktop.
Finetuning recipe for Citrinet models
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
# Preparing the Tokenizer | |
Use the `create_tokenizer.py` script in order to prepare the tokenizer. | |
# Launch the fine tuning script | |
HYDRA_FULL_ERROR=1 python finetune_model.py \ | |
--config-path="configs/" \ | |
--config-name="stt_en_citrinet_512" \ | |
+model_pretrained_name="stt_en_citrinet_512" \ | |
+freeze_encoder=false \ | |
model.tokenizer.dir="<DIRECTORY TO TOKENIZER (not the full path to .model file, just the directory)>" \ | |
model.tokenizer.type="bpe" \ | |
model.train_ds.manifest_filepath="<PATH TO TRAIN MANIFEST>" \ | |
model.train_ds.batch_size=32 \ | |
+model.train_ds.num_workers=8 \ | |
+model.train_ds.pin_memory=true \ | |
model.validation_ds.manifest_filepath=["<PATH TO DEV SET>","<PATH TO TEST SET>"] \ | |
model.validation_ds.batch_size=8 \ | |
+model.validation_ds.num_workers=8 \ | |
+model.validation_ds.pin_memory=true \ | |
model.spec_augment.freq_masks=0 \ | |
model.spec_augment.time_masks=0 \ | |
model.optim.lr=0.01 \ | |
model.optim.name='novograd' \ | |
model.optim.betas=[0.8,0.25] \ | |
model.optim.weight_decay=0.001 \ | |
model.optim.sched.warmup_steps=1000 \ | |
model.optim.sched.min_lr=0.00001 \ | |
trainer.gpus=-1 \ | |
trainer.accelerator='ddp' \ | |
trainer.max_epochs=100 \ | |
trainer.check_val_every_n_epoch=1 \ | |
trainer.precision=32 \ | |
trainer.sync_batchnorm=false \ | |
trainer.benchmark=false \ | |
exp_manager.resume_if_exists=false \ | |
exp_manager.resume_ignore_no_checkpoint=false | |
""" | |
import torch | |
import torch.nn as nn | |
import pytorch_lightning as pl | |
from omegaconf import OmegaConf, open_dict | |
from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE | |
from nemo.core.config import hydra_runner | |
from nemo.utils import logging | |
from nemo.utils.exp_manager import exp_manager | |
def enable_bn_se(m): | |
if type(m) == nn.BatchNorm1d: | |
m.train() | |
for param in m.parameters(): | |
param.requires_grad_(True) | |
if 'SqueezeExcite' in type(m).__name__: | |
m.train() | |
for param in m.parameters(): | |
param.requires_grad_(True) | |
@hydra_runner(config_path="configs/", config_name="stt_en_citrinet_512") | |
def main(cfg): | |
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') | |
trainer = pl.Trainer(**cfg.trainer) | |
exp_manager(trainer, cfg.get("exp_manager", None)) | |
with open_dict(cfg): | |
model_name = cfg.pop('model_pretrained_name') | |
freeze_encoder = cfg.pop('freeze_encoder', False) | |
if 'stt_en_citrinet' not in model_name: | |
raise ValueError("`model_pretrained_name` must be a Citrinet model - `stt_en_citrinet_XYZ`," | |
"where XYZ can be {256, 512, 1024}") | |
# Load pretrained checkpoint | |
checkpoint = EncDecCTCModelBPE.from_pretrained( | |
model_name, map_location=torch.device('cpu') | |
) # type: EncDecCTCModelBPE | |
# Preserve the models decoder weights | |
decoder_ckpt_copy = checkpoint.decoder.state_dict() | |
# Load finetuning model | |
asr_model = EncDecCTCModelBPE(cfg=cfg.model, trainer=trainer) | |
# Load up weights (partially / fully) | |
# this allows decoder weights to be loaded if same shape as original citrinet (1024 subword encodings) | |
asr_model.load_state_dict(checkpoint.state_dict(), strict=False) | |
# Insert preserved model weights if shapes match | |
if decoder_ckpt_copy['decoder_layers.0.weight'].shape == asr_model.decoder.decoder_layers[0].weight.shape: | |
asr_model.decoder.load_state_dict(decoder_ckpt_copy) | |
logging.info("\n") | |
logging.info("Decoder shapes matched - restored weights from pretrained model") | |
logging.info("\n") | |
# release checkpoint memory | |
del checkpoint | |
# If freezing the encoder, unfreeze the batch norm and the squeeze and excite blocks | |
# for transfer learning | |
if freeze_encoder: | |
asr_model.encoder.freeze() | |
asr_model.encoder.apply(enable_bn_se) | |
# Train model | |
trainer.fit(asr_model) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment