Last active
August 17, 2024 04:40
-
-
Save johnmeade/c4de49429502771a304dd7b82e864838 to your computer and use it in GitHub Desktop.
Multi-language ASR using Huggingface transformer models.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Multi-language ASR using Huggingface transformer models. | |
Python dependencies: | |
pip install transformers==4.5.0 librosa soundfile torch | |
""" | |
from typing import NamedTuple | |
from functools import lru_cache | |
from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC | |
from librosa.core import resample | |
import soundfile as sf | |
import torch | |
class Model(NamedTuple): | |
lang: str | |
name: str | |
MODELS = { | |
# find and add more language paths here: | |
# https://huggingface.co/models?filter=wav2vec2,pytorch&pipeline_tag=automatic-speech-recognition | |
Model(lang="eng", name="facebook/wav2vec2-large-960h-lv60-self"), | |
Model(lang="fra", name="facebook/wav2vec2-large-xlsr-53-french"), | |
Model(lang="spa", name="facebook/wav2vec2-large-xlsr-53-spanish"), | |
Model(lang="nld", name="facebook/wav2vec2-large-xlsr-53-dutch"), | |
Model(lang="deu", name="facebook/wav2vec2-large-xlsr-53-german"), | |
} | |
W2V_SR = 16_000 | |
@lru_cache() | |
def _get_models(lang): | |
# find model name | |
matches = [m.name for m in MODELS if m.lang == lang] | |
if not any(matches): | |
raise ValueError("Could not find a model for this language") | |
name = matches[0] | |
# load model and tokenizer | |
tokenizer = Wav2Vec2Tokenizer.from_pretrained(name) | |
model = Wav2Vec2ForCTC.from_pretrained(name) | |
return tokenizer, model | |
def _loadwav(wavfn): | |
# load wav | |
wav, sr = sf.read(wavfn) | |
# ensure mono | |
if wav.ndim > 1: | |
wav = wav[:, 0] | |
# ensure samplerate | |
if sr != W2V_SR: | |
wav = resample(wav, sr, W2V_SR) | |
return wav | |
def rec_files(wavfns, lang): | |
# load wavs | |
wavs = list(map(_loadwav, wavfns)) | |
# load models | |
tokenizer, model = _get_models(lang) | |
# tokenize | |
input_values = tokenizer(wavs, return_tensors="pt", padding="longest").input_values | |
# retrieve logits | |
logits = model(input_values).logits | |
# take argmax and decode | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = tokenizer.batch_decode(predicted_ids) | |
return transcription |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment