Skip to content

Instantly share code, notes, and snippets.

@sayan1999
Created April 16, 2026 14:58
Show Gist options
  • Select an option

  • Save sayan1999/05fd13931820058fc01a463bb7d7fe8f to your computer and use it in GitHub Desktop.

Select an option

Save sayan1999/05fd13931820058fc01a463bb7d7fe8f to your computer and use it in GitHub Desktop.
Chatterbox Cuda fp16
"""
Thin subclasses of the Chatterbox TTS models with partial FP16 quantization.
t3 + ve → FP16 (autoregressive transformer tolerates half precision)
s3gen → FP32 (flow matching + HiFiGAN vocoder — kept full precision for audio quality)
"""
import torch
from chatterbox.tts_turbo import ChatterboxTurboTTS as _Turbo
def _partial_fp16(model):
model.t3 = model.t3.half()
model.ve = model.ve.half()
if model.conds is not None:
model.conds.t3 = model.conds.t3.to(dtype=torch.float16)
for k, v in model.conds.gen.items():
if torch.is_tensor(v):
model.conds.gen[k] = v.to(dtype=torch.float32)
return model
class _GpuMixin:
@classmethod
def from_pretrained(cls, device="cuda"):
model = super().from_pretrained(device)
if not str(device).startswith("cuda"):
import warnings
warnings.warn(
f"Running on '{device}' — partial FP16 quantization skipped.",
RuntimeWarning,
)
return model
return _partial_fp16(model)
def generate(self, *args, **kwargs):
# autocast reconciles any FP32 tensors created inline by the base class
# (e.g. the exaggeration T3Cond update) against the FP16 t3 model.
if str(self.device).startswith("cuda"):
with torch.autocast(device_type="cuda", dtype=torch.float16):
return super().generate(*args, **kwargs)
return super().generate(*args, **kwargs)
class ChatterboxTurboTTS(_GpuMixin, _Turbo):
pass
ChatterboxTurboTTS.from_pretrained()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment