Last active
March 28, 2026 14:43
-
-
Save briankung/7afbf743987df6b5b477e9b58b1d6d53 to your computer and use it in GitHub Desktop.
Testing live transcription on my M4 Max MacBook Pro (64GB). All runnable with `uv run <SCRIPT_NAME>`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # requires-python = ">=3.12" | |
| # dependencies = [ | |
| # "coremltools>=8.0", | |
| # "huggingface_hub", | |
| # "numpy", | |
| # "scipy", | |
| # "sounddevice", | |
| # "tiktoken", | |
| # "transformers", | |
| # ] | |
| # /// | |
| import argparse | |
| import os | |
| import queue | |
| import sys | |
| import coremltools as ct | |
| import numpy as np | |
| import sounddevice as sd | |
| from huggingface_hub import snapshot_download | |
| from scipy.signal import stft as scipy_stft | |
| from transformers import AutoTokenizer | |
| # --- Configuration --- | |
| MODEL_REPO = "phequals/canary-qwen-2.5b-coreml-int8" | |
| MODEL_LOCAL_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".canary-model") | |
| TOKENIZER_REPO = "Qwen/Qwen3-1.7B" | |
| SAMPLE_RATE = 16000 | |
| CHANNELS = 1 | |
| CHUNK_SECONDS = 5 | |
| WIN_SAMPLES = 400 # 25ms Hann window at 16kHz | |
| HOP_SAMPLES = 160 # 10ms hop at 16kHz | |
| N_MELS = 128 | |
| EMBED_DIM = 2048 | |
| MAX_NEW_TOKENS = 150 | |
| # Encoder expects exactly 500 mel frames (fixed CoreML input shape). | |
| # Required audio length: (500 - 1) * 160 + 400 = 80,240 samples β 5.015s | |
| TARGET_FRAMES = 500 | |
| TARGET_SAMPLES = (TARGET_FRAMES - 1) * HOP_SAMPLES + WIN_SAMPLES # 80,240 | |
| ENCODER_AUDIO_TOKENS = 63 # encoder output frames for TARGET_FRAMES input | |
| # RoPE parameters matching Qwen3-1.7B (head_dim=128, theta=10000) | |
| ROPE_HEAD_DIM = 128 | |
| ROPE_BASE = 10000.0 | |
| # KV cache max sequence length (from decoder state shape [1, 8, 256, 128]) | |
| KV_CACHE_MAX = 256 | |
| # Prompt prefix (text before the audio embeddings are inserted) | |
| PROMPT_PREFIX = "Transcribe the following: " | |
| audio_queue: queue.Queue = queue.Queue() | |
| def audio_callback(indata, frames, time, status): | |
| if status: | |
| print(status, file=sys.stderr) | |
| audio_queue.put(indata.copy()) | |
| def compute_mel(audio: np.ndarray, hann_window: np.ndarray, mel_filterbank: np.ndarray) -> np.ndarray: | |
| """Compute log-mel spectrogram matching NeMo/Canary preprocessing exactly. | |
| NeMo config: normalize=per_feature, log (natural), log_zero_guard_value=1e-5. | |
| Common mistakes: log10 instead of ln, and global instead of per-feature normalization. | |
| """ | |
| n_fft = (mel_filterbank.shape[1] - 1) * 2 | |
| _, _, spec = scipy_stft( | |
| audio, | |
| fs=SAMPLE_RATE, | |
| window=hann_window, | |
| nperseg=len(hann_window), | |
| noverlap=len(hann_window) - HOP_SAMPLES, | |
| nfft=n_fft, | |
| boundary=None, | |
| padded=False, | |
| ) | |
| power = np.abs(spec) ** 2 # [n_fft//2+1, T] | |
| mel = mel_filterbank @ power # [128, T] | |
| log_mel = np.log(mel + 1e-5) # natural log, additive guard | |
| # Per-feature normalization: each of the 128 mel bins normalized independently over time | |
| mean = log_mel.mean(axis=-1, keepdims=True) # [128, 1] | |
| std = log_mel.std(axis=-1, keepdims=True) # [128, 1] | |
| log_mel = (log_mel - mean) / (std + 1e-8) | |
| return log_mel.astype(np.float32) | |
| def first_value(d: dict) -> np.ndarray: | |
| return next(iter(d.values())) | |
| def compute_rope(position: int) -> tuple[np.ndarray, np.ndarray]: | |
| """RoPE cosine/sine embeddings for one position, shape [1, 1, ROPE_HEAD_DIM].""" | |
| inv_freq = 1.0 / (ROPE_BASE ** (np.arange(0, ROPE_HEAD_DIM, 2, dtype=np.float32) / ROPE_HEAD_DIM)) | |
| theta = position * inv_freq # [HEAD_DIM//2] | |
| cos_half, sin_half = np.cos(theta), np.sin(theta) | |
| cos = np.concatenate([cos_half, cos_half])[None, None, :] # [1, 1, HEAD_DIM] | |
| sin = np.concatenate([sin_half, sin_half])[None, None, :] | |
| return cos, sin | |
| def decoder_step( | |
| decoder: ct.models.MLModel, | |
| embed: np.ndarray, # [1, 1, 2048] | |
| position: int, | |
| state, | |
| ) -> np.ndarray: # returns [1, 1, 2048] | |
| cos, sin = compute_rope(position) | |
| result = decoder.predict({ | |
| "hidden_states": embed, | |
| "position_cos": cos, | |
| "position_sin": sin, | |
| "attention_mask": np.zeros([1, 1, 1, 1], dtype=np.float32), | |
| }, state=state) | |
| return first_value(result) # [1, 1, 2048] | |
| def transcribe( | |
| audio: np.ndarray, | |
| hann_window: np.ndarray, | |
| mel_filterbank: np.ndarray, | |
| encoder: ct.models.MLModel, | |
| projection: ct.models.MLModel, | |
| decoder: ct.models.MLModel, | |
| lm_head: ct.models.MLModel, | |
| embeddings: np.ndarray, | |
| tokenizer, | |
| prefix_ids: list, | |
| eos_id: int, | |
| ) -> str: | |
| # 1. Mel spectrogram β encoder | |
| # Pad or truncate to produce exactly TARGET_FRAMES (fixed encoder input shape). | |
| audio = audio[:TARGET_SAMPLES] | |
| if len(audio) < TARGET_SAMPLES: | |
| audio = np.pad(audio, (0, TARGET_SAMPLES - len(audio))) | |
| mel = compute_mel(audio, hann_window, mel_filterbank)[np.newaxis] # [1, 128, 500] | |
| enc_out = first_value(encoder.predict({"audio_signal": mel})) # [1, 63, 1024] | |
| # 2. Project all audio frames in one batch call | |
| audio_embeds = first_value(projection.predict({"encoder_output": enc_out}))[0] # [T_audio, 2048] | |
| # 3. Build full prefill sequence: prompt prefix tokens + audio frame embeddings | |
| prefix_embeds = embeddings[prefix_ids].astype(np.float32) # [N_prefix, 2048] | |
| input_embeds = np.concatenate([prefix_embeds, audio_embeds], axis=0) # [N_total, 2048] | |
| # 4. Prefill: feed each token sequentially into the stateful decoder | |
| state = decoder.make_state() | |
| last_hidden = None | |
| for pos, embed in enumerate(input_embeds): | |
| token = embed[np.newaxis, np.newaxis, :].astype(np.float32) # [1, 1, 2048] | |
| last_hidden = decoder_step(decoder, token, pos, state) | |
| # 5. Greedy generation (cap at KV cache space remaining) | |
| max_new = min(MAX_NEW_TOKENS, KV_CACHE_MAX - len(input_embeds)) | |
| generated_ids = [] | |
| pos = len(input_embeds) | |
| for _ in range(max_new): | |
| logits = first_value(lm_head.predict({"hidden_states": last_hidden})) | |
| next_id = int(np.argmax(logits.flatten())) | |
| if next_id == eos_id: | |
| break | |
| generated_ids.append(next_id) | |
| # Stop if the last 8 tokens are all identical (repetition loop) | |
| if len(generated_ids) >= 8 and len(set(generated_ids[-8:])) == 1: | |
| break | |
| next_embed = embeddings[next_id][np.newaxis, np.newaxis, :].astype(np.float32) | |
| last_hidden = decoder_step(decoder, next_embed, pos, state) | |
| pos += 1 | |
| return tokenizer.decode(generated_ids, skip_special_tokens=True) | |
| def pre_warm(encoder, projection, decoder, lm_head): | |
| """Trigger CoreML JIT compilation for all models with dummy inputs. | |
| The first predict() call per model compiles it (~54s for decoder). | |
| Doing this at startup avoids a silent freeze on the first audio chunk.""" | |
| print("Pre-warming models (CoreML JIT compilation, ~60s)...") | |
| encoder.predict({"audio_signal": np.zeros([1, 128, 500], dtype=np.float32)}) | |
| projection.predict({"encoder_output": np.zeros([1, ENCODER_AUDIO_TOKENS, 1024], dtype=np.float32)}) | |
| cos, sin = compute_rope(0) | |
| state = decoder.make_state() | |
| decoder.predict({ | |
| "hidden_states": np.zeros([1, 1, 2048], dtype=np.float32), | |
| "position_cos": cos, "position_sin": sin, | |
| "attention_mask": np.zeros([1, 1, 1, 1], dtype=np.float32), | |
| }, state=state) | |
| lm_head.predict({"hidden_states": np.zeros([1, 1, 2048], dtype=np.float32)}) | |
| print("Models ready.") | |
| def load_artifacts(model_dir: str): | |
| print("Loading CoreML models...") | |
| encoder = ct.models.MLModel(os.path.join(model_dir, "encoder_int8.mlpackage")) | |
| projection = ct.models.MLModel(os.path.join(model_dir, "projection.mlpackage")) | |
| decoder = ct.models.MLModel(os.path.join(model_dir, "canary_decoder_stateful_int8.mlpackage")) | |
| lm_head = ct.models.MLModel(os.path.join(model_dir, "canary_lm_head_int8.mlpackage")) | |
| print("Loading binary artifacts...") | |
| hann_window = np.fromfile(os.path.join(model_dir, "canary_mel_window.bin"), dtype=np.float32) | |
| mel_fb_flat = np.fromfile(os.path.join(model_dir, "canary_mel_filter_bank.bin"), dtype=np.float32) | |
| mel_filterbank = mel_fb_flat.reshape(N_MELS, mel_fb_flat.size // N_MELS) # [128, n_fft//2+1] | |
| # Embeddings stored as float16: shape [vocab_size, EMBED_DIM] | |
| # The file has 4 stray bytes of padding at the end β truncate to nearest row. | |
| raw = np.fromfile(os.path.join(model_dir, "canary_embeddings.bin"), dtype=np.float16) | |
| raw = raw[:raw.size - raw.size % EMBED_DIM] | |
| embeddings = raw.reshape(raw.size // EMBED_DIM, EMBED_DIM).astype(np.float32) | |
| return encoder, projection, decoder, lm_head, hann_window, mel_filterbank, embeddings | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Live ASR with Canary-Qwen-2.5B (CoreML INT8)") | |
| parser.add_argument("--dry-run", action="store_true", | |
| help="Download, load models, receive one audio chunk, then exit 0") | |
| args = parser.parse_args() | |
| print(f"Downloading/verifying {MODEL_REPO}...") | |
| # Use local_dir so files land as real files (not HF cache symlinks). | |
| # CoreML fails to compile .mlpackage bundles that contain relative symlinks. | |
| model_dir = snapshot_download(MODEL_REPO, local_dir=MODEL_LOCAL_DIR) | |
| encoder, projection, decoder, lm_head, hann_window, mel_filterbank, embeddings = load_artifacts(model_dir) | |
| pre_warm(encoder, projection, decoder, lm_head) | |
| print(f"Loading tokenizer from {TOKENIZER_REPO}...") | |
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_REPO) | |
| prefix_ids = tokenizer.encode(PROMPT_PREFIX, add_special_tokens=False) | |
| eos_id = tokenizer.eos_token_id | |
| print(f"Listening in {CHUNK_SECONDS}s chunks... (Ctrl+C to stop)") | |
| buffer = [] | |
| try: | |
| with sd.InputStream(samplerate=SAMPLE_RATE, channels=CHANNELS, callback=audio_callback): | |
| while True: | |
| chunk = audio_queue.get() | |
| buffer.append(chunk) | |
| if sum(len(c) for c in buffer) >= SAMPLE_RATE * CHUNK_SECONDS: | |
| if args.dry_run: | |
| print("Dry run complete.") | |
| sys.exit(0) | |
| audio_data = np.concatenate(buffer).flatten() | |
| text = transcribe( | |
| audio_data, hann_window, mel_filterbank, | |
| encoder, projection, decoder, lm_head, | |
| embeddings, tokenizer, prefix_ids, eos_id, | |
| ) | |
| print(f">>> {text}") | |
| buffer = [] | |
| except KeyboardInterrupt: | |
| print("\nSession ended.") | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # requires-python = ">=3.12" | |
| # dependencies = [ | |
| # "mlx", | |
| # "mlx-audio", | |
| # "sounddevice", | |
| # "numpy", | |
| # "scipy", | |
| # ] | |
| # /// | |
| import numpy as np | |
| import mlx.core as mx | |
| import sounddevice as sd | |
| from mlx_audio.stt import load | |
| import queue | |
| import sys | |
| # --- Configuration --- | |
| # FireRedASR2-AED is ideal for high-speed Mandarin/English live transcription | |
| MODEL_ID = "mlx-community/FireRedASR2-AED-mlx" | |
| SAMPLE_RATE = 16000 | |
| CHANNELS = 1 | |
| CHUNK_SECONDS = 5 | |
| audio_queue = queue.Queue() | |
| def audio_callback(indata, frames, time, status): | |
| if status: print(status, file=sys.stderr) | |
| audio_queue.put(indata.copy()) | |
| def transcribe_live(): | |
| print(f"π Loading {MODEL_ID} on M4 Max (MLX Native)...") | |
| # Load model via mlx-audio for high-performance transcription | |
| model = load(MODEL_ID) | |
| buffer = [] | |
| print(f"π€ Listening in {CHUNK_SECONDS}s chunks... (Ctrl+C to stop)") | |
| try: | |
| with sd.InputStream(samplerate=SAMPLE_RATE, channels=CHANNELS, callback=audio_callback): | |
| while True: | |
| chunk = audio_queue.get() | |
| buffer.append(chunk) | |
| # Process every 5 seconds | |
| if sum(len(c) for c in buffer) >= SAMPLE_RATE * CHUNK_SECONDS: | |
| audio_data = np.concatenate(buffer).flatten() | |
| # FIX: Convert numpy array to mlx array to avoid as_strided() TypeError | |
| mlx_audio = mx.array(audio_data) | |
| # FireRedASR-AED excels at bilingual Mandarin/English | |
| # Language can be set to "Mandarin", "English", or "Auto" | |
| result = model.generate(mlx_audio, language="Auto") | |
| print(f"π {result.text}") | |
| buffer = [] | |
| except KeyboardInterrupt: | |
| print("\nπ Session Ended.") | |
| if __name__ == "__main__": | |
| transcribe_live() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # requires-python = ">=3.12" | |
| # dependencies = [ | |
| # "mlx", | |
| # "parakeet-mlx", | |
| # "sounddevice", | |
| # "numpy", | |
| # "scipy", | |
| # ] | |
| # /// | |
| import numpy as np | |
| import mlx.core as mx | |
| import sounddevice as sd | |
| from parakeet_mlx import from_pretrained | |
| import queue | |
| import sys | |
| # --- Configuration --- | |
| MODEL_ID = "mlx-community/parakeet-tdt-0.6b-v3" | |
| SAMPLE_RATE = 16000 | |
| CHANNELS = 1 | |
| CHUNK_SECONDS = 5 | |
| audio_queue = queue.Queue() | |
| def audio_callback(indata, frames, time, status): | |
| if status: print(status, file=sys.stderr) | |
| audio_queue.put(indata.copy()) | |
| def transcribe_live(): | |
| print(f"π Loading {MODEL_ID} on M4 Max (MLX Native)...") | |
| # Load model with bfloat16 for optimal M4 Max performance | |
| model = from_pretrained(MODEL_ID) | |
| buffer = [] | |
| print(f"π€ Listening in {CHUNK_SECONDS}s chunks... (Ctrl+C to stop)") | |
| try: | |
| with sd.InputStream(samplerate=SAMPLE_RATE, channels=CHANNELS, callback=audio_callback): | |
| while True: | |
| chunk = audio_queue.get() | |
| buffer.append(chunk) | |
| # Accumulate 5 seconds | |
| if sum(len(c) for c in buffer) >= SAMPLE_RATE * CHUNK_SECONDS: | |
| audio_data = np.concatenate(buffer).flatten() | |
| # Convert to MLX array | |
| mlx_audio = mx.array(audio_data) | |
| # FIX: Use transcribe_stream for array-based transcription. | |
| # We open a fresh stream for each chunk to maintain independent 5s blocks. | |
| with model.transcribe_stream() as transcriber: | |
| transcriber.add_audio(mlx_audio) | |
| # Access text via the result object | |
| text = transcriber.result.text | |
| if text: | |
| print(f"π {text}") | |
| buffer = [] | |
| except KeyboardInterrupt: | |
| print("\nπ Session Ended.") | |
| if __name__ == "__main__": | |
| transcribe_live() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # requires-python = ">=3.12" | |
| # dependencies = [ | |
| # "mlx-audio", | |
| # "sounddevice", | |
| # "numpy", | |
| # "scipy", | |
| # ] | |
| # /// | |
| import numpy as np | |
| import sounddevice as sd | |
| from mlx_audio.stt import load | |
| import queue | |
| import sys | |
| # --- Configuration --- | |
| MODEL_ID = "mlx-community/Qwen3-ASR-0.6B-8bit" | |
| SAMPLE_RATE = 16000 | |
| CHANNELS = 1 | |
| audio_queue = queue.Queue() | |
| def audio_callback(indata, frames, time, status): | |
| if status: print(status, file=sys.stderr) | |
| audio_queue.put(indata.copy()) | |
| def stream_transcribe(): | |
| print(f"π Loading Qwen3-ASR Native MLX Streaming Model...") | |
| model = load(MODEL_ID) | |
| buffer = [] | |
| print("π€ Listening. Speak naturally... (Ctrl+C to stop)") | |
| try: | |
| with sd.InputStream(samplerate=SAMPLE_RATE, channels=CHANNELS, callback=audio_callback): | |
| while True: | |
| chunk = audio_queue.get() | |
| buffer.append(chunk) | |
| if sum(len(c) for c in buffer) >= SAMPLE_RATE * 5: | |
| audio_data = np.concatenate(buffer).flatten() | |
| # Process the generator | |
| needs_newline = False | |
| for result in model.stream_transcribe(audio_data, language="English"): | |
| if result.text.strip(): | |
| if not result.is_final: | |
| print(f"\rπ€ {result.text}", end="", flush=True) | |
| needs_newline = True | |
| else: | |
| print(f"\rβ {result.text}", flush=True) | |
| needs_newline = False | |
| if needs_newline: | |
| print(flush=True) | |
| buffer = [] | |
| except KeyboardInterrupt: | |
| print("\nπ Session Ended.") | |
| if __name__ == "__main__": | |
| stream_transcribe() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # requires-python = ">=3.12" | |
| # dependencies = [ | |
| # "mlx-whisper", | |
| # "sounddevice", | |
| # "numpy", | |
| # ] | |
| # /// | |
| import numpy as np | |
| import sounddevice as sd | |
| import mlx_whisper | |
| import queue | |
| import sys | |
| # --- Configuration --- | |
| # Use "mlx-community/whisper-large-v3-turbo" or "mlx-community/Qwen3-ASR-0.6B-MLX" | |
| MODEL_PATH = "mlx-community/whisper-large-v3-turbo" | |
| SAMPLE_RATE = 16000 | |
| CHANNELS = 1 | |
| audio_queue = queue.Queue() | |
| def audio_callback(indata, frames, time, status): | |
| if status: print(status, file=sys.stderr) | |
| audio_queue.put(indata.copy()) | |
| def transcribe_live(): | |
| print(f"π Initializing {MODEL_PATH} via uv...") | |
| buffer = [] | |
| try: | |
| with sd.InputStream(samplerate=SAMPLE_RATE, channels=CHANNELS, callback=audio_callback): | |
| print("π€ Listening... (Ctrl+C to stop)") | |
| while True: | |
| chunk = audio_queue.get() | |
| buffer.append(chunk) | |
| # Transcribe in 5-second increments for low latency | |
| if sum(len(c) for c in buffer) >= SAMPLE_RATE * 5: | |
| audio_data = np.concatenate(buffer).flatten() | |
| result = mlx_whisper.transcribe( | |
| audio_data, | |
| path_or_hf_repo=MODEL_PATH, | |
| fp16=True | |
| ) | |
| print(f"\rπ Transcript: {result['text']}", flush=True) | |
| buffer = [] | |
| except KeyboardInterrupt: | |
| print("\nπ Session Ended.") | |
| if __name__ == "__main__": | |
| transcribe_live() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment