Last active
April 5, 2025 11:41
-
-
Save ilovefreesw/8b507f0000da3a629779c5b0632479e7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import io | |
from typing import Any, Dict, List | |
import moviepy.editor as mp | |
import numpy as np | |
import streamlit as st | |
import torch | |
import torchaudio | |
import torchaudio.transforms as T | |
from streamlit_mic_recorder import mic_recorder | |
from transformers import AutomaticSpeechRecognitionPipeline, AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | |
# Streamlit app | |
st.title("CrisperWhisper++ 🦻") | |
st.subheader("Caution when using. Make sure you can handle the crispness. ⚠️") | |
st.write("🎙️ Record an audio to transcribe or 📁 upload an audio file.") | |
# Model ID input | |
model_id = st.text_input("Enter model ID (e.g., openai/whisper-small)", value="openai/whisper-small") | |
if not model_id: | |
st.warning("Please enter a model ID to proceed.") | |
st.stop() | |
# Device and dtype setup | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
st.write(f"Using device: {device}") | |
# Load model and processor with caching and error handling | |
@st.cache_resource | |
def load_model_and_processor(model_id: str): | |
try: | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True | |
) | |
model.to(device) | |
model.generation_config.median_filter_width = 3 | |
processor = AutoProcessor.from_pretrained(model_id) | |
return model, processor | |
except Exception as e: | |
st.error(f"Failed to load model: {e}") | |
raise | |
# Setup pipeline | |
@st.cache_resource | |
def setup_pipeline(_model: AutoModelForSpeechSeq2Seq, _processor: AutoProcessor) -> AutomaticSpeechRecognitionPipeline: | |
try: | |
return pipeline( | |
"automatic-speech-recognition", | |
model=_model, | |
tokenizer=_processor.tokenizer, | |
feature_extractor=_processor.feature_extractor, | |
chunk_length_s=30, | |
batch_size=1, | |
return_timestamps=True, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
except Exception as e: | |
st.error(f"Failed to setup pipeline: {e}") | |
raise | |
# Load model only when ready | |
if st.button("Load Model"): | |
with st.spinner("Loading model... This may take a moment."): | |
model, processor = load_model_and_processor(model_id) | |
pipe = setup_pipeline(model, processor) | |
st.success("Model loaded successfully!") | |
else: | |
st.info("Click 'Load Model' to initialize the transcription system.") | |
pipe = None | |
# Audio processing functions | |
def process_audio_bytes(audio_bytes: bytes) -> torch.Tensor: | |
try: | |
audio_stream = io.BytesIO(audio_bytes) | |
waveform, sr = torchaudio.load(audio_stream) | |
transform = T.Resample(sr, 16000) | |
waveform = transform(waveform) | |
waveform = waveform / waveform.abs().max() | |
torchaudio.save("sample.wav", waveform, 16000) | |
return waveform | |
except Exception as e: | |
st.error(f"Audio processing failed: {e}") | |
raise | |
def transcribe(audio_bytes: bytes, pipe: AutomaticSpeechRecognitionPipeline) -> Dict[str, Any]: | |
waveform = process_audio_bytes(audio_bytes) | |
transcription = pipe(waveform[0].numpy(), return_timestamps="word") | |
return transcription | |
def timestamps_to_vtt(timestamps: List[Dict[str, Any]]) -> str: | |
vtt_content = "WEBVTT\n\n" | |
for word in timestamps: | |
start_time, end_time = word["timestamp"] | |
start_time_str = f"{int(start_time // 3600):02d}:{int(start_time // 60 % 60):02d}:{start_time % 60:06.3f}" | |
end_time_str = f"{int(end_time // 3600):02d}:{int(end_time // 60 % 60):02d}:{end_time % 60:06.3f}" | |
vtt_content += f"{start_time_str} --> {end_time_str}\n{word['text']}\n\n" | |
return vtt_content | |
def wav_to_black_mp4(wav_path: str, output_path: str, fps: int = 25) -> None: | |
waveform, sample_rate = torchaudio.load(wav_path) | |
duration = waveform.shape[1] / sample_rate | |
audio = mp.AudioFileClip(wav_path) | |
black_clip = mp.ColorClip((256, 250), color=(0, 0, 0), duration=duration) | |
final_clip = black_clip.set_audio(audio) | |
final_clip.write_videofile(output_path, fps=fps, logger=None) | |
# Audio input | |
audio = mic_recorder(start_prompt="Start recording", stop_prompt="Stop recording", key="recorder", format="wav") | |
audio_bytes = audio["bytes"] if audio and "bytes" in audio else None | |
audio_file = st.file_uploader("Or upload an audio file", type=["wav", "mp3", "ogg"]) | |
if audio_file is not None: | |
audio_bytes = audio_file.getvalue() | |
# Process audio if available | |
if audio_bytes and pipe: | |
try: | |
with st.spinner("Transcribing audio..."): | |
transcription = transcribe(audio_bytes, pipe) | |
vtt = timestamps_to_vtt(transcription["chunks"]) | |
with open("subtitles.vtt", "w") as file: | |
file.write(vtt) | |
wav_to_black_mp4("sample.wav", "video.mp4") | |
st.video("video.mp4", subtitles="subtitles.vtt") | |
st.subheader("Transcription:") | |
st.markdown( | |
f"<div style='background-color: #f0f0f0; padding: 10px; border-radius: 5px;'><p style='font-size: 16px; color: #333;'>{transcription['text']}</p></div>", | |
unsafe_allow_html=True, | |
) | |
except Exception as e: | |
st.error(f"An error occurred during transcription: {e}") | |
elif audio_bytes and not pipe: | |
st.warning("Please load the model first by clicking 'Load Model'.") | |
# Footer | |
st.markdown( | |
"<hr><footer><p style='text-align: center;'>© 2024 nyra health GmbH</p></footer>", | |
unsafe_allow_html=True, | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import requests | |
import torch | |
import io | |
import soundfile as sf | |
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig | |
from urllib.request import urlopen | |
import torchaudio | |
# Define model path | |
model_path = "microsoft/Phi-4-multimodal-instruct" | |
# Load model and processor | |
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
device_map="cuda", | |
torch_dtype="auto", | |
trust_remote_code=True, | |
attn_implementation='flash_attention_2', | |
).cuda() | |
# Load generation config | |
generation_config = GenerationConfig.from_pretrained(model_path) | |
# Define prompt structure | |
user_prompt = '<|user|>' | |
assistant_prompt = '<|assistant|>' | |
prompt_suffix = '<|end|>' | |
# Function to get speech segments using Silero VAD | |
def get_speech_segments(audio, sample_rate): | |
model, utils = torch.hub.load('snakers4/silero-vad', 'silero_vad', source='github') | |
(get_speech_timestamps, _, _, _, _) = utils | |
if not isinstance(audio, torch.Tensor): | |
audio = torch.tensor(audio, dtype=torch.float32) | |
if audio.dim() > 1 and audio.shape[1] > 1: # Convert stereo to mono | |
audio = torch.mean(audio, dim=1, keepdim=False) | |
if sample_rate != 16000: | |
resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
audio = resampler(audio) | |
sample_rate = 16000 | |
speech_timestamps = get_speech_timestamps(audio, model, sampling_rate=sample_rate, return_seconds=True) | |
return speech_timestamps, audio, sample_rate | |
# Function to split segments into 30-second chunks | |
def split_into_chunks(timestamps, max_duration=30.0): | |
chunked_timestamps = [] | |
for segment in timestamps: | |
start = segment['start'] | |
end = segment['end'] | |
duration = end - start | |
if duration <= max_duration: | |
chunked_timestamps.append({'start': start, 'end': end}) | |
else: | |
current_start = start | |
while current_start < end: | |
current_end = min(current_start + max_duration, end) | |
chunked_timestamps.append({'start': current_start, 'end': current_end}) | |
current_start = current_end | |
return chunked_timestamps | |
# Function to transcribe a segment of audio | |
def transcribe_segment(audio, sample_rate, start, end): | |
start_sample = int(start * sample_rate) | |
end_sample = int(end * sample_rate) | |
segment = audio[start_sample:end_sample].numpy() # Convert back to numpy for processor | |
speech_prompt = "Transcribe the audio to text, keep it verbatim and process all audio" | |
prompt = f'{user_prompt}<|audio_1|>{speech_prompt}{prompt_suffix}{assistant_prompt}' | |
inputs = processor(text=prompt, audios=[(segment, sample_rate)], return_tensors='pt').to('cuda:0') | |
generate_ids = model.generate( | |
**inputs, | |
max_new_tokens=128000, | |
generation_config=generation_config, | |
) | |
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:] | |
response = processor.batch_decode( | |
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
)[0] | |
return response | |
# Main execution | |
print("\n--- AUDIO PROCESSING ---") | |
audio_url = "https://device-mattreXXXXXXX-1355.mp3" # Your audio URL | |
print(f"Downloading audio from: {audio_url}") | |
# Download and open audio file | |
audio, samplerate = sf.read(io.BytesIO(urlopen(audio_url).read())) | |
# Get speech segments and split into 30-second chunks | |
speech_timestamps, audio_tensor, sample_rate = get_speech_segments(audio, samplerate) | |
chunked_timestamps = split_into_chunks(speech_timestamps, max_duration=30.0) | |
# Transcribe each segment and combine into a single text | |
full_transcription = "" | |
for segment in chunked_timestamps: | |
start = segment['start'] | |
end = segment['end'] | |
transcription = transcribe_segment(audio_tensor, sample_rate, start, end) | |
full_transcription += transcription + " " # Add space between segments | |
# Remove extra spaces and print the complete text | |
full_transcription = " ".join(full_transcription.split()) # Clean up multiple spaces | |
print(f'>>> Full Transcription\n{full_transcription}') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import requests | |
import torch | |
import io | |
import soundfile as sf | |
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig | |
from urllib.request import urlopen | |
import torchaudio | |
# Define model path | |
model_path = "microsoft/Phi-4-multimodal-instruct" | |
# Load model and processor | |
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
device_map="cuda", | |
torch_dtype="auto", | |
trust_remote_code=True, | |
attn_implementation='flash_attention_2', | |
).cuda() | |
# Load generation config | |
generation_config = GenerationConfig.from_pretrained(model_path) | |
# Define prompt structure | |
user_prompt = '<|user|>' | |
assistant_prompt = '<|assistant|>' | |
prompt_suffix = '<|end|>' | |
# Function to get speech segments using Silero VAD | |
def get_speech_segments(audio, sample_rate): | |
# Load Silero VAD model | |
model, utils = torch.hub.load('snakers4/silero-vad', 'silero_vad', source='github') | |
(get_speech_timestamps, _, _, _, _) = utils | |
# Convert audio to tensor if it's not already | |
if not isinstance(audio, torch.Tensor): | |
audio = torch.tensor(audio, dtype=torch.float32) | |
# Ensure audio is mono and at 16000 Hz (Silero VAD requirement) | |
if audio.dim() > 1 and audio.shape[1] > 1: # Convert stereo to mono | |
audio = torch.mean(audio, dim=1, keepdim=False) | |
if sample_rate != 16000: | |
resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
audio = resampler(audio) | |
sample_rate = 16000 | |
# Get speech timestamps | |
speech_timestamps = get_speech_timestamps(audio, model, sampling_rate=sample_rate, return_seconds=True) | |
return speech_timestamps, audio, sample_rate | |
# Function to transcribe a segment of audio | |
def transcribe_segment(audio, sample_rate, start, end): | |
start_sample = int(start * sample_rate) | |
end_sample = int(end * sample_rate) | |
segment = audio[start_sample:end_sample].numpy() # Convert back to numpy for processor | |
speech_prompt = "Transcribe the audio to text, keep it verbatim and process all audio" | |
prompt = f'{user_prompt}<|audio_1|>{speech_prompt}{prompt_suffix}{assistant_prompt}' | |
# Process with the model | |
inputs = processor(text=prompt, audios=[(segment, sample_rate)], return_tensors='pt').to('cuda:0') | |
generate_ids = model.generate( | |
**inputs, | |
max_new_tokens=128000, | |
generation_config=generation_config, | |
) | |
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:] | |
response = processor.batch_decode( | |
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
)[0] | |
return response | |
# Main execution | |
print("\n--- AUDIO PROCESSING ---") | |
audio_url = "https://device-mattreXXXXXXX-1355.mp3" # Your audio URL | |
print(f"Downloading audio from: {audio_url}") | |
# Download and open audio file | |
audio, samplerate = sf.read(io.BytesIO(urlopen(audio_url).read())) | |
# Get speech segments | |
speech_timestamps, audio_tensor, sample_rate = get_speech_segments(audio, samplerate) | |
# Transcribe each segment and combine | |
full_transcription = "" | |
for segment in speech_timestamps: | |
start = segment['start'] | |
end = segment['end'] | |
transcription = transcribe_segment(audio_tensor, sample_rate, start, end) | |
full_transcription += f"[{start:.2f}s - {end:.2f}s]: {transcription}\n" | |
print(f'>>> Full Transcription\n{full_transcription}') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import torch | |
from transformers import pipeline, WhisperProcessor, WhisperForConditionalGeneration | |
import soundfile as sf | |
import numpy as np | |
# Model from Hugging Face | |
MODEL_ID = "Na0s/Medical-Whisper-Large-v3" | |
# Set device | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
# Load model and processor manually to ensure control | |
processor = WhisperProcessor.from_pretrained(MODEL_ID) | |
model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID).to(device) | |
model.eval() | |
# Initialize pipeline with custom model and processor | |
pipe = pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
chunk_length_s=30, # 30-second chunks | |
device=device, | |
) | |
def transcribe_audio(audio_path, language="en"): | |
# Load audio file | |
audio, sample_rate = sf.read(audio_path) | |
# Ensure 16kHz sample rate | |
if sample_rate != 16000: | |
raise ValueError(f"Audio must be 16kHz, got {sample_rate}Hz. Please resample your file.") | |
# Debug: Audio duration | |
duration = len(audio) / sample_rate | |
print(f"Audio duration: {duration:.2f} seconds") | |
# Prepare audio input | |
audio_input = {"array": audio, "sampling_rate": sample_rate} | |
# Transcribe with chunking and explicit attention mask | |
prediction = pipe( | |
audio_input, | |
batch_size=8, | |
return_timestamps=False, # Set True for timestamps if desired | |
generate_kwargs={ | |
"language": language, | |
"task": "transcribe", | |
"max_length": 20000, | |
# Explicitly pass attention mask (handled internally by pipeline) | |
}, | |
) | |
# Return the full transcription | |
return prediction["text"] | |
if __name__ == "__main__": | |
if len(sys.argv) != 2: | |
print("Usage: python transcribe.py <audio_file_path>") | |
sys.exit(1) | |
audio_file = sys.argv[1] | |
try: | |
result = transcribe_audio(audio_file, language="en") # Adjust language if needed | |
print("Transcription:", result) | |
except Exception as e: | |
print(f"Error: {e}") | |
sys.exit(1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import torch | |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | |
import soundfile as sf | |
from pydub import AudioSegment # For MP3 conversion | |
def adjust_pauses_for_hf_pipeline_output(pipeline_output, split_threshold=0.12): | |
""" | |
Adjust pause timings by distributing pauses up to the threshold evenly between adjacent words. | |
""" | |
adjusted_chunks = pipeline_output["chunks"].copy() | |
for i in range(len(adjusted_chunks) - 1): | |
current_chunk = adjusted_chunks[i] | |
next_chunk = adjusted_chunks[i + 1] | |
current_start, current_end = current_chunk["timestamp"] | |
next_start, next_end = next_chunk["timestamp"] | |
pause_duration = next_start - current_end | |
if pause_duration > 0: | |
if pause_duration > split_threshold: | |
distribute = split_threshold / 2 | |
else: | |
distribute = pause_duration / 2 | |
adjusted_chunks[i]["timestamp"] = (current_start, current_end + distribute) | |
adjusted_chunks[i + 1]["timestamp"] = (next_start - distribute, next_end) | |
pipeline_output["chunks"] = adjusted_chunks | |
return pipeline_output | |
def load_audio_file(file_path): | |
"""Load audio file and convert to 16kHz WAV if necessary.""" | |
file_ext = os.path.splitext(file_path)[1].lower() | |
if file_ext not in ['.wav', '.mp3']: | |
raise ValueError("Only WAV and MP3 files are supported") | |
if file_ext == '.mp3': | |
# Convert MP3 to WAV with 16kHz | |
audio = AudioSegment.from_mp3(file_path) | |
audio = audio.set_frame_rate(16000) | |
temp_wav = "temp_16khz.wav" | |
audio.export(temp_wav, format="wav") | |
audio_data, sample_rate = sf.read(temp_wav) | |
os.remove(temp_wav) # Clean up temporary file | |
else: | |
# Read WAV directly | |
audio_data, sample_rate = sf.read(file_path) | |
if sample_rate != 16000: | |
# Resample to 16kHz if not already at that rate | |
audio = AudioSegment.from_wav(file_path) | |
audio = audio.set_frame_rate(16000) | |
temp_wav = "temp_16khz.wav" | |
audio.export(temp_wav, format="wav") | |
audio_data, sample_rate = sf.read(temp_wav) | |
os.remove(temp_wav) | |
return {"array": audio_data, "sampling_rate": sample_rate} | |
# Setup device and model | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
model_id = "nyrahealth/CrisperWhisper" | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True | |
) | |
model.to(device) | |
processor = AutoProcessor.from_pretrained(model_id) | |
pipe = pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
chunk_length_s=30, | |
batch_size=16, | |
return_timestamps='word', | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
if __name__ == "__main__": | |
# Check for command line argument | |
if len(sys.argv) != 2: | |
print("Usage: python script.py <audio_file_path>") | |
sys.exit(1) | |
audio_file_path = sys.argv[1] | |
try: | |
# Load and process the audio file | |
audio_input = load_audio_file(audio_file_path) | |
# Verify sample rate | |
if audio_input["sampling_rate"] != 16000: | |
raise ValueError("Audio sample rate must be 16kHz") | |
# Process the audio | |
hf_pipeline_output = pipe(audio_input) | |
crisper_whisper_result = adjust_pauses_for_hf_pipeline_output(hf_pipeline_output) | |
# Print the result | |
print(crisper_whisper_result) | |
except FileNotFoundError: | |
print(f"Error: The file {audio_file_path} was not found.") | |
except ValueError as ve: | |
print(f"Error: {str(ve)}") | |
except Exception as e: | |
print(f"An error occurred: {str(e)}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Medical transcription using fine tune version of Whisper at Hugging Face: https://huggingface.co/Na0s/Medical-Whisper-Large-v3/tree/main...
Run as:
python transcribe.py /path/to/audio.wav
Perquisites: