Created
February 23, 2024 18:29
-
-
Save aksh-at/fb14599c28a3bc0f907ea45398a7651d to your computer and use it in GitHub Desktop.
Insanely fast whisper on Modal
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import base64 | |
import tempfile | |
from typing import Optional | |
from pydantic import BaseModel | |
from modal import Image, Secret, Stub, build, enter, gpu, web_endpoint | |
whisper_image = ( | |
Image.micromamba() | |
.apt_install("ffmpeg", "ninja-build", "git") | |
.micromamba_install( | |
"cudatoolkit=11.8", | |
"cudnn=8.1.0", | |
"cuda-nvcc", | |
channels=["conda-forge", "nvidia"], | |
) | |
.pip_install( | |
"torch==2.0.1", | |
"tqdm==4.66.1", | |
"more-itertools==10.1.0", | |
"transformers==4.37.2", | |
"ffmpeg-python==0.2.0", | |
"openai-whisper==20231106", | |
"optimum==1.14.0", | |
"pyannote-audio==3.1.0", | |
"rich==13.7.0", | |
) | |
.pip_install("packaging") | |
.run_commands("pip install flash-attn==2.5.2 --no-build-isolation") | |
.run_commands( | |
"git clone https://github.com/Vaibhavs10/insanely-fast-whisper.git", | |
# Pin the commit version. | |
"cd insanely-fast-whisper && git checkout ff0df400f4aed859375c2507ebcc21fe5f9b99e0", | |
) | |
) | |
# Named to have the label prefix "modal-labs--instant". | |
stub = Stub("instant-whisper") | |
with whisper_image.imports(): | |
import sys | |
import torch | |
# The folders are missing `__init__.py` files, so we need to add them to the path. | |
sys.path.append("/insanely-fast-whisper/src/insanely_fast_whisper/utils") | |
from diarize import ( | |
diarize_audio as diarize_audio_func, | |
post_process_segments_and_transcripts, | |
preprocess_inputs, | |
) | |
from pyannote.audio import Pipeline | |
from transformers import ( | |
WhisperFeatureExtractor, | |
WhisperForConditionalGeneration, | |
WhisperTokenizerFast, | |
pipeline, | |
) | |
class TranscriptionRequest(BaseModel): | |
audio: str | |
language: Optional[str] = None | |
diarize_audio: bool = False | |
batch_size: int = 24 | |
@stub.cls( | |
gpu=gpu.A10G(), | |
# To avoid excessive cold-starts, we set the idle timeout to two minutes | |
container_idle_timeout=120, | |
keep_warm=1, | |
image=whisper_image, | |
# TODO: reconcile hugging face secrets. | |
secrets=[Secret.from_name("huggingface-secret-2")], | |
) | |
class Model: | |
@build() | |
@enter() | |
def setup(self): | |
model_id = "openai/whisper-large-v3" | |
torch_dtype = torch.float16 | |
self.device = "cuda:0" | |
model = WhisperForConditionalGeneration.from_pretrained( | |
model_id, | |
torch_dtype=torch_dtype, | |
).to(self.device) | |
tokenizer = WhisperTokenizerFast.from_pretrained(model_id) | |
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_id) | |
self.pipe = pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=tokenizer, | |
feature_extractor=feature_extractor, | |
model_kwargs={"use_flash_attention_2": True}, | |
torch_dtype=torch_dtype, | |
device=self.device, | |
) | |
self.diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1") | |
self.diarization_pipeline.to(torch.device(self.device)) | |
@web_endpoint(method="POST", label="instant-whisper") | |
def transcribe(self, request: TranscriptionRequest): | |
"""Transcribes and optionally translates a single audio file""" | |
with tempfile.NamedTemporaryFile(suffix=".wav") as temp_audio: | |
audio_data = base64.b64decode(request.audio.split(",")[1]) | |
temp_audio.write(audio_data) | |
outputs = self.pipe( | |
temp_audio.name, | |
chunk_length_s=30, | |
batch_size=request.batch_size, | |
generate_kwargs={ | |
"task": "transcribe", | |
"language": None if request.language == "" else request.language, | |
}, | |
return_timestamps="word", | |
) | |
if not request.diarize_audio: | |
return outputs | |
inputs, diarizer_inputs = preprocess_inputs(inputs=temp_audio.name) | |
segments = diarize_audio_func(diarizer_inputs, self.diarization_pipeline) | |
segmented_transcript = post_process_segments_and_transcripts( | |
segments, outputs["chunks"], group_by_speaker=False | |
) | |
outputs["chunks"] = segmented_transcript | |
return outputs |
Thanks @jflam for the catch and investigation! Did not realize FA2 was not enabled, my bad.
It is indeed quite annoying if it changes behavior once serialized. We'll look into it ourselves to see if it's fixable.
The code doesn't work anymore, there's an error on line 89 with .to(self.device)
:
AttributeError("'NoneType' object has no attribute 'to'")
The code doesn't work anymore, there's an error on line 89 with
.to(self.device)
:AttributeError("'NoneType' object has no attribute 'to'")
Encountered the same issue, you need to:
- Accept pyannote/segmentation-3.0 user conditions
- Accept pyannote/speaker-diarization-3.1 user conditions
- Create an access token at hf.co/settings/tokens.
- Setup that secret on Modal and auth with it like so:
self.diarization_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=Secret.from_name("your-hugging-face-secret")
)
thank you for this!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thank you for helping me get flash attention 2 installed on modal and for some cool tips (I didn't know about whisper_image.imports()!)
You do have a bug though:
model_kwargs
is not enabling flash attention 2 correctly. The correct parameter is"attn_implementation": "flash_attention_2"
. In my test case I could see this effect by seeing transcription time drop from 117s to 90s - and I see that word timestamps are also correctly emitted.However, I want to point out that
return_timestamps="word"
is not compatible with flash attention 2, in the case where you serialize the pipeline to disk to reduce startup time via a @build method. If I load the model in the transcribe method, it works fine and returns word timestamps. When I try to load a serialized pipeline, it throws an "WhisperFlashAttention2 attention does not support output_attentions" exception.I have no idea why, but this is my observation this morning. I really do not like the layers of wrappers on top of wrappers that this world seems to enjoy.
Unless you're doing a bunch of batch transcriptions where you can amortize model load time, the inability to serialize the pipeline into the image makes it not worthwhile enabling flash attention 2 as the model load time will eliminate any perf gains from FA2.