Created
May 20, 2024 16:40
-
-
Save Martins6/5a9c682bb55ce89782617fe6dcb4cecc to your computer and use it in GitHub Desktop.
STT with PyAudio and HuggingFace.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import wave | |
from dataclasses import asdict, dataclass | |
import pyaudio | |
import torch | |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | |
@dataclass | |
class StreamParams: | |
format: int = pyaudio.paInt16 | |
channels: int = 1 | |
rate: int = 44100 | |
frames_per_buffer: int = 1024 | |
input: bool = True | |
output: bool = False | |
def to_dict(self) -> dict: | |
return asdict(self) | |
class STTManualRecorder: | |
def __init__( | |
self, | |
wav_file: str = "recording.wav", | |
model_hf_id: str = "distil-whisper/distil-small.en", | |
torch_device: str = "auto", | |
) -> None: | |
self.stream_params = StreamParams() | |
self.wav_file = wav_file | |
self.model_hf_id = model_hf_id | |
self.setup_torch(torch_device) | |
def setup_torch(self, torch_device: str): | |
if torch_device == "auto": | |
if torch.backends.mps.is_available(): | |
device = torch.device("mps") | |
torch_dtype = torch.float16 | |
elif torch.cuda.is_available(): | |
device = torch.device("cuda") | |
torch_dtype = torch.float16 | |
else: | |
device = torch.device("cpu") | |
torch_dtype = torch.float32 | |
self.device = device | |
self.torch_dtype = torch_dtype | |
def record(self): | |
audio = pyaudio.PyAudio() | |
print(self.stream_params.to_dict()) | |
stream = audio.open(**self.stream_params.to_dict()) | |
frames = [] | |
print("Press SPACE to start recording.") | |
# wait for user input of SPACE key | |
while True: | |
if input() == " ": # if SPACE key is pressed | |
break | |
print("Recording... Press SPACE to stop.") | |
time.sleep(0.5) | |
while True: | |
try: | |
data = stream.read( | |
self.stream_params.frames_per_buffer, exception_on_overflow=False | |
) | |
frames.append(data) | |
except KeyboardInterrupt: | |
break | |
if input() == " ": | |
print("Recording stopped.") | |
time.sleep(0.5) | |
break | |
stream.stop_stream() | |
stream.close() | |
audio.terminate() | |
with wave.open(self.wav_file, "wb") as wf: | |
wf.setnchannels(self.stream_params.channels) | |
wf.setsampwidth(audio.get_sample_size(self.stream_params.format)) | |
wf.setframerate(self.stream_params.rate) | |
wf.writeframes(b"".join(frames)) | |
print(f"Saved recording to {self.wav_file}") | |
def transcribe(self) -> str: | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
self.model_hf_id, | |
torch_dtype=self.torch_dtype, | |
low_cpu_mem_usage=True, | |
use_safetensors=True, | |
) | |
model.to(self.device) | |
processor = AutoProcessor.from_pretrained(self.model_hf_id) | |
pipe = pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
max_new_tokens=128, | |
torch_dtype=self.torch_dtype, | |
device=self.device, | |
) | |
result = pipe(self.wav_file) | |
return result["text"] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment