-
-
Save juanmc2005/ed6413e697e176cb36a149d8c40a3a5b to your computer and use it in GitHub Desktop.
import logging | |
import os | |
import sys | |
import traceback | |
from contextlib import contextmanager | |
import diart.operators as dops | |
import numpy as np | |
import rich | |
import rx.operators as ops | |
import whisper_timestamped as whisper | |
from diart import SpeakerDiarization, SpeakerDiarizationConfig | |
from diart.sources import MicrophoneAudioSource | |
from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment | |
def concat(chunks, collar=0.05): | |
""" | |
Concatenate predictions and audio | |
given a list of `(diarization, waveform)` pairs | |
and merge contiguous single-speaker regions | |
with pauses shorter than `collar` seconds. | |
""" | |
first_annotation = chunks[0][0] | |
first_waveform = chunks[0][1] | |
annotation = Annotation(uri=first_annotation.uri) | |
data = [] | |
for ann, wav in chunks: | |
annotation.update(ann) | |
data.append(wav.data) | |
annotation = annotation.support(collar) | |
window = SlidingWindow( | |
first_waveform.sliding_window.duration, | |
first_waveform.sliding_window.step, | |
first_waveform.sliding_window.start, | |
) | |
data = np.concatenate(data, axis=0) | |
return annotation, SlidingWindowFeature(data, window) | |
def colorize_transcription(transcription): | |
""" | |
Unify a speaker-aware transcription represented as | |
a list of `(speaker: int, text: str)` pairs | |
into a single text colored by speakers. | |
""" | |
colors = 2 * [ | |
"bright_red", "bright_blue", "bright_green", "orange3", "deep_pink1", | |
"yellow2", "magenta", "cyan", "bright_magenta", "dodger_blue2" | |
] | |
result = [] | |
for speaker, text in transcription: | |
if speaker == -1: | |
# No speakerfound for this text, use default terminal color | |
result.append(text) | |
else: | |
result.append(f"[{colors[speaker]}]{text}") | |
return "\n".join(result) | |
@contextmanager | |
def suppress_stdout(): | |
# Auxiliary function to suppress Whisper logs (it is quite verbose) | |
# All credit goes to: https://thesmithfam.org/blog/2012/10/25/temporarily-suppress-console-output-in-python/ | |
with open(os.devnull, "w") as devnull: | |
old_stdout = sys.stdout | |
sys.stdout = devnull | |
try: | |
yield | |
finally: | |
sys.stdout = old_stdout | |
class WhisperTranscriber: | |
def __init__(self, model="small", device=None): | |
self.model = whisper.load_model(model, device=device) | |
self._buffer = "" | |
def transcribe(self, waveform): | |
"""Transcribe audio using Whisper""" | |
# Pad/trim audio to fit 30 seconds as required by Whisper | |
audio = waveform.data.astype("float32").reshape(-1) | |
audio = whisper.pad_or_trim(audio) | |
# Transcribe the given audio while suppressing logs | |
with suppress_stdout(): | |
transcription = whisper.transcribe( | |
self.model, | |
audio, | |
# We use past transcriptions to condition the model | |
initial_prompt=self._buffer, | |
verbose=True # to avoid progress bar | |
) | |
return transcription | |
def identify_speakers(self, transcription, diarization, time_shift): | |
"""Iterate over transcription segments to assign speakers""" | |
speaker_captions = [] | |
for segment in transcription["segments"]: | |
# Crop diarization to the segment timestamps | |
start = time_shift + segment["words"][0]["start"] | |
end = time_shift + segment["words"][-1]["end"] | |
dia = diarization.crop(Segment(start, end)) | |
# Assign a speaker to the segment based on diarization | |
speakers = dia.labels() | |
num_speakers = len(speakers) | |
if num_speakers == 0: | |
# No speakers were detected | |
caption = (-1, segment["text"]) | |
elif num_speakers == 1: | |
# Only one speaker is active in this segment | |
spk_id = int(speakers[0].split("speaker")[1]) | |
caption = (spk_id, segment["text"]) | |
else: | |
# Multiple speakers, select the one that speaks the most | |
max_speaker = int(np.argmax([ | |
dia.label_duration(spk) for spk in speakers | |
])) | |
caption = (max_speaker, segment["text"]) | |
speaker_captions.append(caption) | |
return speaker_captions | |
def __call__(self, diarization, waveform): | |
# Step 1: Transcribe | |
transcription = self.transcribe(waveform) | |
# Update transcription buffer | |
self._buffer += transcription["text"] | |
# The audio may not be the beginning of the conversation | |
time_shift = waveform.sliding_window.start | |
# Step 2: Assign speakers | |
speaker_transcriptions = self.identify_speakers(transcription, diarization, time_shift) | |
return speaker_transcriptions | |
# Suppress whisper-timestamped warnings for a clean output | |
logging.getLogger("whisper_timestamped").setLevel(logging.ERROR) | |
# If you have a GPU, you can also set device=torch.device("cuda") | |
config = SpeakerDiarizationConfig( | |
duration=5, | |
step=0.5, | |
latency="min", | |
tau_active=0.5, | |
rho_update=0.1, | |
delta_new=0.57 | |
) | |
dia = SpeakerDiarization(config) | |
source = MicrophoneAudioSource(config.sample_rate) | |
# If you have a GPU, you can also set device="cuda" | |
asr = WhisperTranscriber(model="small") | |
# Split the stream into 2s chunks for transcription | |
transcription_duration = 2 | |
# Apply models in batches for better efficiency | |
batch_size = int(transcription_duration // config.step) | |
# Chain of operations to apply on the stream of microphone audio | |
source.stream.pipe( | |
# Format audio stream to sliding windows of 5s with a step of 500ms | |
dops.rearrange_audio_stream( | |
config.duration, config.step, config.sample_rate | |
), | |
# Wait until a batch is full | |
# The output is a list of audio chunks | |
ops.buffer_with_count(count=batch_size), | |
# Obtain diarization prediction | |
# The output is a list of pairs `(diarization, audio chunk)` | |
ops.map(dia), | |
# Concatenate 500ms predictions/chunks to form a single 2s chunk | |
ops.map(concat), | |
# Ignore this chunk if it does not contain speech | |
ops.filter(lambda ann_wav: ann_wav[0].get_timeline().duration() > 0), | |
# Obtain speaker-aware transcriptions | |
# The output is a list of pairs `(speaker: int, caption: str)` | |
ops.starmap(asr), | |
# Color transcriptions according to the speaker | |
# The output is plain text with color references for rich | |
ops.map(colorize_transcription), | |
).subscribe( | |
on_next=rich.print, # print colored text | |
on_error=lambda _: traceback.print_exc() # print stacktrace if error | |
) | |
print("Listening...") | |
source.read() |
I am on a Mac M1 Pro
- Does this code to realtime transcription with speaker labels?
- the gist is stuck at "listening..."
I can fix the listening halt on Ubuntu and Mac OS with M2 Ultra by just changing the following line (line 151):
source = MicrophoneAudioSource(config.sample_rate)
to this:
source = MicrophoneAudioSource(config.step)
Hi @juanmc2005,
First of all, thank you for your article on Medium and for your code.
I knew nothing about Pyannote, Diart or Whisper(/-timestamped) yesterday, and now, I am able to make them work independently in great part thanks to your documentation on Diart.
When it comes to running the current script, I have one last small problem.
I have audio that is interpreted correctly by both Whisper-Timestamped and Diart when I pass it through their respective example codes.
However, when using the current script, the transcription of the audio is completely wrong.
If I pass the audio directly (without splitting it first) to whisper.transcribe, the transcription is good, but the whole text is returned for every timestamp, which is expected now that the audio is not split.
I don't know if this is a problem that comes from some settings I haven't made properly in your script, or if this is a problem related to Whisper.
If you have any clues on the subject, I would be interested to ear about it.
If the problem comes from Whisper, I will ask in their repository.
Thank you in advance and take care.
Hi @juanmc2005
I got these error when running it.