Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Created November 9, 2023 03:14
Show Gist options
  • Save pszemraj/b1caed06b8b0fada4c9ec9f33ec5351f to your computer and use it in GitHub Desktop.
Save pszemraj/b1caed06b8b0fada4c9ec9f33ec5351f to your computer and use it in GitHub Desktop.
# Function to process audio using distil-whisper
from pathlib import Path
import logging
from typing import Optional, Union
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
# Function to process audio using distil-whisper
def process_audio_distil_whisper(
audio_path: Union[str, Path],
output_dir: Union[str, Path],
recompute: bool = False,
min_chars: int = 1000,
remove_unprocessable: bool = True,
device: str = "cuda:0" if torch.cuda.is_available() else "cpu",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
model_id: str = "distil-whisper/distil-medium.en",
) -> Optional[Path]:
audio_path = Path(audio_path)
output_dir = Path(output_dir)
output_path = output_dir / (audio_path.stem + ".txt")
if not audio_path.exists():
logging.warning(f"Audio file {audio_path} does not exist. Returning None")
return None
if not recompute and output_path.exists():
logging.debug(
f"Skipping {audio_path} as output already exists. Use --recompute to override."
)
return output_path
if output_path.exists():
logging.info(f"found {audio_path}, removing and re-processing")
output_path.unlink()
# Initialize the model and processor
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
# Initialize the pipeline
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=15,
batch_size=16,
torch_dtype=torch_dtype,
device=device,
)
try:
# The audio file needs to be loaded into a format the model can process
# This step will vary depending on your dataset format and may require additional code
# For this example, we will assume the audio file is already in a suitable format
audio_input = {"array": audio_path.read_bytes(), "sampling_rate": 16_000}
result = pipe(audio_input)
text_transcript = result["text"]
except Exception as e:
logging.error(f"Transcription failed: {e}")
if remove_unprocessable:
audio_path.unlink()
return None
if len(text_transcript) < min_chars:
logging.warning(
f"Transcript of {audio_path.name} shorter than {min_chars} chars, skipping"
)
if remove_unprocessable:
audio_path.unlink()
return None
with output_path.open("w", encoding="utf-8") as f:
f.write(text_transcript)
return output_path.resolve()
# The part of loading the audio file and preparing it for the model might need additional steps
# depending on the audio file format and the expected input format of the model.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment