Created
November 9, 2023 03:14
-
-
Save pszemraj/b1caed06b8b0fada4c9ec9f33ec5351f to your computer and use it in GitHub Desktop.
# Function to process audio using distil-whisper
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
import logging | |
from typing import Optional, Union | |
import torch | |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | |
# Function to process audio using distil-whisper | |
def process_audio_distil_whisper( | |
audio_path: Union[str, Path], | |
output_dir: Union[str, Path], | |
recompute: bool = False, | |
min_chars: int = 1000, | |
remove_unprocessable: bool = True, | |
device: str = "cuda:0" if torch.cuda.is_available() else "cpu", | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
model_id: str = "distil-whisper/distil-medium.en", | |
) -> Optional[Path]: | |
audio_path = Path(audio_path) | |
output_dir = Path(output_dir) | |
output_path = output_dir / (audio_path.stem + ".txt") | |
if not audio_path.exists(): | |
logging.warning(f"Audio file {audio_path} does not exist. Returning None") | |
return None | |
if not recompute and output_path.exists(): | |
logging.debug( | |
f"Skipping {audio_path} as output already exists. Use --recompute to override." | |
) | |
return output_path | |
if output_path.exists(): | |
logging.info(f"found {audio_path}, removing and re-processing") | |
output_path.unlink() | |
# Initialize the model and processor | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True | |
) | |
model.to(device) | |
processor = AutoProcessor.from_pretrained(model_id) | |
# Initialize the pipeline | |
pipe = pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
max_new_tokens=128, | |
chunk_length_s=15, | |
batch_size=16, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
try: | |
# The audio file needs to be loaded into a format the model can process | |
# This step will vary depending on your dataset format and may require additional code | |
# For this example, we will assume the audio file is already in a suitable format | |
audio_input = {"array": audio_path.read_bytes(), "sampling_rate": 16_000} | |
result = pipe(audio_input) | |
text_transcript = result["text"] | |
except Exception as e: | |
logging.error(f"Transcription failed: {e}") | |
if remove_unprocessable: | |
audio_path.unlink() | |
return None | |
if len(text_transcript) < min_chars: | |
logging.warning( | |
f"Transcript of {audio_path.name} shorter than {min_chars} chars, skipping" | |
) | |
if remove_unprocessable: | |
audio_path.unlink() | |
return None | |
with output_path.open("w", encoding="utf-8") as f: | |
f.write(text_transcript) | |
return output_path.resolve() | |
# The part of loading the audio file and preparing it for the model might need additional steps | |
# depending on the audio file format and the expected input format of the model. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment