Last active
November 1, 2022 18:04
-
-
Save fauxneticien/7ff5692a1d47e920028193c07a2b4d61 to your computer and use it in GitHub Desktop.
Evaluate wav2vec 2.0 model with language model using torchaudio's CTC decoder
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torchaudio | |
from torchaudio.models.decoder import ctc_decoder | |
from typing import List | |
bundle = torchaudio.pipelines.WAV2VEC2_ASR_LARGE_10M | |
acoustic_model = bundle.get_model() | |
acoustic_model.to('cuda') | |
ls = torchaudio.datasets.LIBRISPEECH('.', url='test-clean', download=True) | |
data_loader = torch.utils.data.DataLoader( | |
ls, | |
batch_size=1, | |
shuffle=False, | |
num_workers=4) | |
def create_subset_data_loader(loader, size_of_subset): | |
count = 0 | |
for data in loader: | |
yield data | |
if count == size_of_subset: | |
break | |
count+=1 | |
# Subset to smaller size for debugging | |
SUBSET_SIZE=100 | |
# SUBSET_SIZE=len(data_loader) | |
subset_loader = create_subset_data_loader(data_loader, SUBSET_SIZE) | |
tokens = [label.lower() for label in bundle.get_labels()] | |
from torchaudio.models.decoder import download_pretrained_files | |
files = download_pretrained_files("librispeech-4-gram") | |
class GreedyCTCDecoder(torch.nn.Module): | |
def __init__(self, labels, blank=0): | |
super().__init__() | |
self.labels = labels | |
self.blank = blank | |
def forward(self, emission: torch.Tensor) -> List[str]: | |
"""Given a sequence emission over labels, get the best path | |
Args: | |
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`. | |
Returns: | |
List[str]: The resulting transcript | |
""" | |
indices = torch.argmax(emission[0], dim=-1) # [num_seq,] | |
indices = torch.unique_consecutive(indices, dim=-1) | |
indices = [i for i in indices if i != self.blank] | |
joined = "".join([self.labels[i] for i in indices]) | |
greedy_result = joined.replace("|", " ").strip().split() | |
greedy_transcript = " ".join(greedy_result) | |
return greedy_transcript | |
greedy_decoder = GreedyCTCDecoder(tokens) | |
import jiwer | |
from tqdm import tqdm | |
import numpy as np | |
logits = [] | |
refs = [] | |
for i, data in tqdm(enumerate(subset_loader), total=SUBSET_SIZE, ncols=100, desc='Inference on GPU'): | |
emission, _ = acoustic_model(data[0][0].to('cuda')) | |
emission = emission.cpu().detach() | |
logits.append(emission) | |
refs.append(' '.join(str(data[2][0]).lower().split())) | |
nolm_preds = [ greedy_decoder(l) for l in tqdm(logits, ncols=100, desc='No LM decoding') ] | |
print(f"Average WER HF, no LM: {round(jiwer.wer(refs, nolm_preds), 4)}") | |
from joblib import Parallel, delayed | |
def lm_eval(LM_WEIGHT, WORD_SCORE): | |
# LM_WEIGHT = 3.23 | |
# WORD_SCORE = -0.26 | |
beam_search_decoder = ctc_decoder( | |
lexicon="lexicon_120k-from-40MB-text.txt", | |
tokens=files.tokens, | |
lm="4-gram_400k.bin", | |
nbest=3, | |
beam_size=1500, | |
lm_weight=LM_WEIGHT, | |
word_score=WORD_SCORE, | |
) | |
def logits_to_preds(logits): | |
beam_search_result = beam_search_decoder(logits) | |
beam_search_transcript = " ".join(beam_search_result[0][0].words).strip() | |
return beam_search_transcript | |
from tqdm.contrib.concurrent import process_map | |
# use chunksize of 1 to send 1 logits matrix per worker to process | |
# (not optimal to send more since each does a slow beam search) | |
# wilm_preds = process_map(logits_to_preds, logits, chunksize=1, ncols=100, desc=f"LM decoding (LM_WEIGHT {LM_WEIGHT}, W_SCORE {WORD_SCORE})") | |
wilm_preds = Parallel(n_jobs=-1, verbose=0, prefer="threads")(delayed(logits_to_preds)(l) for l in tqdm(logits, ncols=100)) | |
wilm_wer = round(jiwer.wer(refs, wilm_preds), 4) | |
print(f"Average WER HF, with LM (LM_WEIGHT {LM_WEIGHT}, W_SCORE {WORD_SCORE}): {wilm_wer}") | |
# import pandas as pd | |
# print(pd.DataFrame({ | |
# 'ref' : refs, | |
# 'pred' : wilm_preds | |
# })) | |
return 1 - wilm_wer | |
from bayes_opt import BayesianOptimization | |
# Bounded region of parameter space | |
pbounds = { | |
'LM_WEIGHT': (0, 5), | |
'WORD_SCORE': (-5, 5) | |
} | |
optimizer = BayesianOptimization( | |
f=lm_eval, | |
pbounds=pbounds, | |
random_state=1, | |
) | |
optimizer.maximize( | |
init_points=2, | |
n_iter=10, | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
bayesian-optimization==1.3.1 | |
bitsandbytes==0.35.3 | |
datasets==2.6.1 | |
jiwer==2.5.1 | |
omegaconf==2.2.3 | |
pyctcdecode==0.4.0 | |
torch==1.13.0 | |
torchaudio==0.13.0 | |
transformers==4.23.1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
requirements.txt