Skip to content

Instantly share code, notes, and snippets.

@AmericanPresidentJimmyCarter
Last active January 15, 2026 23:30
Show Gist options
  • Select an option

  • Save AmericanPresidentJimmyCarter/b77df7fa37d36c425959845c9e72aaa1 to your computer and use it in GitHub Desktop.

Select an option

Save AmericanPresidentJimmyCarter/b77df7fa37d36c425959845c9e72aaa1 to your computer and use it in GitHub Desktop.
Inference LTX2 in int8, group offload, no pipeline
#!/usr/bin/env python3
import argparse
import gc
import os
import copy
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.utils import load_image
from diffusers.video_processor import VideoProcessor
from diffusers.utils.torch_utils import randn_tensor
import os
import shutil
import subprocess
import tempfile
from pathlib import Path
import numpy as np
import torch
def _write_wav_pcm16(path: str, audio_f32: np.ndarray, sample_rate: int) -> None:
"""
Write mono float32 audio in [-1, 1] to 16-bit PCM WAV using the stdlib.
No external deps.
"""
import wave
audio = np.clip(audio_f32, -1.0, 1.0)
audio_i16 = (audio * 32767.0).astype(np.int16)
with wave.open(path, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2) # 16-bit
wf.setframerate(int(sample_rate))
wf.writeframes(audio_i16.tobytes())
def write_mp4_with_audio_ffmpeg(
*,
out_path: str,
video_u8: np.ndarray,
fps: float,
audio: torch.Tensor,
audio_sample_rate: int,
macro_block_size: int = 16,
) -> str:
"""
video_u8: uint8 frames in one of these shapes:
- [F, H, W, C]
- [B, F, H, W, C] (we'll take B=0)
audio: torch tensor shaped [T] or [1, T] or [B, T] (we'll take the first channel/batch)
"""
from diffusers.utils import export_to_video # video-only writer :contentReference[oaicite:3]{index=3}
out_path = str(Path(out_path).resolve())
Path(out_path).parent.mkdir(parents=True, exist_ok=True)
# Normalize video shape to list of frames for export_to_video
if video_u8.ndim == 5:
video_u8 = video_u8[0] # take first batch
if video_u8.ndim != 4:
raise ValueError(f"video_u8 must be [F,H,W,C] or [B,F,H,W,C], got shape={video_u8.shape}")
frames = [video_u8[i] for i in range(video_u8.shape[0])]
# Normalize audio tensor to 1D float32 CPU numpy
if isinstance(audio, torch.Tensor):
a = audio.detach()
if a.ndim == 2:
a = a[0]
a = a.float().cpu().numpy()
else:
a = np.asarray(audio, dtype=np.float32)
if a.ndim == 2:
a = a[0]
with tempfile.TemporaryDirectory() as td:
td = Path(td)
video_path = str(td / "video.mp4")
audio_path = str(td / "audio.wav")
muxed_path = out_path
# 1) Write MP4 video-only
export_to_video(
frames,
output_video_path=video_path,
fps=int(round(float(fps))),
macro_block_size=macro_block_size,
)
# 2) Write WAV audio
_write_wav_pcm16(audio_path, a, int(audio_sample_rate))
# 3) Mux with ffmpeg
ffmpeg = shutil.which("ffmpeg")
if ffmpeg is None:
raise RuntimeError(
"ffmpeg not found on PATH. Install ffmpeg or add it to PATH. "
"Diffusers' export_to_video also relies on ffmpeg via imageio in many setups."
)
cmd = [
ffmpeg, "-y",
"-i", video_path,
"-i", audio_path,
"-c:v", "copy",
"-c:a", "aac",
"-b:a", "192k",
"-shortest",
muxed_path,
]
try:
subprocess.run(cmd, check=True, text=True, capture_output=True)
except subprocess.CalledProcessError as e:
print("\n[ffmpeg] mux failed. stderr:\n")
print(e.stderr)
# Retry: explicitly map streams and re-encode video to avoid timestamp/codec copy issues
cmd2 = [
ffmpeg, "-y",
"-i", video_path,
"-i", audio_path,
"-map", "0:v:0",
"-map", "1:a:0",
"-c:v", "libx264",
"-pix_fmt", "yuv420p",
"-r", str(int(round(float(fps)))),
"-c:a", "aac",
"-b:a", "192k",
"-shortest",
"-movflags", "+faststart",
muxed_path,
]
try:
subprocess.run(cmd2, check=True, text=True, capture_output=True)
except subprocess.CalledProcessError as e2:
print("\n[ffmpeg] retry failed. stderr:\n")
print(e2.stderr)
raise
return out_path
# ---------------------------
# Quanto helpers + cache I/O
# ---------------------------
def quantize_quanto_int8(module: torch.nn.Module, name: str, exclude: Optional[str] = None) -> torch.nn.Module:
"""
Quantize a module with optimum-quanto (qint8) and freeze.
"""
from optimum.quanto import freeze, qint8, quantize # type: ignore
kwargs = {"weights": qint8}
if exclude:
kwargs["exclude"] = exclude
print(f"[quanto] Quantizing {name} -> qint8 ...")
quantize(module, **kwargs)
freeze(module)
module.eval()
print(f"[quanto] Done: {name}")
return module
def torch_load_full_module(path: Path) -> torch.nn.Module:
return torch.load(path, map_location="cpu", weights_only=False)
def torch_save_full_module(module: torch.nn.Module, path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
# Full pickled module (fragile across package upgrades; delete cache when upgrading)
torch.save(module, path)
def load_or_build_quantized(path: Path, build_fn, name: str) -> torch.nn.Module:
if path.exists():
print(f"[cache] Loading {name} from: {path}")
return torch_load_full_module(path)
print(f"[cache] Miss for {name}; building + quantizing + saving to: {path}")
module = build_fn()
torch_save_full_module(module, path)
return module
def cuda_gc():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# ----------------------------------------
# Gemma prompt embedding (copied behavior)
# ----------------------------------------
def pack_text_embeds(
text_hidden_states: torch.Tensor, # [B, T, H, L]
sequence_lengths: torch.Tensor, # [B]
device: Union[str, torch.device],
padding_side: str = "left",
scale_factor: int = 8,
eps: float = 1e-6,
) -> torch.Tensor:
"""
Same as LTX2ImageToVideoPipeline._pack_text_embeds
"""
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
original_dtype = text_hidden_states.dtype
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
if padding_side == "right":
mask = token_indices < sequence_lengths[:, None]
elif padding_side == "left":
start_indices = seq_len - sequence_lengths[:, None]
mask = token_indices >= start_indices
else:
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
mask = mask[:, :, None, None] # [B,T,1,1]
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
normalized = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
normalized = normalized * scale_factor
normalized = normalized.flatten(2) # [B,T,H*L]
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
normalized = normalized.masked_fill(~mask_flat, 0.0)
return normalized.to(dtype=original_dtype)
@torch.no_grad()
def gemma_prompt_embeds(
prompt: Union[str, List[str]],
tokenizer,
text_encoder,
device: torch.device,
max_sequence_length: int = 1024,
scale_factor: int = 8,
dtype: Optional[torch.dtype] = None,
num_videos_per_prompt: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Produces (prompt_embeds, attention_mask) as expected by LTX2 encode_prompt,
matching the pipeline's behavior.
"""
prompt_list = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt_list)
tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
prompt_list = [p.strip() for p in prompt_list]
inputs = tokenizer(
prompt_list,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
input_ids = inputs.input_ids.to(device)
attn_mask = inputs.attention_mask.to(device)
# Ensure encoder on GPU
text_encoder.to(device)
text_encoder.eval()
out = text_encoder(input_ids=input_ids, attention_mask=attn_mask, output_hidden_states=True)
hidden_states = torch.stack(out.hidden_states, dim=-1) # [B,T,H,L]
seq_lens = attn_mask.sum(dim=-1)
embeds = pack_text_embeds(
hidden_states,
seq_lens,
device=device,
padding_side=tokenizer.padding_side,
scale_factor=scale_factor,
)
if dtype is not None:
embeds = embeds.to(dtype=dtype)
# duplicate for num_videos_per_prompt
_, seq_len, _ = embeds.shape
embeds = embeds.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len, -1)
attn_mask = attn_mask.view(batch_size, -1).repeat(num_videos_per_prompt, 1)
return embeds, attn_mask
# -----------------------------------------
# Transformer group-swap (N blocks at once)
# -----------------------------------------
@dataclass
class GroupSwapState:
blocks: List[torch.nn.Module]
device: torch.device
group_size: int
keep_non_block_on_gpu: bool
non_block_modules: List[torch.nn.Module]
current_group_start: Optional[int] = None
def _move_group(self, start: int, target: Union[str, torch.device]) -> None:
end = min(start + self.group_size, len(self.blocks))
for j in range(start, end):
self.blocks[j].to(target)
def ensure_group_loaded_for_index(self, i: int) -> None:
start = (i // self.group_size) * self.group_size
if self.current_group_start == start:
return
# Load new group first (so forward never sees missing weights)
import time
t0 = time.time()
self._move_group(start, self.device)
torch.cuda.synchronize()
print(f"loaded group {start} in {time.time()-t0:.2f}s")
# Optionally manage non-block modules on first group load of a forward pass
if not self.keep_non_block_on_gpu:
for m in self.non_block_modules:
m.to(self.device)
# Unload previous group
if self.current_group_start is not None:
self._move_group(self.current_group_start, "cpu")
self.current_group_start = start
def unload_all(self) -> None:
if self.current_group_start is not None:
self._move_group(self.current_group_start, "cpu")
self.current_group_start = None
if not self.keep_non_block_on_gpu:
for m in self.non_block_modules:
m.to("cpu")
def attach_group_swap_hooks(
transformer: torch.nn.Module,
group_size: int,
device: torch.device,
keep_non_block_on_gpu: bool = True,
) -> GroupSwapState:
assert hasattr(transformer, "transformer_blocks"), "Transformer must have transformer_blocks"
blocks: List[torch.nn.Module] = list(transformer.transformer_blocks)
non_block = []
for name, child in transformer.named_children():
if name != "transformer_blocks":
non_block.append(child)
state = GroupSwapState(
blocks=blocks,
device=device,
group_size=int(group_size),
keep_non_block_on_gpu=bool(keep_non_block_on_gpu),
non_block_modules=non_block,
)
# Start from CPU
transformer.to("cpu")
for b in blocks:
b.to("cpu")
if keep_non_block_on_gpu:
for m in non_block:
m.to(device)
else:
for m in non_block:
m.to("cpu")
def make_pre_hook(idx: int):
def _pre_hook(_module, _inputs):
state.ensure_group_loaded_for_index(idx)
return _pre_hook
for i, block in enumerate(blocks):
block.register_forward_pre_hook(make_pre_hook(i), with_kwargs=False)
# IMPORTANT: unload AFTER the entire transformer forward finishes
def _transformer_post_hook(_module, _inputs, _output):
state.unload_all()
return _output
transformer.register_forward_hook(_transformer_post_hook, with_kwargs=False)
return state
# -------------------------
# Main generation pipeline
# -------------------------
class DummyTextEncoder(torch.nn.Module):
def __init__(self, dtype: torch.dtype):
super().__init__()
self._dtype = dtype
@property
def dtype(self) -> torch.dtype:
return self._dtype
def forward(self, *args, **kwargs):
raise RuntimeError("DummyTextEncoder was called. Provide prompt_embeds to the pipeline.")
@torch.no_grad()
def decode_from_latents_components(
*,
vae,
audio_vae,
vocoder,
video_latents: torch.Tensor, # expected denormalized/unpacked: [B, C, F, H, W]
audio_latents: torch.Tensor, # expected denormalized/unpacked: [B, C, L, M]
device: torch.device,
output_type: str = "np", # "np" or "pil" (pil requires pillow installed; diffusers will handle)
):
"""
Decodes LTX-2 video+audio latents using the raw diffusers components (no pipeline).
Assumptions:
- video_latents are already denormalized + unpacked (as returned by your custom denoise function).
- audio_latents are already denormalized + unpacked.
"""
# ---- Video decode ----
vae.enable_tiling(
tile_sample_min_height=256,
tile_sample_min_width=256,
tile_sample_min_num_frames=32,
tile_sample_stride_height=128,
tile_sample_stride_width=128,
tile_sample_stride_num_frames=16,
)
vae.to(device)
latents_v = video_latents.to(device=device, dtype=vae.dtype)
# Stock pipeline: if timestep_conditioning is disabled, timestep=None
timestep = None
video = vae.decode(latents_v, timestep, return_dict=False)[0]
# Postprocess to np/pil like pipeline does
# Need the same scale factor that the pipeline uses:
vae_scale_factor = getattr(vae, "spatial_compression_ratio", 32)
video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor, resample="bilinear")
video = video_processor.postprocess_video(video, output_type=output_type)
# ---- Audio decode ----
audio_vae.to(device)
vocoder.to(device)
latents_a = audio_latents.to(device=device, dtype=audio_vae.dtype)
mel = audio_vae.decode(latents_a, return_dict=False)[0]
audio = vocoder(mel)
return video, audio
# ---- helpers copied from the pipeline (same behavior) ----
def retrieve_latents(encoder_output, generator=None, sample_mode="sample"):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def calculate_shift(image_seq_len, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.15):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
return image_seq_len * m + b
def retrieve_timesteps(scheduler, num_inference_steps, device, timesteps=None, sigmas=None, **kwargs):
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
if timesteps is not None:
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
return guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
# ---- latent packing utilities (copied from LTX2 pipeline) ----
def pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
b, c, f, h, w = latents.shape
f2 = f // patch_size_t
h2 = h // patch_size
w2 = w // patch_size
latents = latents.reshape(b, -1, f2, patch_size_t, h2, patch_size, w2, patch_size)
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
return latents
def unpack_latents(latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
b = latents.size(0)
latents = latents.reshape(b, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
return latents
def normalize_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0) -> torch.Tensor:
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
return (latents - latents_mean) * scaling_factor / latents_std
def denormalize_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0) -> torch.Tensor:
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
return latents * latents_std / scaling_factor + latents_mean
def pack_audio_latents(latents: torch.Tensor) -> torch.Tensor:
# [B, C, L, M] -> [B, L, C*M] (implicit patch)
return latents.transpose(1, 2).flatten(2, 3)
def unpack_audio_latents(latents: torch.Tensor, latent_length: int, num_mel_bins: int) -> torch.Tensor:
# [B, L, C*M] -> [B, C, L, M]
latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2)
return latents
def denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor) -> torch.Tensor:
latents_mean = latents_mean.to(latents.device, latents.dtype)
latents_std = latents_std.to(latents.device, latents.dtype)
return (latents * latents_std) + latents_mean
# ---- main custom denoise ----
@torch.no_grad()
def ltx2_denoise_to_latents(
*,
image,
prompt_embeds: torch.Tensor,
prompt_attention_mask: torch.Tensor,
negative_prompt_embeds: torch.Tensor,
negative_prompt_attention_mask: torch.Tensor,
scheduler,
vae,
audio_vae,
connectors,
transformer,
num_frames: int,
frame_rate: float,
height: int,
width: int,
num_inference_steps: int,
guidance_scale: float,
guidance_rescale: float,
device: torch.device,
dtype: torch.dtype,
generator: torch.Generator,
swap_state=None, # your GroupSwapState
keep_non_block_on_gpu=True,
step_log_every: int = 1, # set to 0 to disable
):
# Move only what must be on GPU for setup
connectors.to(device)
vae.to(device) # used for encode here; you can offload later yourself
# If you are doing block swapping, you MUST have non-block modules available on GPU before coords prep.
if swap_state is not None:
for m in swap_state.non_block_modules:
m.to(device)
# NEW: also move root-level tensors like scale_shift_table
move_module_root_tensors(transformer, device)
# Video processor (same as pipeline)
vae_spatial = getattr(vae, "spatial_compression_ratio", 32)
vae_temporal = getattr(vae, "temporal_compression_ratio", 8)
video_processor = VideoProcessor(vae_scale_factor=vae_spatial, resample="bilinear")
# Prepare prompt embeds like pipeline does (CFG concat)
do_cfg = guidance_scale > 1.0
if do_cfg:
prompt_embeds_cat = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_mask_cat = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
else:
prompt_embeds_cat = prompt_embeds
prompt_mask_cat = prompt_attention_mask
additive_attention_mask = (1 - prompt_mask_cat.to(prompt_embeds_cat.dtype)) * -1000000.0
connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = connectors(
prompt_embeds_cat, additive_attention_mask, additive_mask=True
)
# Preprocess image and encode to init latents (argmax path as in pipeline)
image_t = video_processor.preprocess(image, height=height, width=width).to(device=device, dtype=prompt_embeds_cat.dtype)
# ---- prepare video latents ----
h_lat = height // vae_spatial
w_lat = width // vae_spatial
f_lat = (num_frames - 1) // vae_temporal + 1
# encode first frame only (unsqueeze temporal dim)
init = retrieve_latents(vae.encode(image_t[0].unsqueeze(0).unsqueeze(2)), generator, "argmax")
init = init.to(dtype=torch.float32)
init = normalize_latents(init, vae.latents_mean, vae.latents_std) # Why is vae.config.scaling_factor not at the end?
init = init.repeat(1, 1, f_lat, 1, 1)
mask = torch.zeros((1, 1, f_lat, h_lat, w_lat), device=device, dtype=torch.float32)
mask[:, :, 0] = 1.0
noise = randn_tensor((1, init.shape[1], f_lat, h_lat, w_lat), generator=generator, device=device, dtype=torch.float32)
latents = init * mask + noise * (1 - mask)
patch_size = transformer.config.patch_size
patch_size_t = transformer.config.patch_size_t
conditioning_mask = pack_latents(mask, patch_size, patch_size_t).squeeze(-1)
latents = pack_latents(latents, patch_size, patch_size_t)
if do_cfg:
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask], dim=0)
# ---- prepare audio latents ----
audio_vae_mel_comp = getattr(audio_vae, "mel_compression_ratio", 4)
audio_vae_temp_comp = getattr(audio_vae, "temporal_compression_ratio", 4)
audio_sr = audio_vae.config.sample_rate
hop = audio_vae.config.mel_hop_length
duration_s = num_frames / frame_rate
latents_per_second = float(audio_sr) / float(hop) / float(audio_vae_temp_comp)
audio_num_frames = round(duration_s * latents_per_second)
num_mel_bins = audio_vae.config.mel_bins
latent_mel_bins = num_mel_bins // audio_vae_mel_comp
audio_shape = (1, audio_vae.config.latent_channels, audio_num_frames, latent_mel_bins)
audio_latents = randn_tensor(audio_shape, generator=generator, device=device, dtype=torch.float32)
audio_latents = pack_audio_latents(audio_latents)
# ---- timesteps ----
video_seq_len = f_lat * h_lat * w_lat
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
mu = calculate_shift(
video_seq_len,
scheduler.config.get("base_image_seq_len", 1024),
scheduler.config.get("max_image_seq_len", 4096),
scheduler.config.get("base_shift", 0.95),
scheduler.config.get("max_shift", 2.05),
)
audio_scheduler = copy.deepcopy(scheduler)
_ = retrieve_timesteps(audio_scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
timesteps, _ = retrieve_timesteps(scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
# ---- coords (requires rope/audio_rope on GPU) ----
video_coords = transformer.rope.prepare_video_coords(
latents.shape[0], f_lat, h_lat, w_lat, latents.device, fps=frame_rate
)
audio_coords = transformer.audio_rope.prepare_audio_coords(
audio_latents.shape[0], audio_num_frames, audio_latents.device
)
# ---- denoising loop ----
for i, t in enumerate(timesteps):
latent_model_input = torch.cat([latents] * 2) if do_cfg else latents
audio_in = torch.cat([audio_latents] * 2) if do_cfg else audio_latents
latent_model_input = latent_model_input.to(dtype)
audio_in = audio_in.to(dtype)
timestep = t.expand(latent_model_input.shape[0])
video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
with transformer.cache_context("cond_uncond"):
noise_pred_video, noise_pred_audio = transformer(
hidden_states=latent_model_input,
audio_hidden_states=audio_in,
encoder_hidden_states=connector_prompt_embeds,
audio_encoder_hidden_states=connector_audio_prompt_embeds,
timestep=video_timestep,
audio_timestep=timestep,
encoder_attention_mask=connector_attention_mask,
audio_encoder_attention_mask=connector_attention_mask,
num_frames=f_lat,
height=h_lat,
width=w_lat,
fps=frame_rate,
audio_num_frames=audio_num_frames,
video_coords=video_coords,
audio_coords=audio_coords,
attention_kwargs=None,
return_dict=False,
)
noise_pred_video = noise_pred_video.float()
noise_pred_audio = noise_pred_audio.float()
if do_cfg:
v_u, v_t = noise_pred_video.chunk(2)
a_u, a_t = noise_pred_audio.chunk(2)
noise_pred_video = v_u + guidance_scale * (v_t - v_u)
noise_pred_audio = a_u + guidance_scale * (a_t - a_u)
if guidance_rescale > 0:
noise_pred_video = rescale_noise_cfg(noise_pred_video, v_t, guidance_rescale=guidance_rescale)
noise_pred_audio = rescale_noise_cfg(noise_pred_audio, a_t, guidance_rescale=guidance_rescale)
# scheduler step (video)
noise_pred_video_u = unpack_latents(noise_pred_video, f_lat, h_lat, w_lat, patch_size, patch_size_t)
latents_u = unpack_latents(latents, f_lat, h_lat, w_lat, patch_size, patch_size_t)
noise_pred_video_u = noise_pred_video_u[:, :, 1:]
noise_latents = latents_u[:, :, 1:]
pred = scheduler.step(noise_pred_video_u, t, noise_latents, return_dict=False)[0]
latents_u = torch.cat([latents_u[:, :, :1], pred], dim=2)
latents = pack_latents(latents_u, patch_size, patch_size_t)
# scheduler step (audio)
audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0]
if step_log_every and (i % step_log_every == 0):
print(f"[step {i:02d}] t={int(t)} latents={latents.device}/{latents.dtype}")
# ---- denormalize / unpack to match output_type="latent" from pipeline ----
latents_u = unpack_latents(latents, f_lat, h_lat, w_lat, patch_size, patch_size_t)
latents_u = denormalize_latents(latents_u, vae.latents_mean, vae.latents_std, vae.config.scaling_factor)
audio_latents = denormalize_audio_latents(audio_latents, audio_vae.latents_mean, audio_vae.latents_std)
audio_latents = unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
# Optional: move non-block modules back off GPU if you want strict VRAM control
if swap_state is not None and not keep_non_block_on_gpu:
for m in swap_state.non_block_modules:
m.to("cpu")
move_module_root_tensors(transformer, "cpu")
return latents_u, audio_latents
def move_module_root_tensors(module: torch.nn.Module, device: Union[str, torch.device]) -> None:
"""
Move only parameters/buffers registered directly on `module` (recurse=False),
without touching submodules (e.g., transformer_blocks).
"""
# Parameters registered on the module itself (not children)
for name, p in list(module._parameters.items()):
if p is None:
continue
# Preserve the Parameter object; just move its storage
p.data = p.data.to(device)
# Buffers registered on the module itself (not children)
for name, b in list(module._buffers.items()):
if b is None:
continue
module._buffers[name] = b.to(device)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--image", required=True)
parser.add_argument("--prompt", required=True)
parser.add_argument("--negative", default="worst quality, inconsistent motion, blurry, jittery, distorted")
parser.add_argument("--out", default="ltx2_i2v.mp4")
parser.add_argument("--num_frames", type=int, default=121)
parser.add_argument("--fps", type=float, default=24.0)
parser.add_argument("--steps", type=int, default=20)
parser.add_argument("--guidance", type=float, default=4.0)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--cache_dir", default="./ltx2_quanto_cache")
parser.add_argument("--group_size", type=int, default=8, help="How many transformer blocks to keep on GPU at a time.")
parser.add_argument("--keep_non_block_on_gpu", action="store_true", help="Keep transformer non-block submodules on GPU.")
parser.add_argument("--dtype", default="bf16", choices=["bf16", "fp16"])
args = parser.parse_args()
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for practical inference here.")
device = torch.device("cuda")
torch_dtype = torch.bfloat16 if (args.dtype == "bf16" and torch.cuda.is_bf16_supported()) else torch.float16
print(f"[setup] device={device}, torch_dtype={torch_dtype}")
model_id = "Lightricks/LTX-2"
cache_dir = Path(args.cache_dir)
# Load non-huge components (CPU is fine; we'll move to GPU only when needed)
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.models.autoencoders import AutoencoderKLLTX2Video, AutoencoderKLLTX2Audio
from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
from diffusers.pipelines.ltx2 import LTX2ImageToVideoPipeline
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer
from diffusers.models.transformers import LTX2VideoTransformer3DModel
print("[setup] Loading scheduler/vae/audio_vae/vocoder/connectors/tokenizer ...")
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
vae = AutoencoderKLLTX2Video.from_pretrained(model_id, subfolder="vae", torch_dtype=torch_dtype).cpu()
audio_vae = AutoencoderKLLTX2Audio.from_pretrained(model_id, subfolder="audio_vae", torch_dtype=torch_dtype).cpu()
vocoder = LTX2Vocoder.from_pretrained(model_id, subfolder="vocoder", torch_dtype=torch_dtype).cpu()
connectors = LTX2TextConnectors.from_pretrained(model_id, subfolder="connectors", torch_dtype=torch_dtype).cpu()
# connectors = quantize_quanto_int8(connectors, "connectors")
tokenizer = GemmaTokenizer.from_pretrained(model_id, subfolder="tokenizer")
# ---------------------------
# Phase 1: text encoder only
# ---------------------------
text_cache = cache_dir / "text_encoder_gemma3_qint8.pt"
def build_text_encoder():
te = Gemma3ForConditionalGeneration.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch_dtype).cpu()
te = quantize_quanto_int8(te, "text_encoder")
return te
text_encoder = load_or_build_quantized(text_cache, build_text_encoder, "text_encoder")
# Compute prompt embeddings on GPU (text encoder present)
print("[text] Moving text encoder to GPU and encoding prompts ...")
generator = torch.Generator(device="cuda")
generator.manual_seed(args.seed)
prompt_embeds, prompt_mask = gemma_prompt_embeds(
args.prompt, tokenizer, text_encoder, device=device, max_sequence_length=1024,
scale_factor=8, dtype=torch_dtype, num_videos_per_prompt=1
)
neg_embeds, neg_mask = gemma_prompt_embeds(
args.negative or "", tokenizer, text_encoder, device=device, max_sequence_length=1024,
scale_factor=8, dtype=torch_dtype, num_videos_per_prompt=1
)
# Delete text encoder completely before loading transformer
print("[text] Deleting text encoder and clearing VRAM ...")
text_encoder.to("cpu")
del text_encoder
cuda_gc()
# -------------------------------
# Phase 2: transformer only (qint8)
# -------------------------------
tr_cache = cache_dir / "transformer_ltx2_qint8.pt"
def build_transformer():
tr = LTX2VideoTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch_dtype).cpu()
tr = quantize_quanto_int8(tr, "transformer")
return tr
transformer = load_or_build_quantized(tr_cache, build_transformer, "transformer")
# Attach N-block-at-a-time swaps
print(f"[swap] Attaching group swap hooks: group_size={args.group_size}, keep_non_block_on_gpu={args.keep_non_block_on_gpu}")
swap_state = attach_group_swap_hooks(
transformer,
group_size=args.group_size,
device=device,
keep_non_block_on_gpu=True # args.keep_non_block_on_gpu,
)
def debug_once(module, inputs):
# only print once
if getattr(module, "_printed", False):
return
module._printed = True
# inspect first tensor input device
hs = inputs[0] if len(inputs) else None
if isinstance(hs, (tuple, list)) and len(hs):
hs = hs[0]
if torch.is_tensor(hs):
print("[debug] transformer hidden_states device:", hs.device, "dtype:", hs.dtype)
transformer.register_forward_pre_hook(debug_once)
# Load image
image = load_image(args.image)
width, height = image.size
width //= 2
height //= 2
if (height % 32) != 0 or (width % 32) != 0:
raise ValueError("--height and --width must be divisible by 32")
# -------------------------------
# Run denoising only -> latents
# -------------------------------
print("[infer] Running denoising to latents (transformer swaps active) ...")
keep_non_block = True # strongly recommended
video_latents, audio_latents = ltx2_denoise_to_latents(
image=image,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_mask,
negative_prompt_embeds=neg_embeds,
negative_prompt_attention_mask=neg_mask,
scheduler=scheduler,
vae=vae,
audio_vae=audio_vae,
connectors=connectors,
transformer=transformer,
num_frames=args.num_frames,
frame_rate=args.fps,
height=height,
width=width,
num_inference_steps=args.steps,
guidance_scale=args.guidance,
guidance_rescale=0.0,
device=device,
dtype=torch_dtype,
generator=generator,
swap_state=swap_state,
keep_non_block_on_gpu=keep_non_block,
step_log_every=1,
)
# Free transformer before decode (per your requirement)
print("[swap] Freeing transformer and clearing VRAM before decode ...")
try:
swap_state.unload_all()
except Exception:
pass
transformer.to("cpu")
del transformer
cuda_gc()
# -------------------------------
# Decode latents -> final outputs
# -------------------------------
print("[decode] Decoding video/audio ...")
video, audio = decode_from_latents_components(
vae=vae,
audio_vae=audio_vae,
vocoder=vocoder,
video_latents=video_latents,
audio_latents=audio_latents,
device=device,
output_type="np",
)
# Write MP4 + audio using diffusers encode_video if available, else ffmpeg mux fallback
out_path = str(Path(args.out).resolve())
# video is float in [0,1], shape likely [B,F,H,W,C]
# video = (video + 1) / 2
video = 1.0 - video
video_u8 = (video.astype(np.float32) * 255.0).round().clip(0, 255).astype(np.uint8)
# video_u8 = 255 - video_u8
# audio is torch tensor (from vocoder). Usually [B, T].
sr = int(getattr(vocoder.config, "output_sampling_rate", 24000))
out_path = write_mp4_with_audio_ffmpeg(
out_path=str(Path(args.out).resolve()),
video_u8=video_u8,
fps=args.fps,
audio=audio,
audio_sample_rate=sr,
macro_block_size=16, # set to 1 if you ever output odd sizes
)
print(f"[done] Wrote: {out_path}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment