Skip to content

Instantly share code, notes, and snippets.

@fauxneticien
Last active November 1, 2022 18:04
Show Gist options
  • Save fauxneticien/7ff5692a1d47e920028193c07a2b4d61 to your computer and use it in GitHub Desktop.
Save fauxneticien/7ff5692a1d47e920028193c07a2b4d61 to your computer and use it in GitHub Desktop.
Evaluate wav2vec 2.0 model with language model using torchaudio's CTC decoder
import torch
import torchaudio
from torchaudio.models.decoder import ctc_decoder
from typing import List
bundle = torchaudio.pipelines.WAV2VEC2_ASR_LARGE_10M
acoustic_model = bundle.get_model()
acoustic_model.to('cuda')
ls = torchaudio.datasets.LIBRISPEECH('.', url='test-clean', download=True)
data_loader = torch.utils.data.DataLoader(
ls,
batch_size=1,
shuffle=False,
num_workers=4)
def create_subset_data_loader(loader, size_of_subset):
count = 0
for data in loader:
yield data
if count == size_of_subset:
break
count+=1
# Subset to smaller size for debugging
SUBSET_SIZE=100
# SUBSET_SIZE=len(data_loader)
subset_loader = create_subset_data_loader(data_loader, SUBSET_SIZE)
tokens = [label.lower() for label in bundle.get_labels()]
from torchaudio.models.decoder import download_pretrained_files
files = download_pretrained_files("librispeech-4-gram")
class GreedyCTCDecoder(torch.nn.Module):
def __init__(self, labels, blank=0):
super().__init__()
self.labels = labels
self.blank = blank
def forward(self, emission: torch.Tensor) -> List[str]:
"""Given a sequence emission over labels, get the best path
Args:
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns:
List[str]: The resulting transcript
"""
indices = torch.argmax(emission[0], dim=-1) # [num_seq,]
indices = torch.unique_consecutive(indices, dim=-1)
indices = [i for i in indices if i != self.blank]
joined = "".join([self.labels[i] for i in indices])
greedy_result = joined.replace("|", " ").strip().split()
greedy_transcript = " ".join(greedy_result)
return greedy_transcript
greedy_decoder = GreedyCTCDecoder(tokens)
import jiwer
from tqdm import tqdm
import numpy as np
logits = []
refs = []
for i, data in tqdm(enumerate(subset_loader), total=SUBSET_SIZE, ncols=100, desc='Inference on GPU'):
emission, _ = acoustic_model(data[0][0].to('cuda'))
emission = emission.cpu().detach()
logits.append(emission)
refs.append(' '.join(str(data[2][0]).lower().split()))
nolm_preds = [ greedy_decoder(l) for l in tqdm(logits, ncols=100, desc='No LM decoding') ]
print(f"Average WER HF, no LM: {round(jiwer.wer(refs, nolm_preds), 4)}")
from joblib import Parallel, delayed
def lm_eval(LM_WEIGHT, WORD_SCORE):
# LM_WEIGHT = 3.23
# WORD_SCORE = -0.26
beam_search_decoder = ctc_decoder(
lexicon="lexicon_120k-from-40MB-text.txt",
tokens=files.tokens,
lm="4-gram_400k.bin",
nbest=3,
beam_size=1500,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
def logits_to_preds(logits):
beam_search_result = beam_search_decoder(logits)
beam_search_transcript = " ".join(beam_search_result[0][0].words).strip()
return beam_search_transcript
from tqdm.contrib.concurrent import process_map
# use chunksize of 1 to send 1 logits matrix per worker to process
# (not optimal to send more since each does a slow beam search)
# wilm_preds = process_map(logits_to_preds, logits, chunksize=1, ncols=100, desc=f"LM decoding (LM_WEIGHT {LM_WEIGHT}, W_SCORE {WORD_SCORE})")
wilm_preds = Parallel(n_jobs=-1, verbose=0, prefer="threads")(delayed(logits_to_preds)(l) for l in tqdm(logits, ncols=100))
wilm_wer = round(jiwer.wer(refs, wilm_preds), 4)
print(f"Average WER HF, with LM (LM_WEIGHT {LM_WEIGHT}, W_SCORE {WORD_SCORE}): {wilm_wer}")
# import pandas as pd
# print(pd.DataFrame({
# 'ref' : refs,
# 'pred' : wilm_preds
# }))
return 1 - wilm_wer
from bayes_opt import BayesianOptimization
# Bounded region of parameter space
pbounds = {
'LM_WEIGHT': (0, 5),
'WORD_SCORE': (-5, 5)
}
optimizer = BayesianOptimization(
f=lm_eval,
pbounds=pbounds,
random_state=1,
)
optimizer.maximize(
init_points=2,
n_iter=10,
)
bayesian-optimization==1.3.1
bitsandbytes==0.35.3
datasets==2.6.1
jiwer==2.5.1
omegaconf==2.2.3
pyctcdecode==0.4.0
torch==1.13.0
torchaudio==0.13.0
transformers==4.23.1
@fauxneticien
Copy link
Author

fauxneticien commented Oct 31, 2022

requirements.txt

bitsandbytes==0.35.3
datasets==2.6.1
jiwer==2.5.1
omegaconf==2.2.3
pyctcdecode==0.4.0
torch==1.13.0
torchaudio==0.13.0
transformers==4.23.1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment