Skip to content

Instantly share code, notes, and snippets.

@Blaizzy
Created December 18, 2025 10:49
Show Gist options
  • Select an option

  • Save Blaizzy/fd96083c4c8ccd00cd039d819934315a to your computer and use it in GitHub Desktop.

Select an option

Save Blaizzy/fd96083c4c8ccd00cd039d819934315a to your computer and use it in GitHub Desktop.
Chattebox Turbo MLX port
#!/usr/bin/env python3
# Copyright (c) 2025 Resemble AI
# MIT License
# Weight conversion script: PyTorch -> MLX
"""
Converts Chatterbox Turbo weights from PyTorch to MLX format.
Usage:
python convert_weights.py --output model.safetensors
"""
import argparse
import os
from pathlib import Path
from typing import Dict, Any
import numpy as np
try:
import mlx.core as mx
from mlx.utils import tree_flatten
except ImportError:
raise ImportError("Please install mlx: pip install mlx")
try:
from safetensors import safe_open
from safetensors.numpy import save_file as save_safetensors
except ImportError:
raise ImportError("Please install safetensors: pip install safetensors")
def download_weights():
"""Download weights from HuggingFace Hub."""
try:
from huggingface_hub import snapshot_download
local_path = snapshot_download(
repo_id="ResembleAI/chatterbox-turbo",
token=os.getenv("HF_TOKEN") or True,
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.pt", "*.model"]
)
return Path(local_path)
except ImportError:
raise ImportError("Please install huggingface_hub: pip install huggingface_hub")
def load_pytorch_weights(path: Path) -> Dict[str, np.ndarray]:
"""Load weights from PyTorch safetensors file."""
weights = {}
with safe_open(path, framework="numpy") as f:
for key in f.keys():
weights[key] = f.get_tensor(key)
return weights
def convert_conv1d_weight(weight: np.ndarray) -> np.ndarray:
"""
Convert Conv1d weight from PyTorch to MLX format.
PyTorch: (out_channels, in_channels, kernel_size)
MLX: (out_channels, kernel_size, in_channels)
"""
if weight.ndim == 3:
return np.transpose(weight, (0, 2, 1))
return weight
def convert_conv_transpose1d_weight(weight: np.ndarray) -> np.ndarray:
"""
Convert ConvTranspose1d weight from PyTorch to MLX format.
PyTorch: (in_channels, out_channels, kernel_size)
MLX: (out_channels, kernel_size, in_channels)
"""
if weight.ndim == 3:
# First swap in/out channels, then transpose
return np.transpose(weight, (1, 2, 0))
return weight
def convert_lstm_weights(weights: Dict[str, np.ndarray], prefix: str) -> Dict[str, np.ndarray]:
"""
Convert LSTM weights from PyTorch to MLX format.
PyTorch LSTM has:
- weight_ih_l{layer}: (4*hidden, input)
- weight_hh_l{layer}: (4*hidden, hidden)
- bias_ih_l{layer}: (4*hidden,)
- bias_hh_l{layer}: (4*hidden,)
MLX LSTM has:
- Wx: (input, 4*hidden) - transposed
- Wh: (hidden, 4*hidden) - transposed
- bias: (4*hidden,) - sum of both biases
"""
converted = {}
for layer_idx in range(3): # We have 3 LSTM layers
pt_prefix = f"{prefix}.lstm{layer_idx+1}" if layer_idx > 0 else prefix
mlx_prefix = f"lstm{layer_idx+1}"
# Check for different naming conventions
weight_ih_key = None
weight_hh_key = None
bias_ih_key = None
bias_hh_key = None
for key in weights:
if f"weight_ih_l0" in key and pt_prefix in key:
weight_ih_key = key
elif f"weight_hh_l0" in key and pt_prefix in key:
weight_hh_key = key
elif f"bias_ih_l0" in key and pt_prefix in key:
bias_ih_key = key
elif f"bias_hh_l0" in key and pt_prefix in key:
bias_hh_key = key
if weight_ih_key:
# Wx: transpose from (4*hidden, input) to (input, 4*hidden)
converted[f"{mlx_prefix}.Wx"] = weights[weight_ih_key].T
if weight_hh_key:
# Wh: transpose from (4*hidden, hidden) to (hidden, 4*hidden)
converted[f"{mlx_prefix}.Wh"] = weights[weight_hh_key].T
if bias_ih_key and bias_hh_key:
# Combine biases
converted[f"{mlx_prefix}.bias"] = weights[bias_ih_key] + weights[bias_hh_key]
return converted
def map_ve_weights(pt_weights: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
"""Map Voice Encoder weights from PyTorch to MLX format."""
mlx_weights = {}
# Map 3 LSTM layers: l0 -> lstm1, l1 -> lstm2, l2 -> lstm3
# MLX LSTM uses same weight format as PyTorch: (4*hidden, input) for Wx
for layer_idx in range(3):
mlx_layer = f"lstm{layer_idx + 1}"
pt_layer = f"l{layer_idx}"
# Weight ih (input-hidden): same shape in both (4*hidden, input)
ih_key = f"lstm.weight_ih_{pt_layer}"
if ih_key in pt_weights:
mlx_weights[f"{mlx_layer}.Wx"] = pt_weights[ih_key]
# Weight hh (hidden-hidden): same shape in both (4*hidden, hidden)
hh_key = f"lstm.weight_hh_{pt_layer}"
if hh_key in pt_weights:
mlx_weights[f"{mlx_layer}.Wh"] = pt_weights[hh_key]
# Bias: combine ih and hh biases
bias_ih_key = f"lstm.bias_ih_{pt_layer}"
bias_hh_key = f"lstm.bias_hh_{pt_layer}"
if bias_ih_key in pt_weights and bias_hh_key in pt_weights:
mlx_weights[f"{mlx_layer}.bias"] = pt_weights[bias_ih_key] + pt_weights[bias_hh_key]
# Projection layer
if "proj.weight" in pt_weights:
mlx_weights["proj.weight"] = pt_weights["proj.weight"]
if "proj.bias" in pt_weights:
mlx_weights["proj.bias"] = pt_weights["proj.bias"]
# Similarity weight/bias (for training, but we include them)
if "similarity_weight" in pt_weights:
mlx_weights["similarity_weight"] = pt_weights["similarity_weight"]
if "similarity_bias" in pt_weights:
mlx_weights["similarity_bias"] = pt_weights["similarity_bias"]
return mlx_weights
def map_t3_weights(pt_weights: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
"""Map T3 weights from PyTorch to MLX format.
GPT-2 uses Conv1D which has weights in (in_features, out_features) format.
MLX Linear expects (out_features, in_features), so we need to transpose.
"""
mlx_weights = {}
for key, value in pt_weights.items():
new_key = key
new_value = value
# GPT-2 attention and MLP weights need transposition
# c_attn, c_proj, c_fc all use Conv1D style (in, out) -> need transpose to (out, in)
if "weight" in key and value.ndim == 2:
if any(x in key for x in ["c_attn", "c_proj", "c_fc"]):
new_value = value.T # Transpose to (out_features, in_features)
# Conv1d weights (3D) need transposition
if "conv" in key.lower() and "weight" in key and value.ndim == 3:
new_value = convert_conv1d_weight(value)
mlx_weights[new_key] = new_value
return mlx_weights
def map_s3gen_weights(pt_weights: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
"""Map S3Gen weights from PyTorch to MLX format.
Handles:
1. Removing 'flow.' prefix for flow components
2. Mapping encoder layer names (norm_ff -> norm2, norm_mha -> norm1, embed.out.0 -> embed.linear)
3. Reconstructing weights from weight normalization parametrizations
4. Mapping mel2wav weights with conv wrapper structure
5. Conv1d weight transposition
6. Decoder estimator block index remapping
"""
import re
mlx_weights = {}
# First pass: handle weight normalization - reconstruct weights from parametrizations
# The Chatterbox weight_norm structure is:
# original0 = g (magnitude per output channel, shape: out, 1, 1)
# original1 = v (direction/weight, shape: out, in, kernel)
# Actual weight = g * (v / ||v||) where norm is computed over (in, kernel) dims
weight_norm_weights = {}
original0_dict = {}
original1_dict = {}
for key, value in pt_weights.items():
if 'parametrizations.weight.original0' in key:
base_key = key.replace('.parametrizations.weight.original0', '.weight')
original0_dict[base_key] = value # magnitude (out, 1, 1)
elif 'parametrizations.weight.original1' in key:
base_key = key.replace('.parametrizations.weight.original1', '.weight')
original1_dict[base_key] = value # direction (out, in, kernel)
# Reconstruct weights: weight = g * (v / ||v||)
for base_key in original0_dict:
if base_key in original1_dict:
g = original0_dict[base_key] # magnitude (out, 1, 1) or similar
v = original1_dict[base_key] # direction (out, in, kernel) or similar
# Compute norm over all dims except the first (out_channels)
if v.ndim == 3:
# Conv1d: v is (out, in, kernel) - norm over (in, kernel)
v_norm = np.linalg.norm(v.reshape(v.shape[0], -1), axis=1, keepdims=True)
v_norm = v_norm.reshape(v.shape[0], 1, 1)
elif v.ndim == 2:
# Linear: v is (out, in) - norm over (in)
v_norm = np.linalg.norm(v, axis=1, keepdims=True)
else:
v_norm = np.linalg.norm(v)
# Compute fused weight: g * (v / ||v||)
# g is already shaped for broadcasting (out, 1, 1)
weight = g * (v / (v_norm + 1e-8))
weight_norm_weights[base_key] = weight
# Merge weight norm weights back
for key, value in weight_norm_weights.items():
pt_weights[key] = value
for key, value in pt_weights.items():
# Skip parametrization keys - we've already handled them
if 'parametrizations' in key:
continue
new_key = key
new_value = value
# Remove 'flow.' prefix - MLX model doesn't have it
if new_key.startswith('flow.'):
new_key = new_key[5:] # Remove 'flow.'
# Handle decoder estimator block index remapping
# PyTorch block.0 -> MLX block.0, PyTorch block.2 -> MLX block.1
if 'decoder.estimator' in new_key:
# Map PyTorch tuple-based block structure to MLX named attributes
# PyTorch: down_blocks.X.0 (resnet), .1.Y (transformer_blocks), .2 (downsample)
# MLX: down_blocks.X.resnet, .transformer_blocks.Y, .downsample
# down_blocks tuple -> named attrs
new_key = re.sub(r'\.down_blocks\.(\d+)\.0\.', r'.down_blocks.\1.resnet.', new_key)
new_key = re.sub(r'\.down_blocks\.(\d+)\.1\.(\d+)\.', r'.down_blocks.\1.transformer_blocks.\2.', new_key)
new_key = re.sub(r'\.down_blocks\.(\d+)\.2\.', r'.down_blocks.\1.downsample.', new_key)
# mid_blocks tuple -> named attrs (no downsample, just resnet + transformer_blocks)
new_key = re.sub(r'\.mid_blocks\.(\d+)\.0\.', r'.mid_blocks.\1.resnet.', new_key)
new_key = re.sub(r'\.mid_blocks\.(\d+)\.1\.(\d+)\.', r'.mid_blocks.\1.transformer_blocks.\2.', new_key)
# up_blocks tuple -> named attrs
new_key = re.sub(r'\.up_blocks\.(\d+)\.0\.', r'.up_blocks.\1.resnet.', new_key)
new_key = re.sub(r'\.up_blocks\.(\d+)\.1\.(\d+)\.', r'.up_blocks.\1.transformer_blocks.\2.', new_key)
new_key = re.sub(r'\.up_blocks\.(\d+)\.2\.', r'.up_blocks.\1.upsample.', new_key)
# CausalBlock1D: block.0 = CausalConv -> block.0.conv.conv
# block.2 = LayerNorm -> block.1 (no wrapper)
# Note: block.2 in PyTorch is LayerNorm with direct weight/bias
new_key = re.sub(r'\.block\.2\.', '.block.1.', new_key) # LayerNorm, no conv wrapper
new_key = re.sub(r'\.block\.0\.', '.block.0.conv.conv.', new_key) # CausalConv has wrappers
# mlp.1 -> mlp.0
new_key = re.sub(r'\.mlp\.1\.', '.mlp.0.', new_key)
# ff.net.2 -> ff.net.1
new_key = re.sub(r'\.ff\.net\.2\.', '.ff.net.1.', new_key)
# res_conv direct -> res_conv.conv
if '.res_conv.' in new_key and '.conv.' not in new_key.split('.res_conv.')[1]:
new_key = new_key.replace('.res_conv.', '.res_conv.conv.')
# final_proj direct -> final_proj.conv
if '.final_proj.' in new_key and '.conv.' not in new_key.split('.final_proj.')[1]:
new_key = new_key.replace('.final_proj.', '.final_proj.conv.')
# Handle downsample/upsample - these need .conv. wrapper for Conv1dPT or CausalConv
if '.downsample.' in new_key:
parts = new_key.split('.downsample.')
suffix = parts[1]
# Only add .conv. if it's a Conv1dPT/CausalConv (has weight/bias at the end)
if suffix in ['weight', 'bias']:
# This is CausalConv (last block) - needs .conv.conv. wrapper
new_key = parts[0] + '.downsample.conv.conv.' + suffix
elif not suffix.startswith('conv.'):
# Check if it's already wrapped
pass # Already handled by Downsample1D
if '.upsample.' in new_key:
parts = new_key.split('.upsample.')
suffix = parts[1]
if suffix in ['weight', 'bias']:
# This is CausalConv (last block) - needs .conv.conv. wrapper
new_key = parts[0] + '.upsample.conv.conv.' + suffix
# Encoder layer name mappings - MLX now uses same names as PyTorch
# norm_ff and norm_mha are kept as-is
# PyTorch: embed.out.0 (Linear), embed.out.1 (LayerNorm) -> MLX: embed.linear, embed.norm
new_key = new_key.replace('.embed.out.0.', '.embed.linear.')
new_key = new_key.replace('.embed.out.1.', '.embed.norm.')
# PyTorch: up_embed.out.0 (Linear), up_embed.out.1 (LayerNorm) -> MLX: up_embed.linear, up_embed.norm
new_key = new_key.replace('.up_embed.out.0.', '.up_embed.linear.')
new_key = new_key.replace('.up_embed.out.1.', '.up_embed.norm.')
# Handle mel2wav conv weights - MLX uses Conv1dPT wrapper with '.conv.' sub-module
if 'mel2wav' in new_key:
parts = new_key.split('.')
# Handle f0_predictor.condnet - PyTorch uses Sequential indices 0,2,4,6,8
# MLX uses list indices 0,1,2,3,4 with Conv1dPT wrapper
if 'f0_predictor.condnet' in new_key:
import re
match = re.search(r'f0_predictor\.condnet\.(\d+)\.', new_key)
if match:
pt_idx = int(match.group(1))
# PyTorch Sequential: conv at 0,2,4,6,8 -> MLX list: 0,1,2,3,4
mlx_idx = pt_idx // 2
new_key = re.sub(
r'f0_predictor\.condnet\.\d+\.',
f'f0_predictor.condnet.{mlx_idx}.conv.',
new_key
)
# Handle conv_pre, conv_post, ups, source_downs - add .conv. wrapper
elif any(x in new_key for x in ['conv_pre.', 'conv_post.', 'ups.', 'source_downs.']):
if parts[-1] in ['weight', 'bias'] and '.conv.' not in new_key:
parts.insert(-1, 'conv')
new_key = '.'.join(parts)
# Handle resblocks and source_resblocks
elif 'resblocks' in new_key:
if parts[-1] in ['weight', 'bias'] and '.conv.' not in new_key:
parts.insert(-1, 'conv')
new_key = '.'.join(parts)
# Conv1d weights (3D) need transposition from (out, in, kernel) to (out, kernel, in)
if "weight" in new_key and value.ndim == 3:
# mel2wav.ups.X layers are ConvTranspose1d
# Use "ups." with dot to avoid matching "upsample" (which is Conv1d)
if "mel2wav" in new_key and ".ups." in new_key:
# ConvTranspose1d: (in, out, kernel) -> (out, kernel, in)
new_value = convert_conv_transpose1d_weight(value)
else:
# Regular Conv1d: (out, in, kernel) -> (out, kernel, in)
new_value = convert_conv1d_weight(value)
# Linear weights - S3Gen uses standard PyTorch Linear which has same format as MLX
# (out_features, in_features), so NO transposition needed
# Note: PyTorch Linear weights are already (out, in) which matches MLX
mlx_weights[new_key] = new_value
return mlx_weights
def convert_all_weights(ckpt_dir: Path) -> Dict[str, np.ndarray]:
"""Convert all model weights to MLX format."""
all_weights = {}
# Voice Encoder
ve_path = ckpt_dir / "ve.safetensors"
if ve_path.exists():
print(f"Converting Voice Encoder weights from {ve_path}")
pt_weights = load_pytorch_weights(ve_path)
mlx_weights = map_ve_weights(pt_weights)
for k, v in mlx_weights.items():
all_weights[f"ve.{k}"] = v
print(f" Converted {len(mlx_weights)} VE weights")
# T3
t3_path = ckpt_dir / "t3_turbo_v1.safetensors"
if t3_path.exists():
print(f"Converting T3 weights from {t3_path}")
pt_weights = load_pytorch_weights(t3_path)
mlx_weights = map_t3_weights(pt_weights)
for k, v in mlx_weights.items():
all_weights[f"t3.{k}"] = v
print(f" Converted {len(mlx_weights)} T3 weights")
# S3Gen
s3gen_path = ckpt_dir / "s3gen_meanflow.safetensors"
if s3gen_path.exists():
print(f"Converting S3Gen weights from {s3gen_path}")
pt_weights = load_pytorch_weights(s3gen_path)
mlx_weights = map_s3gen_weights(pt_weights)
for k, v in mlx_weights.items():
all_weights[f"s3gen.{k}"] = v
print(f" Converted {len(mlx_weights)} S3Gen weights")
return all_weights
def main():
parser = argparse.ArgumentParser(description="Convert Chatterbox weights to MLX format")
parser.add_argument(
"--ckpt-dir",
type=str,
default=None,
help="Path to checkpoint directory (if not specified, downloads from HuggingFace)"
)
parser.add_argument(
"--output",
type=str,
default="model.safetensors",
help="Output file path"
)
args = parser.parse_args()
# Get checkpoint directory
if args.ckpt_dir:
ckpt_dir = Path(args.ckpt_dir)
else:
print("Downloading weights from HuggingFace...")
ckpt_dir = download_weights()
print(f"Checkpoint directory: {ckpt_dir}")
# List available weight files
print("\nAvailable weight files:")
for f in ckpt_dir.glob("*.safetensors"):
print(f" {f.name}")
# Convert weights
print("\nConverting weights...")
all_weights = convert_all_weights(ckpt_dir)
print(f"\nTotal weights to save: {len(all_weights)}")
# Print some weight shapes for verification
print("\nSample weight shapes:")
for i, (k, v) in enumerate(list(all_weights.items())[:10]):
print(f" {k}: {v.shape}")
# Save to safetensors
output_path = Path(args.output)
print(f"\nSaving to {output_path}...")
save_safetensors(all_weights, str(output_path))
print("Done!")
print(f"Saved {len(all_weights)} weight tensors to {output_path}")
if __name__ == "__main__":
main()
#!/usr/bin/env python3
# Copyright (c) 2025 Resemble AI
# MIT License
# Example script for Chatterbox MLX TTS
"""
Example usage of Chatterbox Turbo TTS with MLX on Apple Silicon.
Requirements:
pip install mlx numpy librosa soundfile transformers huggingface_hub
Usage:
python example_tts_turbo.py
# With voice cloning:
python example_tts_turbo.py --audio_prompt your_reference.wav
"""
import argparse
import time
import logging
import numpy as np
# Enable logging to see debug output
logging.basicConfig(level=logging.INFO, format='%(name)s - %(levelname)s - %(message)s')
def main():
parser = argparse.ArgumentParser(description="Chatterbox MLX TTS Example")
parser.add_argument(
"--text",
type=str,
default="Oh, that's hilarious! [chuckle] Um anyway, we do have a new model in store. It's the SkyNet T-800 series and it's got basically everything. Including AI integration with ChatGPT and all that jazz. Would you like me to get some prices for you?",
help="Text to synthesize"
)
parser.add_argument(
"--audio_prompt",
type=str,
default=None,
help="Path to reference audio for voice cloning (optional, > 5 seconds)"
)
parser.add_argument(
"--output",
type=str,
default="output_mlx.wav",
help="Output audio file path"
)
parser.add_argument(
"--temperature",
type=float,
default=0.8,
help="Sampling temperature"
)
parser.add_argument(
"--top_k",
type=int,
default=1000,
help="Top-k sampling parameter"
)
parser.add_argument(
"--top_p",
type=float,
default=0.95,
help="Top-p (nucleus) sampling parameter"
)
args = parser.parse_args()
print("=" * 60)
print("Chatterbox MLX TTS - Apple Silicon Optimized")
print("=" * 60)
# Import here to avoid slow startup for help
print("\nLoading MLX and model components...")
import mlx.core as mx
from tts_turbo import ChatterboxTurboTTS
# Check MLX backend
print(f"MLX backend: {mx.default_device()}")
# Load model
print("\nLoading Chatterbox Turbo model...")
start_time = time.time()
# Check for converted weights in current directory
import os
script_dir = os.path.dirname(os.path.abspath(__file__))
weights_path = os.path.join(script_dir, "model.safetensors")
if os.path.exists(weights_path):
print(f"Found converted weights at: {weights_path}")
else:
weights_path = None
print("No converted weights found. Run convert_weights.py first for proper audio output.")
try:
model = ChatterboxTurboTTS.from_pretrained(weights_path=weights_path)
load_time = time.time() - start_time
print(f"Model loaded in {load_time:.2f}s")
except Exception as e:
print(f"Error loading model: {e}")
print("\nNote: This MLX port requires weight conversion from PyTorch.")
print("The model architecture is ready, but weights need to be converted.")
print("\nTo convert weights, run:")
print(" python convert_weights.py")
return
# Generate speech
print(f"\nGenerating speech for: '{args.text[:50]}...'")
print(f"Parameters: temperature={args.temperature}, top_k={args.top_k}, top_p={args.top_p}")
if args.audio_prompt:
print(f"Using voice reference: {args.audio_prompt}")
start_time = time.time()
try:
wav = model.generate(
text=args.text,
audio_prompt_path=args.audio_prompt,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)
mx.clear_cache()
gen_time = time.time() - start_time
wav_np = np.array(wav[0])
duration = len(wav_np) / model.sr
print(f"\nGeneration complete!")
print(f" - Time: {gen_time:.2f}s")
print(f" - Audio duration: {duration:.2f}s")
print(f" - Real-time factor: {gen_time/duration:.2f}x")
# Save output
try:
import soundfile as sf
sf.write(args.output, wav_np, model.sr)
print(f" - Saved to: {args.output}")
except ImportError:
raise ImportError("soundfile is not installed. Please install it using `pip install soundfile`")
except Exception as e:
print(f"Error during generation: {e}")
import traceback
traceback.print_exc()
print("\nDone!")
if __name__ == "__main__":
main()
# Copyright (c) 2025 Resemble AI
# MIT License
# MLX port of ChatterboxTurboTTS
import os
import math
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union
import numpy as np
import librosa
import mlx.core as mx
import mlx.nn as nn
from models.t3 import T3, T3Config, T3Cond
from models.s3gen import S3Gen, S3GEN_SR, S3GEN_SIL
from models.voice_encoder import VoiceEncoder
logger = logging.getLogger(__name__)
# Constants
S3_SR = 16000 # S3Tokenizer sample rate
REPO_ID = "ResembleAI/chatterbox-turbo"
def punc_norm(text: str) -> str:
"""
Quick cleanup func for punctuation from LLMs or
containing chars not seen often in the dataset.
"""
if len(text) == 0:
return "You need to add some text for me to talk."
# Capitalise first letter
if text[0].islower():
text = text[0].upper() + text[1:]
# Remove multiple space chars
text = " ".join(text.split())
# Replace uncommon/llm punc
punc_to_replace = [
("…", ", "),
(":", ","),
("—", "-"),
("–", "-"),
(" ,", ","),
(""", "\""),
(""", "\""),
("'", "'"),
("'", "'"),
]
for old_char_sequence, new_char in punc_to_replace:
text = text.replace(old_char_sequence, new_char)
# Add full stop if no ending punc
text = text.rstrip(" ")
sentence_enders = {".", "!", "?", "-", ","}
if not any(text.endswith(p) for p in sentence_enders):
text += "."
return text
@dataclass
class Conditionals:
"""
Conditionals for T3 and S3Gen.
"""
t3: T3Cond
gen: dict
def save(self, fpath: Path):
"""Save conditionals to file."""
import pickle
with open(fpath, 'wb') as f:
pickle.dump({'t3': self.t3, 'gen': self.gen}, f)
@classmethod
def load(cls, fpath: Path) -> 'Conditionals':
"""Load conditionals from file."""
import pickle
with open(fpath, 'rb') as f:
data = pickle.load(f)
return cls(data['t3'], data['gen'])
class ChatterboxTurboTTS:
"""
MLX implementation of Chatterbox Turbo TTS.
Optimized for Apple Silicon.
"""
ENC_COND_LEN = 15 * S3_SR # 15 seconds for encoder conditioning
DEC_COND_LEN = 10 * S3GEN_SR # 10 seconds for decoder conditioning
def __init__(
self,
t3: T3,
s3gen: S3Gen,
ve: VoiceEncoder,
tokenizer, # HuggingFace tokenizer
conds: Optional[Conditionals] = None,
local_path: Optional[str] = None,
):
self.sr = S3GEN_SR # Output sample rate
self.t3 = t3
self.s3gen = s3gen
self.ve = ve
self.tokenizer = tokenizer
self.conds = conds
self.local_path = local_path
@classmethod
def from_local(cls, ckpt_dir: Union[str, Path], device: str = "cpu") -> 'ChatterboxTurboTTS':
"""
Load model from local checkpoint directory.
Args:
ckpt_dir: Path to checkpoint directory
device: Device to use (ignored in MLX, always uses Metal)
Returns:
ChatterboxTurboTTS instance
"""
ckpt_dir = Path(ckpt_dir)
# Load Voice Encoder
ve = VoiceEncoder()
# Create T3 config for Turbo
hp = T3Config.turbo()
# Create T3 model
t3 = T3(hp)
# Create S3Gen
s3gen = S3Gen(meanflow=True)
# Try to load converted weights from model.safetensors
model_weights_path = ckpt_dir / "model.safetensors"
if model_weights_path.exists():
logger.info(f"Loading converted weights from {model_weights_path}")
weights = mx.load(str(model_weights_path))
# Split weights by prefix and load into each model
ve_weights = {k.replace("ve.", ""): v for k, v in weights.items() if k.startswith("ve.")}
t3_weights = {k.replace("t3.", ""): v for k, v in weights.items() if k.startswith("t3.")}
s3gen_weights = {k.replace("s3gen.", ""): v for k, v in weights.items() if k.startswith("s3gen.")}
# Debug: Print expected vs loaded keys for VE
from mlx.utils import tree_flatten
ve_param_keys = [k for k, _ in tree_flatten(ve.parameters())]
print(f"VE model expects these parameter keys: {ve_param_keys[:10]}...")
print(f"VE weights from file: {list(ve_weights.keys())[:10]}...")
if ve_weights:
logger.info(f"Loading {len(ve_weights)} VE weights")
try:
ve.load_weights(list(ve_weights.items()), strict=True)
logger.info("VE weights loaded successfully with strict=True")
except Exception as e:
logger.warning(f"VE strict loading failed: {e}")
logger.info("Falling back to strict=False")
ve.load_weights(list(ve_weights.items()), strict=False)
if t3_weights:
logger.info(f"Loading {len(t3_weights)} T3 weights")
try:
t3.load_weights(list(t3_weights.items()), strict=True)
logger.info("T3 weights loaded successfully with strict=True")
except Exception as e:
logger.warning(f"T3 strict loading failed: {e}")
logger.info("Falling back to strict=False")
t3.load_weights(list(t3_weights.items()), strict=False)
if s3gen_weights:
logger.info(f"Loading {len(s3gen_weights)} S3Gen weights")
# S3Gen has some parameters generated at init (not from weights):
# - encoder.embed.pos_enc.pe, encoder.up_embed.pos_enc.pe (positional encodings)
# - mel2wav.stft_window (STFT window from scipy)
# - trim_fade (fade buffer)
init_generated_params = {
'encoder.embed.pos_enc.pe',
'encoder.up_embed.pos_enc.pe',
'mel2wav.stft_window',
'trim_fade',
}
# Get all S3Gen parameter keys
s3gen_param_keys = set(k for k, _ in tree_flatten(s3gen.parameters()))
loadable_param_keys = s3gen_param_keys - init_generated_params
# Find matching weights (weights that exist in model's loadable params)
matching_weights = [(k, v) for k, v in s3gen_weights.items() if k in loadable_param_keys]
# Check for any weights in file that don't match model
unmatched_weights = set(s3gen_weights.keys()) - s3gen_param_keys
if unmatched_weights:
logger.debug(f"Weights in file not in model: {len(unmatched_weights)}")
# Check for loadable params that don't have weights
missing_weights = loadable_param_keys - set(s3gen_weights.keys())
if missing_weights:
logger.warning(f"Model params without weights: {missing_weights}")
logger.info(f"Loading {len(matching_weights)} S3Gen weights (excluding {len(init_generated_params)} init-generated params)")
if matching_weights:
# Load with strict=False since we're intentionally excluding init-generated params
s3gen.load_weights(matching_weights, strict=False)
# Verify all expected weights were loaded
if len(matching_weights) == len(loadable_param_keys):
logger.info("S3Gen weights loaded successfully (all loadable params matched)")
else:
logger.warning(f"S3Gen loaded {len(matching_weights)}/{len(loadable_param_keys)} loadable params")
else:
logger.warning("No matching S3Gen weights found - model may not work correctly")
mx.eval(ve.parameters(), t3.parameters(), s3gen.parameters())
logger.info("Weights loaded successfully")
else:
logger.warning(f"No converted weights found at {model_weights_path}")
logger.warning("Run convert_weights.py first to convert PyTorch weights to MLX format")
# Load tokenizer
try:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(str(ckpt_dir))
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
except Exception as e:
logger.warning(f"Could not load tokenizer: {e}")
tokenizer = None
# Load pre-computed conditionals
conds = None
builtin_voice = ckpt_dir / "conds.pt"
if builtin_voice.exists():
try:
import torch
conds_data = torch.load(builtin_voice, map_location='cpu', weights_only=True)
# Convert to MLX arrays
t3_cond_dict = conds_data.get('t3', {})
gen_dict = conds_data.get('gen', {})
# Helper to convert PyTorch tensor to numpy (handles requires_grad)
def to_numpy(t):
if hasattr(t, 'detach'):
return t.detach().cpu().numpy()
elif hasattr(t, 'numpy'):
return t.numpy()
return np.array(t)
# Convert tensors to MLX arrays
speaker_emb = t3_cond_dict.get('speaker_emb')
if speaker_emb is not None:
speaker_emb = mx.array(to_numpy(speaker_emb))
else:
speaker_emb = mx.array(np.zeros((1, 256), dtype=np.float32))
cond_tokens = t3_cond_dict.get('cond_prompt_speech_tokens')
if cond_tokens is not None:
cond_tokens = mx.array(to_numpy(cond_tokens).astype(np.int32))
t3_cond = T3Cond(
speaker_emb=speaker_emb,
cond_prompt_speech_tokens=cond_tokens,
)
gen_mlx = {}
for k, v in gen_dict.items():
if hasattr(v, 'detach') or hasattr(v, 'numpy'):
gen_mlx[k] = mx.array(to_numpy(v))
elif isinstance(v, (int, float)):
gen_mlx[k] = v
conds = Conditionals(t3_cond, gen_mlx)
logger.info("Loaded pre-computed conditionals")
except Exception as e:
logger.warning(f"Could not load conditionals: {e}")
return cls(t3, s3gen, ve, tokenizer, conds=conds, local_path=str(ckpt_dir))
@classmethod
def from_pretrained(cls, device: str = "cpu", weights_path: str = None) -> 'ChatterboxTurboTTS':
"""
Load model from HuggingFace Hub.
Args:
device: Device to use (ignored in MLX)
weights_path: Optional path to converted model.safetensors
Returns:
ChatterboxTurboTTS instance
"""
try:
from huggingface_hub import snapshot_download
local_path = snapshot_download(
repo_id=REPO_ID,
token=os.getenv("HF_TOKEN") or True,
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.pt", "*.model"]
)
# If weights_path provided, always copy to ensure latest version is used
if weights_path:
import shutil
dest = Path(local_path) / "model.safetensors"
# Always copy to ensure we use the latest converted weights
shutil.copy(weights_path, dest)
logger.info(f"Copied converted weights to {dest}")
return cls.from_local(local_path, device)
except ImportError:
raise ImportError("Please install huggingface_hub: pip install huggingface_hub")
def norm_loudness(self, wav: np.ndarray, sr: int, target_lufs: float = -27) -> np.ndarray:
"""Normalize audio loudness."""
try:
import pyloudnorm as ln
meter = ln.Meter(sr)
loudness = meter.integrated_loudness(wav)
gain_db = target_lufs - loudness
gain_linear = 10.0 ** (gain_db / 20.0)
if math.isfinite(gain_linear) and gain_linear > 0.0:
wav = wav * gain_linear
except Exception as e:
logger.warning(f"Error in norm_loudness, skipping: {e}")
return wav
def _extract_pytorch_conditionals(self, wav_fpath: str, norm_loudness: bool = True) -> tuple:
"""
Extract all conditioning using PyTorch (S3Gen embeddings + T3 tokens).
This matches the original PyTorch tts_turbo.prepare_conditionals behavior.
Args:
wav_fpath: Path to reference audio
norm_loudness: Whether to normalize loudness
Returns:
Tuple of (s3gen_ref_dict, t3_cond_prompt_tokens) or (None, None) on failure
"""
try:
import sys
import torch
from safetensors.torch import load_file
# Add PyTorch chatterbox to path
pytorch_path = str(Path(__file__).parent.parent / "chatterbox" / "src")
if pytorch_path not in sys.path:
sys.path.insert(0, pytorch_path)
from chatterbox.models.s3gen import S3Gen as S3GenPT
# Initialize PyTorch S3Gen
s3gen_pt = S3GenPT()
# Load weights
weights_path = Path(self.local_path) / "s3gen_meanflow.safetensors"
if weights_path.exists():
state_dict = load_file(str(weights_path))
s3gen_pt.load_state_dict(state_dict, strict=False)
s3gen_pt.eval()
# Load and process audio at 24kHz for S3Gen
s3gen_ref_wav, _ = librosa.load(wav_fpath, sr=S3GEN_SR)
if norm_loudness:
s3gen_ref_wav = self.norm_loudness(s3gen_ref_wav, S3GEN_SR)
# Resample to 16kHz for tokenizer
ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
# Trim to conditioning lengths
s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
ref_16k_wav = ref_16k_wav[:self.ENC_COND_LEN]
with torch.no_grad():
# Get S3Gen embeddings
ref_dict = s3gen_pt.embed_ref(s3gen_ref_wav, S3GEN_SR)
# Get T3 conditioning tokens using S3Gen's tokenizer (matches PyTorch exactly)
plen = self.t3.hp.speech_cond_prompt_len
t3_cond_prompt_tokens = None
if plen and s3gen_pt.tokenizer is not None:
t3_cond_prompt_tokens, _ = s3gen_pt.tokenizer.forward(
[ref_16k_wav], max_len=plen
)
t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens)
# Convert S3Gen dict to MLX arrays
mlx_ref_dict = {}
for k, v in ref_dict.items():
if v is not None and torch.is_tensor(v):
mlx_ref_dict[k] = mx.array(v.cpu().numpy())
elif v is not None:
mlx_ref_dict[k] = v
else:
mlx_ref_dict[k] = None
# Convert T3 tokens to MLX
mlx_t3_tokens = None
if t3_cond_prompt_tokens is not None:
mlx_t3_tokens = mx.array(t3_cond_prompt_tokens.cpu().numpy())
logger.info("Extracted all conditionals using PyTorch")
return mlx_ref_dict, mlx_t3_tokens
except Exception as e:
logger.warning(f"Failed to extract conditionals with PyTorch: {e}")
import traceback
traceback.print_exc()
return None, None
def prepare_conditionals(
self,
wav_fpath: str,
exaggeration: float = 0.5,
norm_loudness: bool = True,
):
"""
Prepare conditioning from a reference audio file.
Args:
wav_fpath: Path to reference audio file (should be > 5 seconds)
exaggeration: Emotion exaggeration factor (not used in Turbo)
norm_loudness: Whether to normalize loudness
"""
# Load reference audio at 24kHz for S3Gen
s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
assert len(s3gen_ref_wav) / S3GEN_SR > 5.0, "Audio prompt must be longer than 5 seconds!"
if norm_loudness:
s3gen_ref_wav = self.norm_loudness(s3gen_ref_wav, S3GEN_SR)
# Resample to 16kHz for voice encoder
ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
# Trim to conditioning length
s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
# Try to extract all conditionals using PyTorch (for better quality)
s3gen_ref_dict, t3_cond_prompt_tokens = self._extract_pytorch_conditionals(wav_fpath, norm_loudness)
# Fallback if PyTorch extraction failed
if s3gen_ref_dict is None:
logger.warning("PyTorch extraction failed, using MLX fallback (may have lower quality)")
s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR)
# Fallback for T3 tokens
plen = self.t3.hp.speech_cond_prompt_len
if plen and t3_cond_prompt_tokens is None:
logger.warning("Using zero tokens for T3 conditioning - audio quality may be poor")
t3_cond_prompt_tokens = mx.zeros((1, plen), dtype=mx.int32)
# Get voice encoder speaker embedding
ve_embed = self.ve.embeds_from_wavs([ref_16k_wav[:self.ENC_COND_LEN]], sample_rate=S3_SR)
ve_embed = mx.array(np.mean(ve_embed, axis=0, keepdims=True))
# Create T3 conditioning
t3_cond = T3Cond(
speaker_emb=ve_embed,
cond_prompt_speech_tokens=t3_cond_prompt_tokens,
emotion_adv=mx.array([[[exaggeration]]]) if self.t3.hp.emotion_adv else None,
)
self.conds = Conditionals(t3_cond, s3gen_ref_dict)
def generate(
self,
text: str,
repetition_penalty: float = 1.2,
min_p: float = 0.0,
top_p: float = 0.95,
audio_prompt_path: Optional[str] = None,
exaggeration: float = 0.0,
cfg_weight: float = 0.0,
temperature: float = 0.8,
top_k: int = 1000,
norm_loudness: bool = True,
) -> mx.array:
"""
Generate speech from text.
Args:
text: Input text to synthesize
repetition_penalty: Penalty for repeating tokens
min_p: Minimum probability threshold (not used in Turbo)
top_p: Nucleus sampling threshold
audio_prompt_path: Optional path to reference audio for voice cloning
exaggeration: Emotion exaggeration (not used in Turbo)
cfg_weight: Classifier-free guidance weight (not used in Turbo)
temperature: Sampling temperature
top_k: Top-k sampling parameter
norm_loudness: Whether to normalize output loudness
Returns:
Generated waveform as MLX array (1, T)
"""
# Prepare conditionals if audio prompt provided
if audio_prompt_path:
self.prepare_conditionals(
audio_prompt_path,
exaggeration=exaggeration,
norm_loudness=norm_loudness
)
else:
assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`"
# Warn about unsupported parameters
if cfg_weight > 0.0 or exaggeration > 0.0 or min_p > 0.0:
logger.warning("CFG, min_p and exaggeration are not supported by Turbo version and will be ignored.")
# Normalize and tokenize text
text = punc_norm(text)
if self.tokenizer is not None:
text_tokens = self.tokenizer(text, return_tensors="np", padding=True, truncation=True)
text_tokens = mx.array(text_tokens.input_ids)
else:
# Fallback: simple character-level tokenization (for testing)
logger.warning("No tokenizer available, using simple fallback")
text_tokens = mx.array([[ord(c) for c in text[:512]]])
# Generate speech tokens with T3
speech_tokens = self.t3.inference_turbo(
t3_cond=self.conds.t3,
text_tokens=text_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
# Remove OOV tokens and add silence
speech_tokens = speech_tokens.reshape(-1)
mask = np.where(speech_tokens < 6561)[0].tolist()
speech_tokens = speech_tokens[mask]
silence = mx.array([S3GEN_SIL, S3GEN_SIL, S3GEN_SIL], dtype=mx.int32)
speech_tokens = mx.concatenate([speech_tokens, silence])
speech_tokens = speech_tokens[None, :] # Add batch dimension
# Generate waveform with S3Gen
wav, _ = self.s3gen.inference(
speech_tokens=speech_tokens,
ref_dict=self.conds.gen,
n_cfm_timesteps=2, # Turbo uses 2 steps
)
# Post-process
wav = wav[0] # Remove batch dimension
wav_np = np.array(wav)
# # Normalize loudness
# if norm_loudness:
# wav_np = self.norm_loudness(wav_np, self.sr)
return mx.array(wav_np)[None, :] # (1, T)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment