Skip to content

Instantly share code, notes, and snippets.

@briankung
Last active March 28, 2026 14:43
Show Gist options
  • Select an option

  • Save briankung/7afbf743987df6b5b477e9b58b1d6d53 to your computer and use it in GitHub Desktop.

Select an option

Save briankung/7afbf743987df6b5b477e9b58b1d6d53 to your computer and use it in GitHub Desktop.
Testing live transcription on my M4 Max MacBook Pro (64GB). All runnable with `uv run <SCRIPT_NAME>`
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "coremltools>=8.0",
# "huggingface_hub",
# "numpy",
# "scipy",
# "sounddevice",
# "tiktoken",
# "transformers",
# ]
# ///
import argparse
import os
import queue
import sys
import coremltools as ct
import numpy as np
import sounddevice as sd
from huggingface_hub import snapshot_download
from scipy.signal import stft as scipy_stft
from transformers import AutoTokenizer
# --- Configuration ---
MODEL_REPO = "phequals/canary-qwen-2.5b-coreml-int8"
MODEL_LOCAL_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".canary-model")
TOKENIZER_REPO = "Qwen/Qwen3-1.7B"
SAMPLE_RATE = 16000
CHANNELS = 1
CHUNK_SECONDS = 5
WIN_SAMPLES = 400 # 25ms Hann window at 16kHz
HOP_SAMPLES = 160 # 10ms hop at 16kHz
N_MELS = 128
EMBED_DIM = 2048
MAX_NEW_TOKENS = 150
# Encoder expects exactly 500 mel frames (fixed CoreML input shape).
# Required audio length: (500 - 1) * 160 + 400 = 80,240 samples β‰ˆ 5.015s
TARGET_FRAMES = 500
TARGET_SAMPLES = (TARGET_FRAMES - 1) * HOP_SAMPLES + WIN_SAMPLES # 80,240
ENCODER_AUDIO_TOKENS = 63 # encoder output frames for TARGET_FRAMES input
# RoPE parameters matching Qwen3-1.7B (head_dim=128, theta=10000)
ROPE_HEAD_DIM = 128
ROPE_BASE = 10000.0
# KV cache max sequence length (from decoder state shape [1, 8, 256, 128])
KV_CACHE_MAX = 256
# Prompt prefix (text before the audio embeddings are inserted)
PROMPT_PREFIX = "Transcribe the following: "
audio_queue: queue.Queue = queue.Queue()
def audio_callback(indata, frames, time, status):
if status:
print(status, file=sys.stderr)
audio_queue.put(indata.copy())
def compute_mel(audio: np.ndarray, hann_window: np.ndarray, mel_filterbank: np.ndarray) -> np.ndarray:
"""Compute log-mel spectrogram matching NeMo/Canary preprocessing exactly.
NeMo config: normalize=per_feature, log (natural), log_zero_guard_value=1e-5.
Common mistakes: log10 instead of ln, and global instead of per-feature normalization.
"""
n_fft = (mel_filterbank.shape[1] - 1) * 2
_, _, spec = scipy_stft(
audio,
fs=SAMPLE_RATE,
window=hann_window,
nperseg=len(hann_window),
noverlap=len(hann_window) - HOP_SAMPLES,
nfft=n_fft,
boundary=None,
padded=False,
)
power = np.abs(spec) ** 2 # [n_fft//2+1, T]
mel = mel_filterbank @ power # [128, T]
log_mel = np.log(mel + 1e-5) # natural log, additive guard
# Per-feature normalization: each of the 128 mel bins normalized independently over time
mean = log_mel.mean(axis=-1, keepdims=True) # [128, 1]
std = log_mel.std(axis=-1, keepdims=True) # [128, 1]
log_mel = (log_mel - mean) / (std + 1e-8)
return log_mel.astype(np.float32)
def first_value(d: dict) -> np.ndarray:
return next(iter(d.values()))
def compute_rope(position: int) -> tuple[np.ndarray, np.ndarray]:
"""RoPE cosine/sine embeddings for one position, shape [1, 1, ROPE_HEAD_DIM]."""
inv_freq = 1.0 / (ROPE_BASE ** (np.arange(0, ROPE_HEAD_DIM, 2, dtype=np.float32) / ROPE_HEAD_DIM))
theta = position * inv_freq # [HEAD_DIM//2]
cos_half, sin_half = np.cos(theta), np.sin(theta)
cos = np.concatenate([cos_half, cos_half])[None, None, :] # [1, 1, HEAD_DIM]
sin = np.concatenate([sin_half, sin_half])[None, None, :]
return cos, sin
def decoder_step(
decoder: ct.models.MLModel,
embed: np.ndarray, # [1, 1, 2048]
position: int,
state,
) -> np.ndarray: # returns [1, 1, 2048]
cos, sin = compute_rope(position)
result = decoder.predict({
"hidden_states": embed,
"position_cos": cos,
"position_sin": sin,
"attention_mask": np.zeros([1, 1, 1, 1], dtype=np.float32),
}, state=state)
return first_value(result) # [1, 1, 2048]
def transcribe(
audio: np.ndarray,
hann_window: np.ndarray,
mel_filterbank: np.ndarray,
encoder: ct.models.MLModel,
projection: ct.models.MLModel,
decoder: ct.models.MLModel,
lm_head: ct.models.MLModel,
embeddings: np.ndarray,
tokenizer,
prefix_ids: list,
eos_id: int,
) -> str:
# 1. Mel spectrogram β†’ encoder
# Pad or truncate to produce exactly TARGET_FRAMES (fixed encoder input shape).
audio = audio[:TARGET_SAMPLES]
if len(audio) < TARGET_SAMPLES:
audio = np.pad(audio, (0, TARGET_SAMPLES - len(audio)))
mel = compute_mel(audio, hann_window, mel_filterbank)[np.newaxis] # [1, 128, 500]
enc_out = first_value(encoder.predict({"audio_signal": mel})) # [1, 63, 1024]
# 2. Project all audio frames in one batch call
audio_embeds = first_value(projection.predict({"encoder_output": enc_out}))[0] # [T_audio, 2048]
# 3. Build full prefill sequence: prompt prefix tokens + audio frame embeddings
prefix_embeds = embeddings[prefix_ids].astype(np.float32) # [N_prefix, 2048]
input_embeds = np.concatenate([prefix_embeds, audio_embeds], axis=0) # [N_total, 2048]
# 4. Prefill: feed each token sequentially into the stateful decoder
state = decoder.make_state()
last_hidden = None
for pos, embed in enumerate(input_embeds):
token = embed[np.newaxis, np.newaxis, :].astype(np.float32) # [1, 1, 2048]
last_hidden = decoder_step(decoder, token, pos, state)
# 5. Greedy generation (cap at KV cache space remaining)
max_new = min(MAX_NEW_TOKENS, KV_CACHE_MAX - len(input_embeds))
generated_ids = []
pos = len(input_embeds)
for _ in range(max_new):
logits = first_value(lm_head.predict({"hidden_states": last_hidden}))
next_id = int(np.argmax(logits.flatten()))
if next_id == eos_id:
break
generated_ids.append(next_id)
# Stop if the last 8 tokens are all identical (repetition loop)
if len(generated_ids) >= 8 and len(set(generated_ids[-8:])) == 1:
break
next_embed = embeddings[next_id][np.newaxis, np.newaxis, :].astype(np.float32)
last_hidden = decoder_step(decoder, next_embed, pos, state)
pos += 1
return tokenizer.decode(generated_ids, skip_special_tokens=True)
def pre_warm(encoder, projection, decoder, lm_head):
"""Trigger CoreML JIT compilation for all models with dummy inputs.
The first predict() call per model compiles it (~54s for decoder).
Doing this at startup avoids a silent freeze on the first audio chunk."""
print("Pre-warming models (CoreML JIT compilation, ~60s)...")
encoder.predict({"audio_signal": np.zeros([1, 128, 500], dtype=np.float32)})
projection.predict({"encoder_output": np.zeros([1, ENCODER_AUDIO_TOKENS, 1024], dtype=np.float32)})
cos, sin = compute_rope(0)
state = decoder.make_state()
decoder.predict({
"hidden_states": np.zeros([1, 1, 2048], dtype=np.float32),
"position_cos": cos, "position_sin": sin,
"attention_mask": np.zeros([1, 1, 1, 1], dtype=np.float32),
}, state=state)
lm_head.predict({"hidden_states": np.zeros([1, 1, 2048], dtype=np.float32)})
print("Models ready.")
def load_artifacts(model_dir: str):
print("Loading CoreML models...")
encoder = ct.models.MLModel(os.path.join(model_dir, "encoder_int8.mlpackage"))
projection = ct.models.MLModel(os.path.join(model_dir, "projection.mlpackage"))
decoder = ct.models.MLModel(os.path.join(model_dir, "canary_decoder_stateful_int8.mlpackage"))
lm_head = ct.models.MLModel(os.path.join(model_dir, "canary_lm_head_int8.mlpackage"))
print("Loading binary artifacts...")
hann_window = np.fromfile(os.path.join(model_dir, "canary_mel_window.bin"), dtype=np.float32)
mel_fb_flat = np.fromfile(os.path.join(model_dir, "canary_mel_filter_bank.bin"), dtype=np.float32)
mel_filterbank = mel_fb_flat.reshape(N_MELS, mel_fb_flat.size // N_MELS) # [128, n_fft//2+1]
# Embeddings stored as float16: shape [vocab_size, EMBED_DIM]
# The file has 4 stray bytes of padding at the end β€” truncate to nearest row.
raw = np.fromfile(os.path.join(model_dir, "canary_embeddings.bin"), dtype=np.float16)
raw = raw[:raw.size - raw.size % EMBED_DIM]
embeddings = raw.reshape(raw.size // EMBED_DIM, EMBED_DIM).astype(np.float32)
return encoder, projection, decoder, lm_head, hann_window, mel_filterbank, embeddings
def main():
parser = argparse.ArgumentParser(description="Live ASR with Canary-Qwen-2.5B (CoreML INT8)")
parser.add_argument("--dry-run", action="store_true",
help="Download, load models, receive one audio chunk, then exit 0")
args = parser.parse_args()
print(f"Downloading/verifying {MODEL_REPO}...")
# Use local_dir so files land as real files (not HF cache symlinks).
# CoreML fails to compile .mlpackage bundles that contain relative symlinks.
model_dir = snapshot_download(MODEL_REPO, local_dir=MODEL_LOCAL_DIR)
encoder, projection, decoder, lm_head, hann_window, mel_filterbank, embeddings = load_artifacts(model_dir)
pre_warm(encoder, projection, decoder, lm_head)
print(f"Loading tokenizer from {TOKENIZER_REPO}...")
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_REPO)
prefix_ids = tokenizer.encode(PROMPT_PREFIX, add_special_tokens=False)
eos_id = tokenizer.eos_token_id
print(f"Listening in {CHUNK_SECONDS}s chunks... (Ctrl+C to stop)")
buffer = []
try:
with sd.InputStream(samplerate=SAMPLE_RATE, channels=CHANNELS, callback=audio_callback):
while True:
chunk = audio_queue.get()
buffer.append(chunk)
if sum(len(c) for c in buffer) >= SAMPLE_RATE * CHUNK_SECONDS:
if args.dry_run:
print("Dry run complete.")
sys.exit(0)
audio_data = np.concatenate(buffer).flatten()
text = transcribe(
audio_data, hann_window, mel_filterbank,
encoder, projection, decoder, lm_head,
embeddings, tokenizer, prefix_ids, eos_id,
)
print(f">>> {text}")
buffer = []
except KeyboardInterrupt:
print("\nSession ended.")
if __name__ == "__main__":
main()
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "mlx",
# "mlx-audio",
# "sounddevice",
# "numpy",
# "scipy",
# ]
# ///
import numpy as np
import mlx.core as mx
import sounddevice as sd
from mlx_audio.stt import load
import queue
import sys
# --- Configuration ---
# FireRedASR2-AED is ideal for high-speed Mandarin/English live transcription
MODEL_ID = "mlx-community/FireRedASR2-AED-mlx"
SAMPLE_RATE = 16000
CHANNELS = 1
CHUNK_SECONDS = 5
audio_queue = queue.Queue()
def audio_callback(indata, frames, time, status):
if status: print(status, file=sys.stderr)
audio_queue.put(indata.copy())
def transcribe_live():
print(f"πŸš€ Loading {MODEL_ID} on M4 Max (MLX Native)...")
# Load model via mlx-audio for high-performance transcription
model = load(MODEL_ID)
buffer = []
print(f"🎀 Listening in {CHUNK_SECONDS}s chunks... (Ctrl+C to stop)")
try:
with sd.InputStream(samplerate=SAMPLE_RATE, channels=CHANNELS, callback=audio_callback):
while True:
chunk = audio_queue.get()
buffer.append(chunk)
# Process every 5 seconds
if sum(len(c) for c in buffer) >= SAMPLE_RATE * CHUNK_SECONDS:
audio_data = np.concatenate(buffer).flatten()
# FIX: Convert numpy array to mlx array to avoid as_strided() TypeError
mlx_audio = mx.array(audio_data)
# FireRedASR-AED excels at bilingual Mandarin/English
# Language can be set to "Mandarin", "English", or "Auto"
result = model.generate(mlx_audio, language="Auto")
print(f"πŸ“ {result.text}")
buffer = []
except KeyboardInterrupt:
print("\nπŸ›‘ Session Ended.")
if __name__ == "__main__":
transcribe_live()
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "mlx",
# "parakeet-mlx",
# "sounddevice",
# "numpy",
# "scipy",
# ]
# ///
import numpy as np
import mlx.core as mx
import sounddevice as sd
from parakeet_mlx import from_pretrained
import queue
import sys
# --- Configuration ---
MODEL_ID = "mlx-community/parakeet-tdt-0.6b-v3"
SAMPLE_RATE = 16000
CHANNELS = 1
CHUNK_SECONDS = 5
audio_queue = queue.Queue()
def audio_callback(indata, frames, time, status):
if status: print(status, file=sys.stderr)
audio_queue.put(indata.copy())
def transcribe_live():
print(f"πŸš€ Loading {MODEL_ID} on M4 Max (MLX Native)...")
# Load model with bfloat16 for optimal M4 Max performance
model = from_pretrained(MODEL_ID)
buffer = []
print(f"🎀 Listening in {CHUNK_SECONDS}s chunks... (Ctrl+C to stop)")
try:
with sd.InputStream(samplerate=SAMPLE_RATE, channels=CHANNELS, callback=audio_callback):
while True:
chunk = audio_queue.get()
buffer.append(chunk)
# Accumulate 5 seconds
if sum(len(c) for c in buffer) >= SAMPLE_RATE * CHUNK_SECONDS:
audio_data = np.concatenate(buffer).flatten()
# Convert to MLX array
mlx_audio = mx.array(audio_data)
# FIX: Use transcribe_stream for array-based transcription.
# We open a fresh stream for each chunk to maintain independent 5s blocks.
with model.transcribe_stream() as transcriber:
transcriber.add_audio(mlx_audio)
# Access text via the result object
text = transcriber.result.text
if text:
print(f"πŸ“ {text}")
buffer = []
except KeyboardInterrupt:
print("\nπŸ›‘ Session Ended.")
if __name__ == "__main__":
transcribe_live()
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "mlx-audio",
# "sounddevice",
# "numpy",
# "scipy",
# ]
# ///
import numpy as np
import sounddevice as sd
from mlx_audio.stt import load
import queue
import sys
# --- Configuration ---
MODEL_ID = "mlx-community/Qwen3-ASR-0.6B-8bit"
SAMPLE_RATE = 16000
CHANNELS = 1
audio_queue = queue.Queue()
def audio_callback(indata, frames, time, status):
if status: print(status, file=sys.stderr)
audio_queue.put(indata.copy())
def stream_transcribe():
print(f"πŸš€ Loading Qwen3-ASR Native MLX Streaming Model...")
model = load(MODEL_ID)
buffer = []
print("🎀 Listening. Speak naturally... (Ctrl+C to stop)")
try:
with sd.InputStream(samplerate=SAMPLE_RATE, channels=CHANNELS, callback=audio_callback):
while True:
chunk = audio_queue.get()
buffer.append(chunk)
if sum(len(c) for c in buffer) >= SAMPLE_RATE * 5:
audio_data = np.concatenate(buffer).flatten()
# Process the generator
needs_newline = False
for result in model.stream_transcribe(audio_data, language="English"):
if result.text.strip():
if not result.is_final:
print(f"\r🎀 {result.text}", end="", flush=True)
needs_newline = True
else:
print(f"\rβœ… {result.text}", flush=True)
needs_newline = False
if needs_newline:
print(flush=True)
buffer = []
except KeyboardInterrupt:
print("\nπŸ›‘ Session Ended.")
if __name__ == "__main__":
stream_transcribe()
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "mlx-whisper",
# "sounddevice",
# "numpy",
# ]
# ///
import numpy as np
import sounddevice as sd
import mlx_whisper
import queue
import sys
# --- Configuration ---
# Use "mlx-community/whisper-large-v3-turbo" or "mlx-community/Qwen3-ASR-0.6B-MLX"
MODEL_PATH = "mlx-community/whisper-large-v3-turbo"
SAMPLE_RATE = 16000
CHANNELS = 1
audio_queue = queue.Queue()
def audio_callback(indata, frames, time, status):
if status: print(status, file=sys.stderr)
audio_queue.put(indata.copy())
def transcribe_live():
print(f"πŸš€ Initializing {MODEL_PATH} via uv...")
buffer = []
try:
with sd.InputStream(samplerate=SAMPLE_RATE, channels=CHANNELS, callback=audio_callback):
print("🎀 Listening... (Ctrl+C to stop)")
while True:
chunk = audio_queue.get()
buffer.append(chunk)
# Transcribe in 5-second increments for low latency
if sum(len(c) for c in buffer) >= SAMPLE_RATE * 5:
audio_data = np.concatenate(buffer).flatten()
result = mlx_whisper.transcribe(
audio_data,
path_or_hf_repo=MODEL_PATH,
fp16=True
)
print(f"\rπŸ“ Transcript: {result['text']}", flush=True)
buffer = []
except KeyboardInterrupt:
print("\nπŸ›‘ Session Ended.")
if __name__ == "__main__":
transcribe_live()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment