Created
December 18, 2025 10:49
-
-
Save Blaizzy/fd96083c4c8ccd00cd039d819934315a to your computer and use it in GitHub Desktop.
Chattebox Turbo MLX port
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| # Copyright (c) 2025 Resemble AI | |
| # MIT License | |
| # Weight conversion script: PyTorch -> MLX | |
| """ | |
| Converts Chatterbox Turbo weights from PyTorch to MLX format. | |
| Usage: | |
| python convert_weights.py --output model.safetensors | |
| """ | |
| import argparse | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, Any | |
| import numpy as np | |
| try: | |
| import mlx.core as mx | |
| from mlx.utils import tree_flatten | |
| except ImportError: | |
| raise ImportError("Please install mlx: pip install mlx") | |
| try: | |
| from safetensors import safe_open | |
| from safetensors.numpy import save_file as save_safetensors | |
| except ImportError: | |
| raise ImportError("Please install safetensors: pip install safetensors") | |
| def download_weights(): | |
| """Download weights from HuggingFace Hub.""" | |
| try: | |
| from huggingface_hub import snapshot_download | |
| local_path = snapshot_download( | |
| repo_id="ResembleAI/chatterbox-turbo", | |
| token=os.getenv("HF_TOKEN") or True, | |
| allow_patterns=["*.safetensors", "*.json", "*.txt", "*.pt", "*.model"] | |
| ) | |
| return Path(local_path) | |
| except ImportError: | |
| raise ImportError("Please install huggingface_hub: pip install huggingface_hub") | |
| def load_pytorch_weights(path: Path) -> Dict[str, np.ndarray]: | |
| """Load weights from PyTorch safetensors file.""" | |
| weights = {} | |
| with safe_open(path, framework="numpy") as f: | |
| for key in f.keys(): | |
| weights[key] = f.get_tensor(key) | |
| return weights | |
| def convert_conv1d_weight(weight: np.ndarray) -> np.ndarray: | |
| """ | |
| Convert Conv1d weight from PyTorch to MLX format. | |
| PyTorch: (out_channels, in_channels, kernel_size) | |
| MLX: (out_channels, kernel_size, in_channels) | |
| """ | |
| if weight.ndim == 3: | |
| return np.transpose(weight, (0, 2, 1)) | |
| return weight | |
| def convert_conv_transpose1d_weight(weight: np.ndarray) -> np.ndarray: | |
| """ | |
| Convert ConvTranspose1d weight from PyTorch to MLX format. | |
| PyTorch: (in_channels, out_channels, kernel_size) | |
| MLX: (out_channels, kernel_size, in_channels) | |
| """ | |
| if weight.ndim == 3: | |
| # First swap in/out channels, then transpose | |
| return np.transpose(weight, (1, 2, 0)) | |
| return weight | |
| def convert_lstm_weights(weights: Dict[str, np.ndarray], prefix: str) -> Dict[str, np.ndarray]: | |
| """ | |
| Convert LSTM weights from PyTorch to MLX format. | |
| PyTorch LSTM has: | |
| - weight_ih_l{layer}: (4*hidden, input) | |
| - weight_hh_l{layer}: (4*hidden, hidden) | |
| - bias_ih_l{layer}: (4*hidden,) | |
| - bias_hh_l{layer}: (4*hidden,) | |
| MLX LSTM has: | |
| - Wx: (input, 4*hidden) - transposed | |
| - Wh: (hidden, 4*hidden) - transposed | |
| - bias: (4*hidden,) - sum of both biases | |
| """ | |
| converted = {} | |
| for layer_idx in range(3): # We have 3 LSTM layers | |
| pt_prefix = f"{prefix}.lstm{layer_idx+1}" if layer_idx > 0 else prefix | |
| mlx_prefix = f"lstm{layer_idx+1}" | |
| # Check for different naming conventions | |
| weight_ih_key = None | |
| weight_hh_key = None | |
| bias_ih_key = None | |
| bias_hh_key = None | |
| for key in weights: | |
| if f"weight_ih_l0" in key and pt_prefix in key: | |
| weight_ih_key = key | |
| elif f"weight_hh_l0" in key and pt_prefix in key: | |
| weight_hh_key = key | |
| elif f"bias_ih_l0" in key and pt_prefix in key: | |
| bias_ih_key = key | |
| elif f"bias_hh_l0" in key and pt_prefix in key: | |
| bias_hh_key = key | |
| if weight_ih_key: | |
| # Wx: transpose from (4*hidden, input) to (input, 4*hidden) | |
| converted[f"{mlx_prefix}.Wx"] = weights[weight_ih_key].T | |
| if weight_hh_key: | |
| # Wh: transpose from (4*hidden, hidden) to (hidden, 4*hidden) | |
| converted[f"{mlx_prefix}.Wh"] = weights[weight_hh_key].T | |
| if bias_ih_key and bias_hh_key: | |
| # Combine biases | |
| converted[f"{mlx_prefix}.bias"] = weights[bias_ih_key] + weights[bias_hh_key] | |
| return converted | |
| def map_ve_weights(pt_weights: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: | |
| """Map Voice Encoder weights from PyTorch to MLX format.""" | |
| mlx_weights = {} | |
| # Map 3 LSTM layers: l0 -> lstm1, l1 -> lstm2, l2 -> lstm3 | |
| # MLX LSTM uses same weight format as PyTorch: (4*hidden, input) for Wx | |
| for layer_idx in range(3): | |
| mlx_layer = f"lstm{layer_idx + 1}" | |
| pt_layer = f"l{layer_idx}" | |
| # Weight ih (input-hidden): same shape in both (4*hidden, input) | |
| ih_key = f"lstm.weight_ih_{pt_layer}" | |
| if ih_key in pt_weights: | |
| mlx_weights[f"{mlx_layer}.Wx"] = pt_weights[ih_key] | |
| # Weight hh (hidden-hidden): same shape in both (4*hidden, hidden) | |
| hh_key = f"lstm.weight_hh_{pt_layer}" | |
| if hh_key in pt_weights: | |
| mlx_weights[f"{mlx_layer}.Wh"] = pt_weights[hh_key] | |
| # Bias: combine ih and hh biases | |
| bias_ih_key = f"lstm.bias_ih_{pt_layer}" | |
| bias_hh_key = f"lstm.bias_hh_{pt_layer}" | |
| if bias_ih_key in pt_weights and bias_hh_key in pt_weights: | |
| mlx_weights[f"{mlx_layer}.bias"] = pt_weights[bias_ih_key] + pt_weights[bias_hh_key] | |
| # Projection layer | |
| if "proj.weight" in pt_weights: | |
| mlx_weights["proj.weight"] = pt_weights["proj.weight"] | |
| if "proj.bias" in pt_weights: | |
| mlx_weights["proj.bias"] = pt_weights["proj.bias"] | |
| # Similarity weight/bias (for training, but we include them) | |
| if "similarity_weight" in pt_weights: | |
| mlx_weights["similarity_weight"] = pt_weights["similarity_weight"] | |
| if "similarity_bias" in pt_weights: | |
| mlx_weights["similarity_bias"] = pt_weights["similarity_bias"] | |
| return mlx_weights | |
| def map_t3_weights(pt_weights: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: | |
| """Map T3 weights from PyTorch to MLX format. | |
| GPT-2 uses Conv1D which has weights in (in_features, out_features) format. | |
| MLX Linear expects (out_features, in_features), so we need to transpose. | |
| """ | |
| mlx_weights = {} | |
| for key, value in pt_weights.items(): | |
| new_key = key | |
| new_value = value | |
| # GPT-2 attention and MLP weights need transposition | |
| # c_attn, c_proj, c_fc all use Conv1D style (in, out) -> need transpose to (out, in) | |
| if "weight" in key and value.ndim == 2: | |
| if any(x in key for x in ["c_attn", "c_proj", "c_fc"]): | |
| new_value = value.T # Transpose to (out_features, in_features) | |
| # Conv1d weights (3D) need transposition | |
| if "conv" in key.lower() and "weight" in key and value.ndim == 3: | |
| new_value = convert_conv1d_weight(value) | |
| mlx_weights[new_key] = new_value | |
| return mlx_weights | |
| def map_s3gen_weights(pt_weights: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: | |
| """Map S3Gen weights from PyTorch to MLX format. | |
| Handles: | |
| 1. Removing 'flow.' prefix for flow components | |
| 2. Mapping encoder layer names (norm_ff -> norm2, norm_mha -> norm1, embed.out.0 -> embed.linear) | |
| 3. Reconstructing weights from weight normalization parametrizations | |
| 4. Mapping mel2wav weights with conv wrapper structure | |
| 5. Conv1d weight transposition | |
| 6. Decoder estimator block index remapping | |
| """ | |
| import re | |
| mlx_weights = {} | |
| # First pass: handle weight normalization - reconstruct weights from parametrizations | |
| # The Chatterbox weight_norm structure is: | |
| # original0 = g (magnitude per output channel, shape: out, 1, 1) | |
| # original1 = v (direction/weight, shape: out, in, kernel) | |
| # Actual weight = g * (v / ||v||) where norm is computed over (in, kernel) dims | |
| weight_norm_weights = {} | |
| original0_dict = {} | |
| original1_dict = {} | |
| for key, value in pt_weights.items(): | |
| if 'parametrizations.weight.original0' in key: | |
| base_key = key.replace('.parametrizations.weight.original0', '.weight') | |
| original0_dict[base_key] = value # magnitude (out, 1, 1) | |
| elif 'parametrizations.weight.original1' in key: | |
| base_key = key.replace('.parametrizations.weight.original1', '.weight') | |
| original1_dict[base_key] = value # direction (out, in, kernel) | |
| # Reconstruct weights: weight = g * (v / ||v||) | |
| for base_key in original0_dict: | |
| if base_key in original1_dict: | |
| g = original0_dict[base_key] # magnitude (out, 1, 1) or similar | |
| v = original1_dict[base_key] # direction (out, in, kernel) or similar | |
| # Compute norm over all dims except the first (out_channels) | |
| if v.ndim == 3: | |
| # Conv1d: v is (out, in, kernel) - norm over (in, kernel) | |
| v_norm = np.linalg.norm(v.reshape(v.shape[0], -1), axis=1, keepdims=True) | |
| v_norm = v_norm.reshape(v.shape[0], 1, 1) | |
| elif v.ndim == 2: | |
| # Linear: v is (out, in) - norm over (in) | |
| v_norm = np.linalg.norm(v, axis=1, keepdims=True) | |
| else: | |
| v_norm = np.linalg.norm(v) | |
| # Compute fused weight: g * (v / ||v||) | |
| # g is already shaped for broadcasting (out, 1, 1) | |
| weight = g * (v / (v_norm + 1e-8)) | |
| weight_norm_weights[base_key] = weight | |
| # Merge weight norm weights back | |
| for key, value in weight_norm_weights.items(): | |
| pt_weights[key] = value | |
| for key, value in pt_weights.items(): | |
| # Skip parametrization keys - we've already handled them | |
| if 'parametrizations' in key: | |
| continue | |
| new_key = key | |
| new_value = value | |
| # Remove 'flow.' prefix - MLX model doesn't have it | |
| if new_key.startswith('flow.'): | |
| new_key = new_key[5:] # Remove 'flow.' | |
| # Handle decoder estimator block index remapping | |
| # PyTorch block.0 -> MLX block.0, PyTorch block.2 -> MLX block.1 | |
| if 'decoder.estimator' in new_key: | |
| # Map PyTorch tuple-based block structure to MLX named attributes | |
| # PyTorch: down_blocks.X.0 (resnet), .1.Y (transformer_blocks), .2 (downsample) | |
| # MLX: down_blocks.X.resnet, .transformer_blocks.Y, .downsample | |
| # down_blocks tuple -> named attrs | |
| new_key = re.sub(r'\.down_blocks\.(\d+)\.0\.', r'.down_blocks.\1.resnet.', new_key) | |
| new_key = re.sub(r'\.down_blocks\.(\d+)\.1\.(\d+)\.', r'.down_blocks.\1.transformer_blocks.\2.', new_key) | |
| new_key = re.sub(r'\.down_blocks\.(\d+)\.2\.', r'.down_blocks.\1.downsample.', new_key) | |
| # mid_blocks tuple -> named attrs (no downsample, just resnet + transformer_blocks) | |
| new_key = re.sub(r'\.mid_blocks\.(\d+)\.0\.', r'.mid_blocks.\1.resnet.', new_key) | |
| new_key = re.sub(r'\.mid_blocks\.(\d+)\.1\.(\d+)\.', r'.mid_blocks.\1.transformer_blocks.\2.', new_key) | |
| # up_blocks tuple -> named attrs | |
| new_key = re.sub(r'\.up_blocks\.(\d+)\.0\.', r'.up_blocks.\1.resnet.', new_key) | |
| new_key = re.sub(r'\.up_blocks\.(\d+)\.1\.(\d+)\.', r'.up_blocks.\1.transformer_blocks.\2.', new_key) | |
| new_key = re.sub(r'\.up_blocks\.(\d+)\.2\.', r'.up_blocks.\1.upsample.', new_key) | |
| # CausalBlock1D: block.0 = CausalConv -> block.0.conv.conv | |
| # block.2 = LayerNorm -> block.1 (no wrapper) | |
| # Note: block.2 in PyTorch is LayerNorm with direct weight/bias | |
| new_key = re.sub(r'\.block\.2\.', '.block.1.', new_key) # LayerNorm, no conv wrapper | |
| new_key = re.sub(r'\.block\.0\.', '.block.0.conv.conv.', new_key) # CausalConv has wrappers | |
| # mlp.1 -> mlp.0 | |
| new_key = re.sub(r'\.mlp\.1\.', '.mlp.0.', new_key) | |
| # ff.net.2 -> ff.net.1 | |
| new_key = re.sub(r'\.ff\.net\.2\.', '.ff.net.1.', new_key) | |
| # res_conv direct -> res_conv.conv | |
| if '.res_conv.' in new_key and '.conv.' not in new_key.split('.res_conv.')[1]: | |
| new_key = new_key.replace('.res_conv.', '.res_conv.conv.') | |
| # final_proj direct -> final_proj.conv | |
| if '.final_proj.' in new_key and '.conv.' not in new_key.split('.final_proj.')[1]: | |
| new_key = new_key.replace('.final_proj.', '.final_proj.conv.') | |
| # Handle downsample/upsample - these need .conv. wrapper for Conv1dPT or CausalConv | |
| if '.downsample.' in new_key: | |
| parts = new_key.split('.downsample.') | |
| suffix = parts[1] | |
| # Only add .conv. if it's a Conv1dPT/CausalConv (has weight/bias at the end) | |
| if suffix in ['weight', 'bias']: | |
| # This is CausalConv (last block) - needs .conv.conv. wrapper | |
| new_key = parts[0] + '.downsample.conv.conv.' + suffix | |
| elif not suffix.startswith('conv.'): | |
| # Check if it's already wrapped | |
| pass # Already handled by Downsample1D | |
| if '.upsample.' in new_key: | |
| parts = new_key.split('.upsample.') | |
| suffix = parts[1] | |
| if suffix in ['weight', 'bias']: | |
| # This is CausalConv (last block) - needs .conv.conv. wrapper | |
| new_key = parts[0] + '.upsample.conv.conv.' + suffix | |
| # Encoder layer name mappings - MLX now uses same names as PyTorch | |
| # norm_ff and norm_mha are kept as-is | |
| # PyTorch: embed.out.0 (Linear), embed.out.1 (LayerNorm) -> MLX: embed.linear, embed.norm | |
| new_key = new_key.replace('.embed.out.0.', '.embed.linear.') | |
| new_key = new_key.replace('.embed.out.1.', '.embed.norm.') | |
| # PyTorch: up_embed.out.0 (Linear), up_embed.out.1 (LayerNorm) -> MLX: up_embed.linear, up_embed.norm | |
| new_key = new_key.replace('.up_embed.out.0.', '.up_embed.linear.') | |
| new_key = new_key.replace('.up_embed.out.1.', '.up_embed.norm.') | |
| # Handle mel2wav conv weights - MLX uses Conv1dPT wrapper with '.conv.' sub-module | |
| if 'mel2wav' in new_key: | |
| parts = new_key.split('.') | |
| # Handle f0_predictor.condnet - PyTorch uses Sequential indices 0,2,4,6,8 | |
| # MLX uses list indices 0,1,2,3,4 with Conv1dPT wrapper | |
| if 'f0_predictor.condnet' in new_key: | |
| import re | |
| match = re.search(r'f0_predictor\.condnet\.(\d+)\.', new_key) | |
| if match: | |
| pt_idx = int(match.group(1)) | |
| # PyTorch Sequential: conv at 0,2,4,6,8 -> MLX list: 0,1,2,3,4 | |
| mlx_idx = pt_idx // 2 | |
| new_key = re.sub( | |
| r'f0_predictor\.condnet\.\d+\.', | |
| f'f0_predictor.condnet.{mlx_idx}.conv.', | |
| new_key | |
| ) | |
| # Handle conv_pre, conv_post, ups, source_downs - add .conv. wrapper | |
| elif any(x in new_key for x in ['conv_pre.', 'conv_post.', 'ups.', 'source_downs.']): | |
| if parts[-1] in ['weight', 'bias'] and '.conv.' not in new_key: | |
| parts.insert(-1, 'conv') | |
| new_key = '.'.join(parts) | |
| # Handle resblocks and source_resblocks | |
| elif 'resblocks' in new_key: | |
| if parts[-1] in ['weight', 'bias'] and '.conv.' not in new_key: | |
| parts.insert(-1, 'conv') | |
| new_key = '.'.join(parts) | |
| # Conv1d weights (3D) need transposition from (out, in, kernel) to (out, kernel, in) | |
| if "weight" in new_key and value.ndim == 3: | |
| # mel2wav.ups.X layers are ConvTranspose1d | |
| # Use "ups." with dot to avoid matching "upsample" (which is Conv1d) | |
| if "mel2wav" in new_key and ".ups." in new_key: | |
| # ConvTranspose1d: (in, out, kernel) -> (out, kernel, in) | |
| new_value = convert_conv_transpose1d_weight(value) | |
| else: | |
| # Regular Conv1d: (out, in, kernel) -> (out, kernel, in) | |
| new_value = convert_conv1d_weight(value) | |
| # Linear weights - S3Gen uses standard PyTorch Linear which has same format as MLX | |
| # (out_features, in_features), so NO transposition needed | |
| # Note: PyTorch Linear weights are already (out, in) which matches MLX | |
| mlx_weights[new_key] = new_value | |
| return mlx_weights | |
| def convert_all_weights(ckpt_dir: Path) -> Dict[str, np.ndarray]: | |
| """Convert all model weights to MLX format.""" | |
| all_weights = {} | |
| # Voice Encoder | |
| ve_path = ckpt_dir / "ve.safetensors" | |
| if ve_path.exists(): | |
| print(f"Converting Voice Encoder weights from {ve_path}") | |
| pt_weights = load_pytorch_weights(ve_path) | |
| mlx_weights = map_ve_weights(pt_weights) | |
| for k, v in mlx_weights.items(): | |
| all_weights[f"ve.{k}"] = v | |
| print(f" Converted {len(mlx_weights)} VE weights") | |
| # T3 | |
| t3_path = ckpt_dir / "t3_turbo_v1.safetensors" | |
| if t3_path.exists(): | |
| print(f"Converting T3 weights from {t3_path}") | |
| pt_weights = load_pytorch_weights(t3_path) | |
| mlx_weights = map_t3_weights(pt_weights) | |
| for k, v in mlx_weights.items(): | |
| all_weights[f"t3.{k}"] = v | |
| print(f" Converted {len(mlx_weights)} T3 weights") | |
| # S3Gen | |
| s3gen_path = ckpt_dir / "s3gen_meanflow.safetensors" | |
| if s3gen_path.exists(): | |
| print(f"Converting S3Gen weights from {s3gen_path}") | |
| pt_weights = load_pytorch_weights(s3gen_path) | |
| mlx_weights = map_s3gen_weights(pt_weights) | |
| for k, v in mlx_weights.items(): | |
| all_weights[f"s3gen.{k}"] = v | |
| print(f" Converted {len(mlx_weights)} S3Gen weights") | |
| return all_weights | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Convert Chatterbox weights to MLX format") | |
| parser.add_argument( | |
| "--ckpt-dir", | |
| type=str, | |
| default=None, | |
| help="Path to checkpoint directory (if not specified, downloads from HuggingFace)" | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| type=str, | |
| default="model.safetensors", | |
| help="Output file path" | |
| ) | |
| args = parser.parse_args() | |
| # Get checkpoint directory | |
| if args.ckpt_dir: | |
| ckpt_dir = Path(args.ckpt_dir) | |
| else: | |
| print("Downloading weights from HuggingFace...") | |
| ckpt_dir = download_weights() | |
| print(f"Checkpoint directory: {ckpt_dir}") | |
| # List available weight files | |
| print("\nAvailable weight files:") | |
| for f in ckpt_dir.glob("*.safetensors"): | |
| print(f" {f.name}") | |
| # Convert weights | |
| print("\nConverting weights...") | |
| all_weights = convert_all_weights(ckpt_dir) | |
| print(f"\nTotal weights to save: {len(all_weights)}") | |
| # Print some weight shapes for verification | |
| print("\nSample weight shapes:") | |
| for i, (k, v) in enumerate(list(all_weights.items())[:10]): | |
| print(f" {k}: {v.shape}") | |
| # Save to safetensors | |
| output_path = Path(args.output) | |
| print(f"\nSaving to {output_path}...") | |
| save_safetensors(all_weights, str(output_path)) | |
| print("Done!") | |
| print(f"Saved {len(all_weights)} weight tensors to {output_path}") | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| # Copyright (c) 2025 Resemble AI | |
| # MIT License | |
| # Example script for Chatterbox MLX TTS | |
| """ | |
| Example usage of Chatterbox Turbo TTS with MLX on Apple Silicon. | |
| Requirements: | |
| pip install mlx numpy librosa soundfile transformers huggingface_hub | |
| Usage: | |
| python example_tts_turbo.py | |
| # With voice cloning: | |
| python example_tts_turbo.py --audio_prompt your_reference.wav | |
| """ | |
| import argparse | |
| import time | |
| import logging | |
| import numpy as np | |
| # Enable logging to see debug output | |
| logging.basicConfig(level=logging.INFO, format='%(name)s - %(levelname)s - %(message)s') | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Chatterbox MLX TTS Example") | |
| parser.add_argument( | |
| "--text", | |
| type=str, | |
| default="Oh, that's hilarious! [chuckle] Um anyway, we do have a new model in store. It's the SkyNet T-800 series and it's got basically everything. Including AI integration with ChatGPT and all that jazz. Would you like me to get some prices for you?", | |
| help="Text to synthesize" | |
| ) | |
| parser.add_argument( | |
| "--audio_prompt", | |
| type=str, | |
| default=None, | |
| help="Path to reference audio for voice cloning (optional, > 5 seconds)" | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| type=str, | |
| default="output_mlx.wav", | |
| help="Output audio file path" | |
| ) | |
| parser.add_argument( | |
| "--temperature", | |
| type=float, | |
| default=0.8, | |
| help="Sampling temperature" | |
| ) | |
| parser.add_argument( | |
| "--top_k", | |
| type=int, | |
| default=1000, | |
| help="Top-k sampling parameter" | |
| ) | |
| parser.add_argument( | |
| "--top_p", | |
| type=float, | |
| default=0.95, | |
| help="Top-p (nucleus) sampling parameter" | |
| ) | |
| args = parser.parse_args() | |
| print("=" * 60) | |
| print("Chatterbox MLX TTS - Apple Silicon Optimized") | |
| print("=" * 60) | |
| # Import here to avoid slow startup for help | |
| print("\nLoading MLX and model components...") | |
| import mlx.core as mx | |
| from tts_turbo import ChatterboxTurboTTS | |
| # Check MLX backend | |
| print(f"MLX backend: {mx.default_device()}") | |
| # Load model | |
| print("\nLoading Chatterbox Turbo model...") | |
| start_time = time.time() | |
| # Check for converted weights in current directory | |
| import os | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| weights_path = os.path.join(script_dir, "model.safetensors") | |
| if os.path.exists(weights_path): | |
| print(f"Found converted weights at: {weights_path}") | |
| else: | |
| weights_path = None | |
| print("No converted weights found. Run convert_weights.py first for proper audio output.") | |
| try: | |
| model = ChatterboxTurboTTS.from_pretrained(weights_path=weights_path) | |
| load_time = time.time() - start_time | |
| print(f"Model loaded in {load_time:.2f}s") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| print("\nNote: This MLX port requires weight conversion from PyTorch.") | |
| print("The model architecture is ready, but weights need to be converted.") | |
| print("\nTo convert weights, run:") | |
| print(" python convert_weights.py") | |
| return | |
| # Generate speech | |
| print(f"\nGenerating speech for: '{args.text[:50]}...'") | |
| print(f"Parameters: temperature={args.temperature}, top_k={args.top_k}, top_p={args.top_p}") | |
| if args.audio_prompt: | |
| print(f"Using voice reference: {args.audio_prompt}") | |
| start_time = time.time() | |
| try: | |
| wav = model.generate( | |
| text=args.text, | |
| audio_prompt_path=args.audio_prompt, | |
| temperature=args.temperature, | |
| top_k=args.top_k, | |
| top_p=args.top_p, | |
| ) | |
| mx.clear_cache() | |
| gen_time = time.time() - start_time | |
| wav_np = np.array(wav[0]) | |
| duration = len(wav_np) / model.sr | |
| print(f"\nGeneration complete!") | |
| print(f" - Time: {gen_time:.2f}s") | |
| print(f" - Audio duration: {duration:.2f}s") | |
| print(f" - Real-time factor: {gen_time/duration:.2f}x") | |
| # Save output | |
| try: | |
| import soundfile as sf | |
| sf.write(args.output, wav_np, model.sr) | |
| print(f" - Saved to: {args.output}") | |
| except ImportError: | |
| raise ImportError("soundfile is not installed. Please install it using `pip install soundfile`") | |
| except Exception as e: | |
| print(f"Error during generation: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| print("\nDone!") | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Copyright (c) 2025 Resemble AI | |
| # MIT License | |
| # MLX port of ChatterboxTurboTTS | |
| import os | |
| import math | |
| import logging | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Optional, Union | |
| import numpy as np | |
| import librosa | |
| import mlx.core as mx | |
| import mlx.nn as nn | |
| from models.t3 import T3, T3Config, T3Cond | |
| from models.s3gen import S3Gen, S3GEN_SR, S3GEN_SIL | |
| from models.voice_encoder import VoiceEncoder | |
| logger = logging.getLogger(__name__) | |
| # Constants | |
| S3_SR = 16000 # S3Tokenizer sample rate | |
| REPO_ID = "ResembleAI/chatterbox-turbo" | |
| def punc_norm(text: str) -> str: | |
| """ | |
| Quick cleanup func for punctuation from LLMs or | |
| containing chars not seen often in the dataset. | |
| """ | |
| if len(text) == 0: | |
| return "You need to add some text for me to talk." | |
| # Capitalise first letter | |
| if text[0].islower(): | |
| text = text[0].upper() + text[1:] | |
| # Remove multiple space chars | |
| text = " ".join(text.split()) | |
| # Replace uncommon/llm punc | |
| punc_to_replace = [ | |
| ("…", ", "), | |
| (":", ","), | |
| ("—", "-"), | |
| ("–", "-"), | |
| (" ,", ","), | |
| (""", "\""), | |
| (""", "\""), | |
| ("'", "'"), | |
| ("'", "'"), | |
| ] | |
| for old_char_sequence, new_char in punc_to_replace: | |
| text = text.replace(old_char_sequence, new_char) | |
| # Add full stop if no ending punc | |
| text = text.rstrip(" ") | |
| sentence_enders = {".", "!", "?", "-", ","} | |
| if not any(text.endswith(p) for p in sentence_enders): | |
| text += "." | |
| return text | |
| @dataclass | |
| class Conditionals: | |
| """ | |
| Conditionals for T3 and S3Gen. | |
| """ | |
| t3: T3Cond | |
| gen: dict | |
| def save(self, fpath: Path): | |
| """Save conditionals to file.""" | |
| import pickle | |
| with open(fpath, 'wb') as f: | |
| pickle.dump({'t3': self.t3, 'gen': self.gen}, f) | |
| @classmethod | |
| def load(cls, fpath: Path) -> 'Conditionals': | |
| """Load conditionals from file.""" | |
| import pickle | |
| with open(fpath, 'rb') as f: | |
| data = pickle.load(f) | |
| return cls(data['t3'], data['gen']) | |
| class ChatterboxTurboTTS: | |
| """ | |
| MLX implementation of Chatterbox Turbo TTS. | |
| Optimized for Apple Silicon. | |
| """ | |
| ENC_COND_LEN = 15 * S3_SR # 15 seconds for encoder conditioning | |
| DEC_COND_LEN = 10 * S3GEN_SR # 10 seconds for decoder conditioning | |
| def __init__( | |
| self, | |
| t3: T3, | |
| s3gen: S3Gen, | |
| ve: VoiceEncoder, | |
| tokenizer, # HuggingFace tokenizer | |
| conds: Optional[Conditionals] = None, | |
| local_path: Optional[str] = None, | |
| ): | |
| self.sr = S3GEN_SR # Output sample rate | |
| self.t3 = t3 | |
| self.s3gen = s3gen | |
| self.ve = ve | |
| self.tokenizer = tokenizer | |
| self.conds = conds | |
| self.local_path = local_path | |
| @classmethod | |
| def from_local(cls, ckpt_dir: Union[str, Path], device: str = "cpu") -> 'ChatterboxTurboTTS': | |
| """ | |
| Load model from local checkpoint directory. | |
| Args: | |
| ckpt_dir: Path to checkpoint directory | |
| device: Device to use (ignored in MLX, always uses Metal) | |
| Returns: | |
| ChatterboxTurboTTS instance | |
| """ | |
| ckpt_dir = Path(ckpt_dir) | |
| # Load Voice Encoder | |
| ve = VoiceEncoder() | |
| # Create T3 config for Turbo | |
| hp = T3Config.turbo() | |
| # Create T3 model | |
| t3 = T3(hp) | |
| # Create S3Gen | |
| s3gen = S3Gen(meanflow=True) | |
| # Try to load converted weights from model.safetensors | |
| model_weights_path = ckpt_dir / "model.safetensors" | |
| if model_weights_path.exists(): | |
| logger.info(f"Loading converted weights from {model_weights_path}") | |
| weights = mx.load(str(model_weights_path)) | |
| # Split weights by prefix and load into each model | |
| ve_weights = {k.replace("ve.", ""): v for k, v in weights.items() if k.startswith("ve.")} | |
| t3_weights = {k.replace("t3.", ""): v for k, v in weights.items() if k.startswith("t3.")} | |
| s3gen_weights = {k.replace("s3gen.", ""): v for k, v in weights.items() if k.startswith("s3gen.")} | |
| # Debug: Print expected vs loaded keys for VE | |
| from mlx.utils import tree_flatten | |
| ve_param_keys = [k for k, _ in tree_flatten(ve.parameters())] | |
| print(f"VE model expects these parameter keys: {ve_param_keys[:10]}...") | |
| print(f"VE weights from file: {list(ve_weights.keys())[:10]}...") | |
| if ve_weights: | |
| logger.info(f"Loading {len(ve_weights)} VE weights") | |
| try: | |
| ve.load_weights(list(ve_weights.items()), strict=True) | |
| logger.info("VE weights loaded successfully with strict=True") | |
| except Exception as e: | |
| logger.warning(f"VE strict loading failed: {e}") | |
| logger.info("Falling back to strict=False") | |
| ve.load_weights(list(ve_weights.items()), strict=False) | |
| if t3_weights: | |
| logger.info(f"Loading {len(t3_weights)} T3 weights") | |
| try: | |
| t3.load_weights(list(t3_weights.items()), strict=True) | |
| logger.info("T3 weights loaded successfully with strict=True") | |
| except Exception as e: | |
| logger.warning(f"T3 strict loading failed: {e}") | |
| logger.info("Falling back to strict=False") | |
| t3.load_weights(list(t3_weights.items()), strict=False) | |
| if s3gen_weights: | |
| logger.info(f"Loading {len(s3gen_weights)} S3Gen weights") | |
| # S3Gen has some parameters generated at init (not from weights): | |
| # - encoder.embed.pos_enc.pe, encoder.up_embed.pos_enc.pe (positional encodings) | |
| # - mel2wav.stft_window (STFT window from scipy) | |
| # - trim_fade (fade buffer) | |
| init_generated_params = { | |
| 'encoder.embed.pos_enc.pe', | |
| 'encoder.up_embed.pos_enc.pe', | |
| 'mel2wav.stft_window', | |
| 'trim_fade', | |
| } | |
| # Get all S3Gen parameter keys | |
| s3gen_param_keys = set(k for k, _ in tree_flatten(s3gen.parameters())) | |
| loadable_param_keys = s3gen_param_keys - init_generated_params | |
| # Find matching weights (weights that exist in model's loadable params) | |
| matching_weights = [(k, v) for k, v in s3gen_weights.items() if k in loadable_param_keys] | |
| # Check for any weights in file that don't match model | |
| unmatched_weights = set(s3gen_weights.keys()) - s3gen_param_keys | |
| if unmatched_weights: | |
| logger.debug(f"Weights in file not in model: {len(unmatched_weights)}") | |
| # Check for loadable params that don't have weights | |
| missing_weights = loadable_param_keys - set(s3gen_weights.keys()) | |
| if missing_weights: | |
| logger.warning(f"Model params without weights: {missing_weights}") | |
| logger.info(f"Loading {len(matching_weights)} S3Gen weights (excluding {len(init_generated_params)} init-generated params)") | |
| if matching_weights: | |
| # Load with strict=False since we're intentionally excluding init-generated params | |
| s3gen.load_weights(matching_weights, strict=False) | |
| # Verify all expected weights were loaded | |
| if len(matching_weights) == len(loadable_param_keys): | |
| logger.info("S3Gen weights loaded successfully (all loadable params matched)") | |
| else: | |
| logger.warning(f"S3Gen loaded {len(matching_weights)}/{len(loadable_param_keys)} loadable params") | |
| else: | |
| logger.warning("No matching S3Gen weights found - model may not work correctly") | |
| mx.eval(ve.parameters(), t3.parameters(), s3gen.parameters()) | |
| logger.info("Weights loaded successfully") | |
| else: | |
| logger.warning(f"No converted weights found at {model_weights_path}") | |
| logger.warning("Run convert_weights.py first to convert PyTorch weights to MLX format") | |
| # Load tokenizer | |
| try: | |
| from transformers import AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(str(ckpt_dir)) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| except Exception as e: | |
| logger.warning(f"Could not load tokenizer: {e}") | |
| tokenizer = None | |
| # Load pre-computed conditionals | |
| conds = None | |
| builtin_voice = ckpt_dir / "conds.pt" | |
| if builtin_voice.exists(): | |
| try: | |
| import torch | |
| conds_data = torch.load(builtin_voice, map_location='cpu', weights_only=True) | |
| # Convert to MLX arrays | |
| t3_cond_dict = conds_data.get('t3', {}) | |
| gen_dict = conds_data.get('gen', {}) | |
| # Helper to convert PyTorch tensor to numpy (handles requires_grad) | |
| def to_numpy(t): | |
| if hasattr(t, 'detach'): | |
| return t.detach().cpu().numpy() | |
| elif hasattr(t, 'numpy'): | |
| return t.numpy() | |
| return np.array(t) | |
| # Convert tensors to MLX arrays | |
| speaker_emb = t3_cond_dict.get('speaker_emb') | |
| if speaker_emb is not None: | |
| speaker_emb = mx.array(to_numpy(speaker_emb)) | |
| else: | |
| speaker_emb = mx.array(np.zeros((1, 256), dtype=np.float32)) | |
| cond_tokens = t3_cond_dict.get('cond_prompt_speech_tokens') | |
| if cond_tokens is not None: | |
| cond_tokens = mx.array(to_numpy(cond_tokens).astype(np.int32)) | |
| t3_cond = T3Cond( | |
| speaker_emb=speaker_emb, | |
| cond_prompt_speech_tokens=cond_tokens, | |
| ) | |
| gen_mlx = {} | |
| for k, v in gen_dict.items(): | |
| if hasattr(v, 'detach') or hasattr(v, 'numpy'): | |
| gen_mlx[k] = mx.array(to_numpy(v)) | |
| elif isinstance(v, (int, float)): | |
| gen_mlx[k] = v | |
| conds = Conditionals(t3_cond, gen_mlx) | |
| logger.info("Loaded pre-computed conditionals") | |
| except Exception as e: | |
| logger.warning(f"Could not load conditionals: {e}") | |
| return cls(t3, s3gen, ve, tokenizer, conds=conds, local_path=str(ckpt_dir)) | |
| @classmethod | |
| def from_pretrained(cls, device: str = "cpu", weights_path: str = None) -> 'ChatterboxTurboTTS': | |
| """ | |
| Load model from HuggingFace Hub. | |
| Args: | |
| device: Device to use (ignored in MLX) | |
| weights_path: Optional path to converted model.safetensors | |
| Returns: | |
| ChatterboxTurboTTS instance | |
| """ | |
| try: | |
| from huggingface_hub import snapshot_download | |
| local_path = snapshot_download( | |
| repo_id=REPO_ID, | |
| token=os.getenv("HF_TOKEN") or True, | |
| allow_patterns=["*.safetensors", "*.json", "*.txt", "*.pt", "*.model"] | |
| ) | |
| # If weights_path provided, always copy to ensure latest version is used | |
| if weights_path: | |
| import shutil | |
| dest = Path(local_path) / "model.safetensors" | |
| # Always copy to ensure we use the latest converted weights | |
| shutil.copy(weights_path, dest) | |
| logger.info(f"Copied converted weights to {dest}") | |
| return cls.from_local(local_path, device) | |
| except ImportError: | |
| raise ImportError("Please install huggingface_hub: pip install huggingface_hub") | |
| def norm_loudness(self, wav: np.ndarray, sr: int, target_lufs: float = -27) -> np.ndarray: | |
| """Normalize audio loudness.""" | |
| try: | |
| import pyloudnorm as ln | |
| meter = ln.Meter(sr) | |
| loudness = meter.integrated_loudness(wav) | |
| gain_db = target_lufs - loudness | |
| gain_linear = 10.0 ** (gain_db / 20.0) | |
| if math.isfinite(gain_linear) and gain_linear > 0.0: | |
| wav = wav * gain_linear | |
| except Exception as e: | |
| logger.warning(f"Error in norm_loudness, skipping: {e}") | |
| return wav | |
| def _extract_pytorch_conditionals(self, wav_fpath: str, norm_loudness: bool = True) -> tuple: | |
| """ | |
| Extract all conditioning using PyTorch (S3Gen embeddings + T3 tokens). | |
| This matches the original PyTorch tts_turbo.prepare_conditionals behavior. | |
| Args: | |
| wav_fpath: Path to reference audio | |
| norm_loudness: Whether to normalize loudness | |
| Returns: | |
| Tuple of (s3gen_ref_dict, t3_cond_prompt_tokens) or (None, None) on failure | |
| """ | |
| try: | |
| import sys | |
| import torch | |
| from safetensors.torch import load_file | |
| # Add PyTorch chatterbox to path | |
| pytorch_path = str(Path(__file__).parent.parent / "chatterbox" / "src") | |
| if pytorch_path not in sys.path: | |
| sys.path.insert(0, pytorch_path) | |
| from chatterbox.models.s3gen import S3Gen as S3GenPT | |
| # Initialize PyTorch S3Gen | |
| s3gen_pt = S3GenPT() | |
| # Load weights | |
| weights_path = Path(self.local_path) / "s3gen_meanflow.safetensors" | |
| if weights_path.exists(): | |
| state_dict = load_file(str(weights_path)) | |
| s3gen_pt.load_state_dict(state_dict, strict=False) | |
| s3gen_pt.eval() | |
| # Load and process audio at 24kHz for S3Gen | |
| s3gen_ref_wav, _ = librosa.load(wav_fpath, sr=S3GEN_SR) | |
| if norm_loudness: | |
| s3gen_ref_wav = self.norm_loudness(s3gen_ref_wav, S3GEN_SR) | |
| # Resample to 16kHz for tokenizer | |
| ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR) | |
| # Trim to conditioning lengths | |
| s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN] | |
| ref_16k_wav = ref_16k_wav[:self.ENC_COND_LEN] | |
| with torch.no_grad(): | |
| # Get S3Gen embeddings | |
| ref_dict = s3gen_pt.embed_ref(s3gen_ref_wav, S3GEN_SR) | |
| # Get T3 conditioning tokens using S3Gen's tokenizer (matches PyTorch exactly) | |
| plen = self.t3.hp.speech_cond_prompt_len | |
| t3_cond_prompt_tokens = None | |
| if plen and s3gen_pt.tokenizer is not None: | |
| t3_cond_prompt_tokens, _ = s3gen_pt.tokenizer.forward( | |
| [ref_16k_wav], max_len=plen | |
| ) | |
| t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens) | |
| # Convert S3Gen dict to MLX arrays | |
| mlx_ref_dict = {} | |
| for k, v in ref_dict.items(): | |
| if v is not None and torch.is_tensor(v): | |
| mlx_ref_dict[k] = mx.array(v.cpu().numpy()) | |
| elif v is not None: | |
| mlx_ref_dict[k] = v | |
| else: | |
| mlx_ref_dict[k] = None | |
| # Convert T3 tokens to MLX | |
| mlx_t3_tokens = None | |
| if t3_cond_prompt_tokens is not None: | |
| mlx_t3_tokens = mx.array(t3_cond_prompt_tokens.cpu().numpy()) | |
| logger.info("Extracted all conditionals using PyTorch") | |
| return mlx_ref_dict, mlx_t3_tokens | |
| except Exception as e: | |
| logger.warning(f"Failed to extract conditionals with PyTorch: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, None | |
| def prepare_conditionals( | |
| self, | |
| wav_fpath: str, | |
| exaggeration: float = 0.5, | |
| norm_loudness: bool = True, | |
| ): | |
| """ | |
| Prepare conditioning from a reference audio file. | |
| Args: | |
| wav_fpath: Path to reference audio file (should be > 5 seconds) | |
| exaggeration: Emotion exaggeration factor (not used in Turbo) | |
| norm_loudness: Whether to normalize loudness | |
| """ | |
| # Load reference audio at 24kHz for S3Gen | |
| s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR) | |
| assert len(s3gen_ref_wav) / S3GEN_SR > 5.0, "Audio prompt must be longer than 5 seconds!" | |
| if norm_loudness: | |
| s3gen_ref_wav = self.norm_loudness(s3gen_ref_wav, S3GEN_SR) | |
| # Resample to 16kHz for voice encoder | |
| ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR) | |
| # Trim to conditioning length | |
| s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN] | |
| # Try to extract all conditionals using PyTorch (for better quality) | |
| s3gen_ref_dict, t3_cond_prompt_tokens = self._extract_pytorch_conditionals(wav_fpath, norm_loudness) | |
| # Fallback if PyTorch extraction failed | |
| if s3gen_ref_dict is None: | |
| logger.warning("PyTorch extraction failed, using MLX fallback (may have lower quality)") | |
| s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR) | |
| # Fallback for T3 tokens | |
| plen = self.t3.hp.speech_cond_prompt_len | |
| if plen and t3_cond_prompt_tokens is None: | |
| logger.warning("Using zero tokens for T3 conditioning - audio quality may be poor") | |
| t3_cond_prompt_tokens = mx.zeros((1, plen), dtype=mx.int32) | |
| # Get voice encoder speaker embedding | |
| ve_embed = self.ve.embeds_from_wavs([ref_16k_wav[:self.ENC_COND_LEN]], sample_rate=S3_SR) | |
| ve_embed = mx.array(np.mean(ve_embed, axis=0, keepdims=True)) | |
| # Create T3 conditioning | |
| t3_cond = T3Cond( | |
| speaker_emb=ve_embed, | |
| cond_prompt_speech_tokens=t3_cond_prompt_tokens, | |
| emotion_adv=mx.array([[[exaggeration]]]) if self.t3.hp.emotion_adv else None, | |
| ) | |
| self.conds = Conditionals(t3_cond, s3gen_ref_dict) | |
| def generate( | |
| self, | |
| text: str, | |
| repetition_penalty: float = 1.2, | |
| min_p: float = 0.0, | |
| top_p: float = 0.95, | |
| audio_prompt_path: Optional[str] = None, | |
| exaggeration: float = 0.0, | |
| cfg_weight: float = 0.0, | |
| temperature: float = 0.8, | |
| top_k: int = 1000, | |
| norm_loudness: bool = True, | |
| ) -> mx.array: | |
| """ | |
| Generate speech from text. | |
| Args: | |
| text: Input text to synthesize | |
| repetition_penalty: Penalty for repeating tokens | |
| min_p: Minimum probability threshold (not used in Turbo) | |
| top_p: Nucleus sampling threshold | |
| audio_prompt_path: Optional path to reference audio for voice cloning | |
| exaggeration: Emotion exaggeration (not used in Turbo) | |
| cfg_weight: Classifier-free guidance weight (not used in Turbo) | |
| temperature: Sampling temperature | |
| top_k: Top-k sampling parameter | |
| norm_loudness: Whether to normalize output loudness | |
| Returns: | |
| Generated waveform as MLX array (1, T) | |
| """ | |
| # Prepare conditionals if audio prompt provided | |
| if audio_prompt_path: | |
| self.prepare_conditionals( | |
| audio_prompt_path, | |
| exaggeration=exaggeration, | |
| norm_loudness=norm_loudness | |
| ) | |
| else: | |
| assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`" | |
| # Warn about unsupported parameters | |
| if cfg_weight > 0.0 or exaggeration > 0.0 or min_p > 0.0: | |
| logger.warning("CFG, min_p and exaggeration are not supported by Turbo version and will be ignored.") | |
| # Normalize and tokenize text | |
| text = punc_norm(text) | |
| if self.tokenizer is not None: | |
| text_tokens = self.tokenizer(text, return_tensors="np", padding=True, truncation=True) | |
| text_tokens = mx.array(text_tokens.input_ids) | |
| else: | |
| # Fallback: simple character-level tokenization (for testing) | |
| logger.warning("No tokenizer available, using simple fallback") | |
| text_tokens = mx.array([[ord(c) for c in text[:512]]]) | |
| # Generate speech tokens with T3 | |
| speech_tokens = self.t3.inference_turbo( | |
| t3_cond=self.conds.t3, | |
| text_tokens=text_tokens, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| # Remove OOV tokens and add silence | |
| speech_tokens = speech_tokens.reshape(-1) | |
| mask = np.where(speech_tokens < 6561)[0].tolist() | |
| speech_tokens = speech_tokens[mask] | |
| silence = mx.array([S3GEN_SIL, S3GEN_SIL, S3GEN_SIL], dtype=mx.int32) | |
| speech_tokens = mx.concatenate([speech_tokens, silence]) | |
| speech_tokens = speech_tokens[None, :] # Add batch dimension | |
| # Generate waveform with S3Gen | |
| wav, _ = self.s3gen.inference( | |
| speech_tokens=speech_tokens, | |
| ref_dict=self.conds.gen, | |
| n_cfm_timesteps=2, # Turbo uses 2 steps | |
| ) | |
| # Post-process | |
| wav = wav[0] # Remove batch dimension | |
| wav_np = np.array(wav) | |
| # # Normalize loudness | |
| # if norm_loudness: | |
| # wav_np = self.norm_loudness(wav_np, self.sr) | |
| return mx.array(wav_np)[None, :] # (1, T) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment