Skip to content

Instantly share code, notes, and snippets.

@Martins6
Created May 20, 2024 16:40
Show Gist options
  • Save Martins6/5a9c682bb55ce89782617fe6dcb4cecc to your computer and use it in GitHub Desktop.
Save Martins6/5a9c682bb55ce89782617fe6dcb4cecc to your computer and use it in GitHub Desktop.
STT with PyAudio and HuggingFace.
import time
import wave
from dataclasses import asdict, dataclass
import pyaudio
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
@dataclass
class StreamParams:
format: int = pyaudio.paInt16
channels: int = 1
rate: int = 44100
frames_per_buffer: int = 1024
input: bool = True
output: bool = False
def to_dict(self) -> dict:
return asdict(self)
class STTManualRecorder:
def __init__(
self,
wav_file: str = "recording.wav",
model_hf_id: str = "distil-whisper/distil-small.en",
torch_device: str = "auto",
) -> None:
self.stream_params = StreamParams()
self.wav_file = wav_file
self.model_hf_id = model_hf_id
self.setup_torch(torch_device)
def setup_torch(self, torch_device: str):
if torch_device == "auto":
if torch.backends.mps.is_available():
device = torch.device("mps")
torch_dtype = torch.float16
elif torch.cuda.is_available():
device = torch.device("cuda")
torch_dtype = torch.float16
else:
device = torch.device("cpu")
torch_dtype = torch.float32
self.device = device
self.torch_dtype = torch_dtype
def record(self):
audio = pyaudio.PyAudio()
print(self.stream_params.to_dict())
stream = audio.open(**self.stream_params.to_dict())
frames = []
print("Press SPACE to start recording.")
# wait for user input of SPACE key
while True:
if input() == " ": # if SPACE key is pressed
break
print("Recording... Press SPACE to stop.")
time.sleep(0.5)
while True:
try:
data = stream.read(
self.stream_params.frames_per_buffer, exception_on_overflow=False
)
frames.append(data)
except KeyboardInterrupt:
break
if input() == " ":
print("Recording stopped.")
time.sleep(0.5)
break
stream.stop_stream()
stream.close()
audio.terminate()
with wave.open(self.wav_file, "wb") as wf:
wf.setnchannels(self.stream_params.channels)
wf.setsampwidth(audio.get_sample_size(self.stream_params.format))
wf.setframerate(self.stream_params.rate)
wf.writeframes(b"".join(frames))
print(f"Saved recording to {self.wav_file}")
def transcribe(self) -> str:
model = AutoModelForSpeechSeq2Seq.from_pretrained(
self.model_hf_id,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
)
model.to(self.device)
processor = AutoProcessor.from_pretrained(self.model_hf_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=self.torch_dtype,
device=self.device,
)
result = pipe(self.wav_file)
return result["text"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment