Skip to content

Instantly share code, notes, and snippets.

@aimerib
Created January 19, 2025 14:43
Show Gist options
  • Select an option

  • Save aimerib/4f83e5a1b6228bb89f217256045fffee to your computer and use it in GitHub Desktop.

Select an option

Save aimerib/4f83e5a1b6228bb89f217256045fffee to your computer and use it in GitHub Desktop.
This overly complicated example uses a Qt application and the spacebar to start recording audio, sending it to qwen-audio and then processing via kokoro for TTS. ATTENTION: This code is littered with .device("mps") so make sure to change that to your device type or let torch detect it automatically.
from PySide6.QtWidgets import (
QApplication,
QMainWindow,
QLabel,
QVBoxLayout,
QHBoxLayout,
QWidget,
QTextEdit,
)
from PySide6.QtCore import Qt, QThread, Signal
import sounddevice as sd
import numpy as np
import sys
import torch
from torch import _TorchCompileInductorWrapper
from torch.nn.modules.conv import Conv1d
from torch.nn.modules.sparse import Embedding
from torch.nn.modules.container import ModuleList
from torch._dynamo import OptimizedModule
from torch._dynamo.eval_frame import OptimizeContext
from transformers import (
Qwen2AudioForConditionalGeneration,
AutoProcessor,
Qwen2AudioEncoder,
Qwen2Config,
GenerationConfig,
)
from transformers.models.qwen2_audio.modeling_qwen2_audio import (
Qwen2AudioEncoderLayer,
Qwen2AudioSdpaAttention,
Qwen2AudioMultiModalProjector,
)
from transformers.activations import GELUActivation
from transformers.models.qwen2_audio.configuration_qwen2_audio import (
Qwen2AudioConfig,
Qwen2AudioEncoderConfig,
)
from torch._dynamo.repro.after_dynamo import WrapBackendDebug
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2ForCausalLM,
Qwen2Model,
Qwen2DecoderLayer,
Qwen2SdpaAttention,
Qwen2RotaryEmbedding,
Qwen2MLP,
Qwen2RMSNorm,
)
from torch._dynamo.convert_frame import CatchErrorsWrapper, ConvertFrameAssert
from torch.nn.modules.activation import SiLU
from transformers.modeling_rope_utils import _compute_default_rope_parameters
from torch.nn.modules.normalization import LayerNorm
from torch.nn.modules.linear import Linear
from torch.nn.modules.pooling import AvgPool1d
from torch.nn.functional import gelu
from torch._dynamo.hooks import Hooks
from transformers.generation.configuration_utils import CompileConfig
from contextlib import nullcontext
import os
import torch._dynamo # for compilation configs
from pathlib import Path
import sounddevice as sd
import asyncio
from kokoro_onnx import Kokoro
import time
torch.serialization.add_safe_globals(
[
OptimizedModule,
Qwen2AudioForConditionalGeneration,
Qwen2AudioEncoder,
set,
getattr,
Conv1d,
Embedding,
ModuleList,
Qwen2AudioEncoderLayer,
Qwen2AudioSdpaAttention,
Linear,
Qwen2AudioConfig,
Qwen2AudioEncoderConfig,
LayerNorm,
GELUActivation,
gelu,
AvgPool1d,
Qwen2AudioMultiModalProjector,
Qwen2ForCausalLM,
Qwen2Model,
Qwen2DecoderLayer,
Qwen2SdpaAttention,
Qwen2RotaryEmbedding,
Qwen2Config,
_compute_default_rope_parameters,
Qwen2MLP,
SiLU,
Qwen2RMSNorm,
GenerationConfig,
OptimizeContext,
CatchErrorsWrapper,
ConvertFrameAssert,
WrapBackendDebug,
_TorchCompileInductorWrapper,
getattr,
Hooks,
nullcontext,
CompileConfig
]
)
class AudioRecorderThread(QThread):
finished = Signal(dict)
def __init__(self, sample_rate=16000):
super().__init__()
try:
self.sample_rate = sample_rate
self.audio_data = []
self.recording = False
self.stream = None
except Exception as e:
print(f"Error initializing AudioRecorderThread: {e}")
import traceback
traceback.print_exc()
raise
def audio_callback(self, indata, frames, time, status):
try:
if status:
print(f'Audio callback status: {status}')
if self.recording:
print("Recording...")
self.audio_data.append(indata.copy())
print(f"Audio data length: {len(self.audio_data)}")
except Exception as e:
print(f"Error in audio callback: {e}")
import traceback
traceback.print_exc()
def start_recording(self):
try:
print("Starting recording")
self.audio_data = []
self.recording = True
try:
self.stream = sd.InputStream(
samplerate=self.sample_rate,
channels=1,
callback=self.audio_callback
)
self.stream.start()
print("Stream started")
except Exception as e:
print(f"Error initializing audio stream: {e}")
import traceback
traceback.print_exc()
self.recording = False
raise
except Exception as e:
print(f"Error in start_recording: {e}")
import traceback
traceback.print_exc()
self.recording = False
def stop_recording(self):
try:
print("Stopping recording")
self.recording = False
if self.stream:
try:
self.stream.stop()
self.stream.close()
print("Stream stopped and closed")
except Exception as e:
print(f"Error stopping stream: {e}")
import traceback
traceback.print_exc()
try:
print("Concatenating audio data")
if self.audio_data:
audio = np.concatenate(self.audio_data, axis=0)
self.finished.emit({"audio": audio.flatten(), "sr": self.sample_rate})
print("Emitting finished signal")
else:
print("No audio data recorded")
self.finished.emit({"audio": np.array([]), "sr": self.sample_rate})
except Exception as e:
print(f"Error processing recorded audio: {e}")
import traceback
traceback.print_exc()
self.finished.emit({"audio": np.array([]), "sr": self.sample_rate})
except Exception as e:
print(f"Error in stop_recording: {e}")
import traceback
traceback.print_exc()
self.finished.emit({"audio": np.array([]), "sr": self.sample_rate})
def run(self):
with sd.InputStream(
samplerate=self.sample_rate, channels=1, callback=self.audio_callback
):
print("Stream started")
while self.recording:
print("Recording...")
sd.sleep(100)
if self.audio_data:
print("Concatenating audio data")
complete_audio = np.concatenate(self.audio_data)
audio_data_dict = {"audio": complete_audio, "sr": self.sample_rate}
print("Emitting finished signal")
self.finished.emit(audio_data_dict)
else:
print("No audio data recorded")
self.finished.emit({"audio": np.array([]), "sr": self.sample_rate})
class StreamingTTSWorker(QThread):
def __init__(self, kokoro: Kokoro, text, voice="af_bella", sampling_rate=16000):
super().__init__()
try:
self.kokoro = kokoro
self.text = text
self.voice = voice
self.sampling_rate = sampling_rate
self._is_running = True
self.stream = None
self.loop = None
except Exception as e:
print(f"Error initializing StreamingTTSWorker: {e}")
import traceback
traceback.print_exc()
raise
async def process_audio(self):
try:
print("Starting TTS stream generation...")
async for samples, sample_rate in self.kokoro.create_stream(
self.text,
voice=self.voice
):
if not self._is_running:
print("TTS stopped by request")
break
try:
sd.play(samples, sample_rate, blocking=True)
print(f"Playing audio chunk, length: {len(samples)}")
except Exception as e:
print(f"Error playing audio chunk: {e}")
import traceback
traceback.print_exc()
break
except Exception as e:
print(f"Error in TTS generation: {e}")
import traceback
traceback.print_exc()
def run(self):
try:
try:
self.stream = sd.OutputStream(
samplerate=self.sampling_rate,
channels=1,
callback=self.audio_callback
)
print("TTS stream initialized")
except Exception as e:
print(f"Error initializing TTS stream: {e}")
import traceback
traceback.print_exc()
return
with self.stream:
try:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.loop.run_until_complete(self.process_audio())
except Exception as e:
print(f"Error in TTS event loop: {e}")
import traceback
traceback.print_exc()
finally:
if self.loop:
self.loop.close()
self.loop = None
except Exception as e:
print(f"Error in TTS worker: {e}")
import traceback
traceback.print_exc()
finally:
if self.stream:
try:
self.stream.stop()
self.stream.close()
print("TTS stream cleaned up")
except Exception as e:
print(f"Error cleaning up TTS stream: {e}")
import traceback
traceback.print_exc()
def stop(self):
try:
self._is_running = False
if self.stream:
self.stream.stop()
self.stream.close()
print("TTS stopped and cleaned up")
if self.loop and self.loop.is_running():
self.loop.stop()
except Exception as e:
print(f"Error stopping TTS: {e}")
import traceback
traceback.print_exc()
def audio_callback(self, outdata, frames, time, status):
if status:
print(f'TTS audio status: {status}')
class LLMWorker(QThread):
finished = Signal(str)
error = Signal(str)
def __init__(self, model, processor, inputs):
super().__init__()
try:
self.model = model
self.processor = processor
self.inputs = inputs
self._is_running = True
except Exception as e:
print(f"Error initializing LLMWorker: {e}")
import traceback
traceback.print_exc()
raise
def run(self):
try:
print("Starting LLM generation...")
try:
if torch.backends.mps.is_available():
print("Moving inputs to MPS...")
for key in self.inputs:
if isinstance(self.inputs[key], torch.Tensor):
self.inputs[key] = self.inputs[key].to('mps')
print("Inputs moved to MPS")
except Exception as e:
print(f"Error moving inputs to device: {e}")
print("Falling back to CPU")
self.model = self.model.to('cpu')
for key in self.inputs:
if isinstance(self.inputs[key], torch.Tensor):
self.inputs[key] = self.inputs[key].to('cpu')
gen_config = {
"max_new_tokens": self.max_new_tokens,
"do_sample": True,
"temperature": 0.7,
"top_p": 0.9,
"pad_token_id": self.processor.tokenizer.pad_token_id,
"eos_token_id": self.processor.tokenizer.eos_token_id,
}
print("Starting generation with config:", gen_config)
try:
with torch.no_grad():
start_time = time.time()
processed_inputs = {}
for key, value in self.inputs.items():
if key == "audios":
processed_inputs[key] = value.unsqueeze(0) if len(value.shape) == 1 else value
else:
processed_inputs[key] = value
print("Input shapes:", {k: v.shape if isinstance(v, torch.Tensor) else type(v) for k, v in processed_inputs.items()})
outputs = self.model.generate(
**processed_inputs,
**gen_config
)
end_time = time.time()
print(f"Generation completed in {end_time - start_time:.2f} seconds")
except Exception as e:
print(f"Error during generation: {e}")
import traceback
traceback.print_exc()
self.error.emit(f"Generation failed: {str(e)}")
return
try:
print("Processing outputs...")
full_text = self.processor.decode(outputs[0], skip_special_tokens=True)
print(f"Full generated text length: {len(full_text)}")
parts = full_text.split("Assistant:")
if len(parts) > 1:
response = parts[-1].strip()
else:
last_system = full_text.rfind("System:")
last_user = full_text.rfind("User:")
last_marker = max(last_system, last_user)
if last_marker != -1:
next_marker = full_text.find("\n", last_marker)
if next_marker != -1:
response = full_text[next_marker:].strip()
else:
response = full_text[last_marker:].strip()
else:
response = full_text.strip()
if not response:
response = "I apologize, but I couldn't generate a proper response. Please try again."
print(f"Extracted response length: {len(response)}")
self.finished.emit(response)
except Exception as e:
print(f"Error processing outputs: {e}")
import traceback
traceback.print_exc()
self.error.emit(f"Failed to process model output: {str(e)}")
except Exception as e:
print(f"Unexpected error in LLMWorker: {e}")
import traceback
traceback.print_exc()
self.error.emit(f"An unexpected error occurred: {str(e)}")
finally:
if torch.backends.mps.is_available():
try:
torch.mps.empty_cache()
print("MPS cache cleared")
except Exception as e:
print(f"Error clearing MPS cache: {e}")
def stop(self):
self._is_running = False
class MainWindow(QMainWindow):
def __init__(self):
super().__init__()
try:
self.setWindowTitle("Audio Assistant")
self.setGeometry(100, 100, 600, 400)
self.conversation_history = []
self.recorder = None
self.processing = False
self.setFocusPolicy(Qt.StrongFocus)
central_widget = QWidget()
self.setCentralWidget(central_widget)
layout = QVBoxLayout(central_widget)
status_layout = QHBoxLayout()
self.status_label = QLabel("Loading models...")
status_layout.addWidget(self.status_label)
self.thinking_label = QLabel("")
self.thinking_label.setAlignment(Qt.AlignCenter)
status_layout.addWidget(self.thinking_label)
layout.addLayout(status_layout)
self.response_display = QTextEdit()
self.response_display.setReadOnly(True)
self.response_display.setMinimumHeight(200)
layout.addWidget(self.response_display)
self.show()
self.activateWindow()
self.setFocus()
QApplication.processEvents()
QThread.currentThread().msleep(100)
self.init_models()
except Exception as e:
print(f"Error initializing MainWindow: {e}")
import traceback
traceback.print_exc()
raise
def keyPressEvent(self, event):
try:
if event.key() == Qt.Key_Space and not event.isAutoRepeat():
if self.recorder is not None and not self.processing:
print("Starting recording...")
self.status_label.setText("Recording...")
QApplication.processEvents()
self.recorder.start_recording()
except Exception as e:
print(f"Error in keyPressEvent: {e}")
import traceback
traceback.print_exc()
def keyReleaseEvent(self, event):
try:
if event.key() == Qt.Key_Space and not event.isAutoRepeat():
if self.recorder is not None and self.recorder.recording:
print("Stopping recording...")
self.recorder.stop_recording()
self.status_label.setText("Processing...")
QApplication.processEvents()
except Exception as e:
print(f"Error in keyReleaseEvent: {e}")
import traceback
traceback.print_exc()
def process_audio(self, audio_data: dict):
try:
if not audio_data["audio"].size:
print("No audio data to process")
self.status_label.setText("No audio recorded. Hold SPACE to try again.")
return
if self.processing:
print("Already processing, ignoring new audio")
return
self.processing = True
print("Processing audio")
try:
self.status_label.setText("Hold SPACE to record")
self.thinking_label.setText("Thinking...")
QApplication.processEvents()
if not self.conversation_history or self.conversation_history[0]["role"] != "system":
self.conversation_history.insert(0, {
"role": "system",
"content": "You are Gorpo, a friendly and helpful personal assistant. Always provide clear, informative responses while maintaining a conversational tone."
})
self.conversation_history.append({
"role": "user",
"content": [{"audio": "<|audio_bos|><|AUDIO|><|audio_eos|>"}]
})
conversation_text = ""
for message in self.conversation_history:
role = message["role"]
if role == "system":
conversation_text += f"System: {message['content']}\n"
elif role == "user":
if isinstance(message["content"], list):
for content in message["content"]:
if "audio" in content:
conversation_text += f"User: {content['audio']}\n"
else:
conversation_text += f"User: {content['text']}\n"
else:
conversation_text += f"User: {message['content']}\n"
elif role == "assistant":
conversation_text += f"Assistant: {message['content']}\n"
conversation_text += "Assistant:"
try:
inputs = self.processor(
text=conversation_text,
audios=audio_data["audio"],
return_tensors="pt",
sampling_rate=audio_data["sr"],
)
except Exception as e:
print(f"Error processing audio: {e}")
raise RuntimeError(f"Failed to process audio input: {e}")
if self.tts_worker and self.tts_worker.isRunning():
self.tts_worker.stop()
self.tts_worker.wait()
if self.llm_worker and self.llm_worker.isRunning():
self.llm_worker.wait()
self.llm_worker = LLMWorker(self.model, self.processor, inputs)
self.llm_worker.finished.connect(self.handle_llm_response)
self.llm_worker.error.connect(self.handle_llm_error)
self.llm_worker.start()
except Exception as e:
print(f"Processing failed: {e}")
import traceback
traceback.print_exc()
self.thinking_label.setText("")
self.status_label.setText("Hold SPACE to record")
self.response_display.setText("I apologize, but I encountered an error while processing your request.")
finally:
if not self.llm_worker or not self.llm_worker.isRunning():
self.processing = False
except Exception as e:
print(f"Error in process_audio: {e}")
import traceback
traceback.print_exc()
self.processing = False
def handle_llm_response(self, response: str):
try:
self.thinking_label.setText("")
self.status_label.setText("Hold SPACE to record")
self.response_display.setText(response)
self.conversation_history.append({
"role": "assistant",
"content": response
})
if self.tts_worker and self.tts_worker.isRunning():
self.tts_worker.stop()
self.tts_worker.wait()
self.tts_worker = StreamingTTSWorker(self.kokoro, response, voice="af_bella", sampling_rate=self.processor.feature_extractor.sampling_rate)
self.tts_worker.start()
except Exception as e:
print(f"Error in handle_llm_response: {e}")
import traceback
traceback.print_exc()
finally:
self.processing = False
def handle_llm_error(self, error_msg: str):
try:
self.thinking_label.setText("")
self.status_label.setText("Hold SPACE to record")
self.response_display.setText(error_msg)
except Exception as e:
print(f"Error in handle_llm_error: {e}")
import traceback
traceback.print_exc()
finally:
self.processing = False
def init_models(self):
try:
print("Loading models...")
self.status_label.setText("Loading models...")
QApplication.processEvents()
self.kokoro = Kokoro("civitai_scrapper/kokoro/kokoro-v0_19.onnx", "civitai_scrapper/kokoro/voices.json")
self.tts_worker = None
self.llm_worker = None
self.cache_dir = Path("./model_cache")
self.compiled_model_path = self.cache_dir / "compiled_model.pt"
self.load_model()
print("Models loaded, initializing recorder...")
self.recorder = AudioRecorderThread(
sample_rate=self.processor.feature_extractor.sampling_rate
)
print(f"Recorder initialized: {self.recorder is not None}")
print(f"Recorder attributes: {dir(self.recorder)}")
self.recorder.finished.connect(self.process_audio)
print("Setup complete")
self.status_label.setText("Hold SPACE to record")
QApplication.processEvents()
except Exception as e:
print(f"Model initialization failed: {e}")
import traceback
traceback.print_exc()
self.status_label.setText("Error loading models")
QApplication.processEvents()
def load_model(self):
try:
self.processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2-Audio-7B-Instruct", cache_dir=self.cache_dir
)
if self.compiled_model_path.exists():
print("Loading pre-compiled model...")
self.model = torch.load(
self.compiled_model_path, map_location="mps", weights_only=True
)
else:
print("Loading and compiling model...")
config = Qwen2Config.from_pretrained(
"Qwen/Qwen2-Audio-7B-Instruct",
cache_dir=self.cache_dir,
use_cache=True,
)
self.model = Qwen2AudioForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-Audio-7B-Instruct",
config=config,
cache_dir=self.cache_dir,
torch_dtype=torch.float16,
device_map="mps",
use_flash_attention_2=True,
)
self.model = torch.compile(
self.model,
mode="max-autotune",
fullgraph=True,
dynamic=True,
)
os.makedirs(self.cache_dir, exist_ok=True)
torch.save(self.model, self.compiled_model_path)
except Exception as e:
print(f"Compilation failed: {e}. Loading standard model...")
config = Qwen2Config.from_pretrained(
"Qwen/Qwen2-Audio-7B-Instruct",
cache_dir=self.cache_dir,
use_cache=True,
)
self.model = Qwen2AudioForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-Audio-7B-Instruct",
config=config,
cache_dir=self.cache_dir,
torch_dtype=torch.float16,
device_map="mps",
use_flash_attention_2=True,
)
def closeEvent(self, event):
if self.tts_worker and self.tts_worker.isRunning():
self.tts_worker.stop()
self.tts_worker.wait()
if self.llm_worker and self.llm_worker.isRunning():
self.llm_worker.wait()
if self.recorder and self.recorder.recording:
self.recorder.stop_recording()
event.accept()
if __name__ == "__main__":
app = QApplication(sys.argv)
window = MainWindow()
sys.exit(app.exec())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment