Last active
January 15, 2026 23:30
-
-
Save AmericanPresidentJimmyCarter/b77df7fa37d36c425959845c9e72aaa1 to your computer and use it in GitHub Desktop.
Inference LTX2 in int8, group offload, no pipeline
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| import argparse | |
| import gc | |
| import os | |
| import copy | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from diffusers.utils import load_image | |
| from diffusers.video_processor import VideoProcessor | |
| from diffusers.utils.torch_utils import randn_tensor | |
| import os | |
| import shutil | |
| import subprocess | |
| import tempfile | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| def _write_wav_pcm16(path: str, audio_f32: np.ndarray, sample_rate: int) -> None: | |
| """ | |
| Write mono float32 audio in [-1, 1] to 16-bit PCM WAV using the stdlib. | |
| No external deps. | |
| """ | |
| import wave | |
| audio = np.clip(audio_f32, -1.0, 1.0) | |
| audio_i16 = (audio * 32767.0).astype(np.int16) | |
| with wave.open(path, "wb") as wf: | |
| wf.setnchannels(1) | |
| wf.setsampwidth(2) # 16-bit | |
| wf.setframerate(int(sample_rate)) | |
| wf.writeframes(audio_i16.tobytes()) | |
| def write_mp4_with_audio_ffmpeg( | |
| *, | |
| out_path: str, | |
| video_u8: np.ndarray, | |
| fps: float, | |
| audio: torch.Tensor, | |
| audio_sample_rate: int, | |
| macro_block_size: int = 16, | |
| ) -> str: | |
| """ | |
| video_u8: uint8 frames in one of these shapes: | |
| - [F, H, W, C] | |
| - [B, F, H, W, C] (we'll take B=0) | |
| audio: torch tensor shaped [T] or [1, T] or [B, T] (we'll take the first channel/batch) | |
| """ | |
| from diffusers.utils import export_to_video # video-only writer :contentReference[oaicite:3]{index=3} | |
| out_path = str(Path(out_path).resolve()) | |
| Path(out_path).parent.mkdir(parents=True, exist_ok=True) | |
| # Normalize video shape to list of frames for export_to_video | |
| if video_u8.ndim == 5: | |
| video_u8 = video_u8[0] # take first batch | |
| if video_u8.ndim != 4: | |
| raise ValueError(f"video_u8 must be [F,H,W,C] or [B,F,H,W,C], got shape={video_u8.shape}") | |
| frames = [video_u8[i] for i in range(video_u8.shape[0])] | |
| # Normalize audio tensor to 1D float32 CPU numpy | |
| if isinstance(audio, torch.Tensor): | |
| a = audio.detach() | |
| if a.ndim == 2: | |
| a = a[0] | |
| a = a.float().cpu().numpy() | |
| else: | |
| a = np.asarray(audio, dtype=np.float32) | |
| if a.ndim == 2: | |
| a = a[0] | |
| with tempfile.TemporaryDirectory() as td: | |
| td = Path(td) | |
| video_path = str(td / "video.mp4") | |
| audio_path = str(td / "audio.wav") | |
| muxed_path = out_path | |
| # 1) Write MP4 video-only | |
| export_to_video( | |
| frames, | |
| output_video_path=video_path, | |
| fps=int(round(float(fps))), | |
| macro_block_size=macro_block_size, | |
| ) | |
| # 2) Write WAV audio | |
| _write_wav_pcm16(audio_path, a, int(audio_sample_rate)) | |
| # 3) Mux with ffmpeg | |
| ffmpeg = shutil.which("ffmpeg") | |
| if ffmpeg is None: | |
| raise RuntimeError( | |
| "ffmpeg not found on PATH. Install ffmpeg or add it to PATH. " | |
| "Diffusers' export_to_video also relies on ffmpeg via imageio in many setups." | |
| ) | |
| cmd = [ | |
| ffmpeg, "-y", | |
| "-i", video_path, | |
| "-i", audio_path, | |
| "-c:v", "copy", | |
| "-c:a", "aac", | |
| "-b:a", "192k", | |
| "-shortest", | |
| muxed_path, | |
| ] | |
| try: | |
| subprocess.run(cmd, check=True, text=True, capture_output=True) | |
| except subprocess.CalledProcessError as e: | |
| print("\n[ffmpeg] mux failed. stderr:\n") | |
| print(e.stderr) | |
| # Retry: explicitly map streams and re-encode video to avoid timestamp/codec copy issues | |
| cmd2 = [ | |
| ffmpeg, "-y", | |
| "-i", video_path, | |
| "-i", audio_path, | |
| "-map", "0:v:0", | |
| "-map", "1:a:0", | |
| "-c:v", "libx264", | |
| "-pix_fmt", "yuv420p", | |
| "-r", str(int(round(float(fps)))), | |
| "-c:a", "aac", | |
| "-b:a", "192k", | |
| "-shortest", | |
| "-movflags", "+faststart", | |
| muxed_path, | |
| ] | |
| try: | |
| subprocess.run(cmd2, check=True, text=True, capture_output=True) | |
| except subprocess.CalledProcessError as e2: | |
| print("\n[ffmpeg] retry failed. stderr:\n") | |
| print(e2.stderr) | |
| raise | |
| return out_path | |
| # --------------------------- | |
| # Quanto helpers + cache I/O | |
| # --------------------------- | |
| def quantize_quanto_int8(module: torch.nn.Module, name: str, exclude: Optional[str] = None) -> torch.nn.Module: | |
| """ | |
| Quantize a module with optimum-quanto (qint8) and freeze. | |
| """ | |
| from optimum.quanto import freeze, qint8, quantize # type: ignore | |
| kwargs = {"weights": qint8} | |
| if exclude: | |
| kwargs["exclude"] = exclude | |
| print(f"[quanto] Quantizing {name} -> qint8 ...") | |
| quantize(module, **kwargs) | |
| freeze(module) | |
| module.eval() | |
| print(f"[quanto] Done: {name}") | |
| return module | |
| def torch_load_full_module(path: Path) -> torch.nn.Module: | |
| return torch.load(path, map_location="cpu", weights_only=False) | |
| def torch_save_full_module(module: torch.nn.Module, path: Path) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| # Full pickled module (fragile across package upgrades; delete cache when upgrading) | |
| torch.save(module, path) | |
| def load_or_build_quantized(path: Path, build_fn, name: str) -> torch.nn.Module: | |
| if path.exists(): | |
| print(f"[cache] Loading {name} from: {path}") | |
| return torch_load_full_module(path) | |
| print(f"[cache] Miss for {name}; building + quantizing + saving to: {path}") | |
| module = build_fn() | |
| torch_save_full_module(module, path) | |
| return module | |
| def cuda_gc(): | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # ---------------------------------------- | |
| # Gemma prompt embedding (copied behavior) | |
| # ---------------------------------------- | |
| def pack_text_embeds( | |
| text_hidden_states: torch.Tensor, # [B, T, H, L] | |
| sequence_lengths: torch.Tensor, # [B] | |
| device: Union[str, torch.device], | |
| padding_side: str = "left", | |
| scale_factor: int = 8, | |
| eps: float = 1e-6, | |
| ) -> torch.Tensor: | |
| """ | |
| Same as LTX2ImageToVideoPipeline._pack_text_embeds | |
| """ | |
| batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape | |
| original_dtype = text_hidden_states.dtype | |
| token_indices = torch.arange(seq_len, device=device).unsqueeze(0) | |
| if padding_side == "right": | |
| mask = token_indices < sequence_lengths[:, None] | |
| elif padding_side == "left": | |
| start_indices = seq_len - sequence_lengths[:, None] | |
| mask = token_indices >= start_indices | |
| else: | |
| raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") | |
| mask = mask[:, :, None, None] # [B,T,1,1] | |
| masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) | |
| num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) | |
| masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) | |
| x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) | |
| x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) | |
| normalized = (text_hidden_states - masked_mean) / (x_max - x_min + eps) | |
| normalized = normalized * scale_factor | |
| normalized = normalized.flatten(2) # [B,T,H*L] | |
| mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) | |
| normalized = normalized.masked_fill(~mask_flat, 0.0) | |
| return normalized.to(dtype=original_dtype) | |
| @torch.no_grad() | |
| def gemma_prompt_embeds( | |
| prompt: Union[str, List[str]], | |
| tokenizer, | |
| text_encoder, | |
| device: torch.device, | |
| max_sequence_length: int = 1024, | |
| scale_factor: int = 8, | |
| dtype: Optional[torch.dtype] = None, | |
| num_videos_per_prompt: int = 1, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Produces (prompt_embeds, attention_mask) as expected by LTX2 encode_prompt, | |
| matching the pipeline's behavior. | |
| """ | |
| prompt_list = [prompt] if isinstance(prompt, str) else prompt | |
| batch_size = len(prompt_list) | |
| tokenizer.padding_side = "left" | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| prompt_list = [p.strip() for p in prompt_list] | |
| inputs = tokenizer( | |
| prompt_list, | |
| padding="max_length", | |
| max_length=max_sequence_length, | |
| truncation=True, | |
| add_special_tokens=True, | |
| return_tensors="pt", | |
| ) | |
| input_ids = inputs.input_ids.to(device) | |
| attn_mask = inputs.attention_mask.to(device) | |
| # Ensure encoder on GPU | |
| text_encoder.to(device) | |
| text_encoder.eval() | |
| out = text_encoder(input_ids=input_ids, attention_mask=attn_mask, output_hidden_states=True) | |
| hidden_states = torch.stack(out.hidden_states, dim=-1) # [B,T,H,L] | |
| seq_lens = attn_mask.sum(dim=-1) | |
| embeds = pack_text_embeds( | |
| hidden_states, | |
| seq_lens, | |
| device=device, | |
| padding_side=tokenizer.padding_side, | |
| scale_factor=scale_factor, | |
| ) | |
| if dtype is not None: | |
| embeds = embeds.to(dtype=dtype) | |
| # duplicate for num_videos_per_prompt | |
| _, seq_len, _ = embeds.shape | |
| embeds = embeds.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len, -1) | |
| attn_mask = attn_mask.view(batch_size, -1).repeat(num_videos_per_prompt, 1) | |
| return embeds, attn_mask | |
| # ----------------------------------------- | |
| # Transformer group-swap (N blocks at once) | |
| # ----------------------------------------- | |
| @dataclass | |
| class GroupSwapState: | |
| blocks: List[torch.nn.Module] | |
| device: torch.device | |
| group_size: int | |
| keep_non_block_on_gpu: bool | |
| non_block_modules: List[torch.nn.Module] | |
| current_group_start: Optional[int] = None | |
| def _move_group(self, start: int, target: Union[str, torch.device]) -> None: | |
| end = min(start + self.group_size, len(self.blocks)) | |
| for j in range(start, end): | |
| self.blocks[j].to(target) | |
| def ensure_group_loaded_for_index(self, i: int) -> None: | |
| start = (i // self.group_size) * self.group_size | |
| if self.current_group_start == start: | |
| return | |
| # Load new group first (so forward never sees missing weights) | |
| import time | |
| t0 = time.time() | |
| self._move_group(start, self.device) | |
| torch.cuda.synchronize() | |
| print(f"loaded group {start} in {time.time()-t0:.2f}s") | |
| # Optionally manage non-block modules on first group load of a forward pass | |
| if not self.keep_non_block_on_gpu: | |
| for m in self.non_block_modules: | |
| m.to(self.device) | |
| # Unload previous group | |
| if self.current_group_start is not None: | |
| self._move_group(self.current_group_start, "cpu") | |
| self.current_group_start = start | |
| def unload_all(self) -> None: | |
| if self.current_group_start is not None: | |
| self._move_group(self.current_group_start, "cpu") | |
| self.current_group_start = None | |
| if not self.keep_non_block_on_gpu: | |
| for m in self.non_block_modules: | |
| m.to("cpu") | |
| def attach_group_swap_hooks( | |
| transformer: torch.nn.Module, | |
| group_size: int, | |
| device: torch.device, | |
| keep_non_block_on_gpu: bool = True, | |
| ) -> GroupSwapState: | |
| assert hasattr(transformer, "transformer_blocks"), "Transformer must have transformer_blocks" | |
| blocks: List[torch.nn.Module] = list(transformer.transformer_blocks) | |
| non_block = [] | |
| for name, child in transformer.named_children(): | |
| if name != "transformer_blocks": | |
| non_block.append(child) | |
| state = GroupSwapState( | |
| blocks=blocks, | |
| device=device, | |
| group_size=int(group_size), | |
| keep_non_block_on_gpu=bool(keep_non_block_on_gpu), | |
| non_block_modules=non_block, | |
| ) | |
| # Start from CPU | |
| transformer.to("cpu") | |
| for b in blocks: | |
| b.to("cpu") | |
| if keep_non_block_on_gpu: | |
| for m in non_block: | |
| m.to(device) | |
| else: | |
| for m in non_block: | |
| m.to("cpu") | |
| def make_pre_hook(idx: int): | |
| def _pre_hook(_module, _inputs): | |
| state.ensure_group_loaded_for_index(idx) | |
| return _pre_hook | |
| for i, block in enumerate(blocks): | |
| block.register_forward_pre_hook(make_pre_hook(i), with_kwargs=False) | |
| # IMPORTANT: unload AFTER the entire transformer forward finishes | |
| def _transformer_post_hook(_module, _inputs, _output): | |
| state.unload_all() | |
| return _output | |
| transformer.register_forward_hook(_transformer_post_hook, with_kwargs=False) | |
| return state | |
| # ------------------------- | |
| # Main generation pipeline | |
| # ------------------------- | |
| class DummyTextEncoder(torch.nn.Module): | |
| def __init__(self, dtype: torch.dtype): | |
| super().__init__() | |
| self._dtype = dtype | |
| @property | |
| def dtype(self) -> torch.dtype: | |
| return self._dtype | |
| def forward(self, *args, **kwargs): | |
| raise RuntimeError("DummyTextEncoder was called. Provide prompt_embeds to the pipeline.") | |
| @torch.no_grad() | |
| def decode_from_latents_components( | |
| *, | |
| vae, | |
| audio_vae, | |
| vocoder, | |
| video_latents: torch.Tensor, # expected denormalized/unpacked: [B, C, F, H, W] | |
| audio_latents: torch.Tensor, # expected denormalized/unpacked: [B, C, L, M] | |
| device: torch.device, | |
| output_type: str = "np", # "np" or "pil" (pil requires pillow installed; diffusers will handle) | |
| ): | |
| """ | |
| Decodes LTX-2 video+audio latents using the raw diffusers components (no pipeline). | |
| Assumptions: | |
| - video_latents are already denormalized + unpacked (as returned by your custom denoise function). | |
| - audio_latents are already denormalized + unpacked. | |
| """ | |
| # ---- Video decode ---- | |
| vae.enable_tiling( | |
| tile_sample_min_height=256, | |
| tile_sample_min_width=256, | |
| tile_sample_min_num_frames=32, | |
| tile_sample_stride_height=128, | |
| tile_sample_stride_width=128, | |
| tile_sample_stride_num_frames=16, | |
| ) | |
| vae.to(device) | |
| latents_v = video_latents.to(device=device, dtype=vae.dtype) | |
| # Stock pipeline: if timestep_conditioning is disabled, timestep=None | |
| timestep = None | |
| video = vae.decode(latents_v, timestep, return_dict=False)[0] | |
| # Postprocess to np/pil like pipeline does | |
| # Need the same scale factor that the pipeline uses: | |
| vae_scale_factor = getattr(vae, "spatial_compression_ratio", 32) | |
| video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor, resample="bilinear") | |
| video = video_processor.postprocess_video(video, output_type=output_type) | |
| # ---- Audio decode ---- | |
| audio_vae.to(device) | |
| vocoder.to(device) | |
| latents_a = audio_latents.to(device=device, dtype=audio_vae.dtype) | |
| mel = audio_vae.decode(latents_a, return_dict=False)[0] | |
| audio = vocoder(mel) | |
| return video, audio | |
| # ---- helpers copied from the pipeline (same behavior) ---- | |
| def retrieve_latents(encoder_output, generator=None, sample_mode="sample"): | |
| if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": | |
| return encoder_output.latent_dist.sample(generator) | |
| elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": | |
| return encoder_output.latent_dist.mode() | |
| elif hasattr(encoder_output, "latents"): | |
| return encoder_output.latents | |
| else: | |
| raise AttributeError("Could not access latents of provided encoder_output") | |
| def calculate_shift(image_seq_len, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.15): | |
| m = (max_shift - base_shift) / (max_seq_len - base_seq_len) | |
| b = base_shift - m * base_seq_len | |
| return image_seq_len * m + b | |
| def retrieve_timesteps(scheduler, num_inference_steps, device, timesteps=None, sigmas=None, **kwargs): | |
| if timesteps is not None and sigmas is not None: | |
| raise ValueError("Only one of `timesteps` or `sigmas` can be passed.") | |
| if timesteps is not None: | |
| scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| num_inference_steps = len(timesteps) | |
| elif sigmas is not None: | |
| scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| else: | |
| scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| return timesteps, num_inference_steps | |
| def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): | |
| std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) | |
| std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) | |
| noise_pred_rescaled = noise_cfg * (std_text / std_cfg) | |
| return guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg | |
| # ---- latent packing utilities (copied from LTX2 pipeline) ---- | |
| def pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: | |
| b, c, f, h, w = latents.shape | |
| f2 = f // patch_size_t | |
| h2 = h // patch_size | |
| w2 = w // patch_size | |
| latents = latents.reshape(b, -1, f2, patch_size_t, h2, patch_size, w2, patch_size) | |
| latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) | |
| return latents | |
| def unpack_latents(latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: | |
| b = latents.size(0) | |
| latents = latents.reshape(b, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) | |
| latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) | |
| return latents | |
| def normalize_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0) -> torch.Tensor: | |
| latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) | |
| latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) | |
| return (latents - latents_mean) * scaling_factor / latents_std | |
| def denormalize_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0) -> torch.Tensor: | |
| latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) | |
| latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) | |
| return latents * latents_std / scaling_factor + latents_mean | |
| def pack_audio_latents(latents: torch.Tensor) -> torch.Tensor: | |
| # [B, C, L, M] -> [B, L, C*M] (implicit patch) | |
| return latents.transpose(1, 2).flatten(2, 3) | |
| def unpack_audio_latents(latents: torch.Tensor, latent_length: int, num_mel_bins: int) -> torch.Tensor: | |
| # [B, L, C*M] -> [B, C, L, M] | |
| latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) | |
| return latents | |
| def denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor) -> torch.Tensor: | |
| latents_mean = latents_mean.to(latents.device, latents.dtype) | |
| latents_std = latents_std.to(latents.device, latents.dtype) | |
| return (latents * latents_std) + latents_mean | |
| # ---- main custom denoise ---- | |
| @torch.no_grad() | |
| def ltx2_denoise_to_latents( | |
| *, | |
| image, | |
| prompt_embeds: torch.Tensor, | |
| prompt_attention_mask: torch.Tensor, | |
| negative_prompt_embeds: torch.Tensor, | |
| negative_prompt_attention_mask: torch.Tensor, | |
| scheduler, | |
| vae, | |
| audio_vae, | |
| connectors, | |
| transformer, | |
| num_frames: int, | |
| frame_rate: float, | |
| height: int, | |
| width: int, | |
| num_inference_steps: int, | |
| guidance_scale: float, | |
| guidance_rescale: float, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| generator: torch.Generator, | |
| swap_state=None, # your GroupSwapState | |
| keep_non_block_on_gpu=True, | |
| step_log_every: int = 1, # set to 0 to disable | |
| ): | |
| # Move only what must be on GPU for setup | |
| connectors.to(device) | |
| vae.to(device) # used for encode here; you can offload later yourself | |
| # If you are doing block swapping, you MUST have non-block modules available on GPU before coords prep. | |
| if swap_state is not None: | |
| for m in swap_state.non_block_modules: | |
| m.to(device) | |
| # NEW: also move root-level tensors like scale_shift_table | |
| move_module_root_tensors(transformer, device) | |
| # Video processor (same as pipeline) | |
| vae_spatial = getattr(vae, "spatial_compression_ratio", 32) | |
| vae_temporal = getattr(vae, "temporal_compression_ratio", 8) | |
| video_processor = VideoProcessor(vae_scale_factor=vae_spatial, resample="bilinear") | |
| # Prepare prompt embeds like pipeline does (CFG concat) | |
| do_cfg = guidance_scale > 1.0 | |
| if do_cfg: | |
| prompt_embeds_cat = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) | |
| prompt_mask_cat = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) | |
| else: | |
| prompt_embeds_cat = prompt_embeds | |
| prompt_mask_cat = prompt_attention_mask | |
| additive_attention_mask = (1 - prompt_mask_cat.to(prompt_embeds_cat.dtype)) * -1000000.0 | |
| connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = connectors( | |
| prompt_embeds_cat, additive_attention_mask, additive_mask=True | |
| ) | |
| # Preprocess image and encode to init latents (argmax path as in pipeline) | |
| image_t = video_processor.preprocess(image, height=height, width=width).to(device=device, dtype=prompt_embeds_cat.dtype) | |
| # ---- prepare video latents ---- | |
| h_lat = height // vae_spatial | |
| w_lat = width // vae_spatial | |
| f_lat = (num_frames - 1) // vae_temporal + 1 | |
| # encode first frame only (unsqueeze temporal dim) | |
| init = retrieve_latents(vae.encode(image_t[0].unsqueeze(0).unsqueeze(2)), generator, "argmax") | |
| init = init.to(dtype=torch.float32) | |
| init = normalize_latents(init, vae.latents_mean, vae.latents_std) # Why is vae.config.scaling_factor not at the end? | |
| init = init.repeat(1, 1, f_lat, 1, 1) | |
| mask = torch.zeros((1, 1, f_lat, h_lat, w_lat), device=device, dtype=torch.float32) | |
| mask[:, :, 0] = 1.0 | |
| noise = randn_tensor((1, init.shape[1], f_lat, h_lat, w_lat), generator=generator, device=device, dtype=torch.float32) | |
| latents = init * mask + noise * (1 - mask) | |
| patch_size = transformer.config.patch_size | |
| patch_size_t = transformer.config.patch_size_t | |
| conditioning_mask = pack_latents(mask, patch_size, patch_size_t).squeeze(-1) | |
| latents = pack_latents(latents, patch_size, patch_size_t) | |
| if do_cfg: | |
| conditioning_mask = torch.cat([conditioning_mask, conditioning_mask], dim=0) | |
| # ---- prepare audio latents ---- | |
| audio_vae_mel_comp = getattr(audio_vae, "mel_compression_ratio", 4) | |
| audio_vae_temp_comp = getattr(audio_vae, "temporal_compression_ratio", 4) | |
| audio_sr = audio_vae.config.sample_rate | |
| hop = audio_vae.config.mel_hop_length | |
| duration_s = num_frames / frame_rate | |
| latents_per_second = float(audio_sr) / float(hop) / float(audio_vae_temp_comp) | |
| audio_num_frames = round(duration_s * latents_per_second) | |
| num_mel_bins = audio_vae.config.mel_bins | |
| latent_mel_bins = num_mel_bins // audio_vae_mel_comp | |
| audio_shape = (1, audio_vae.config.latent_channels, audio_num_frames, latent_mel_bins) | |
| audio_latents = randn_tensor(audio_shape, generator=generator, device=device, dtype=torch.float32) | |
| audio_latents = pack_audio_latents(audio_latents) | |
| # ---- timesteps ---- | |
| video_seq_len = f_lat * h_lat * w_lat | |
| sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) | |
| mu = calculate_shift( | |
| video_seq_len, | |
| scheduler.config.get("base_image_seq_len", 1024), | |
| scheduler.config.get("max_image_seq_len", 4096), | |
| scheduler.config.get("base_shift", 0.95), | |
| scheduler.config.get("max_shift", 2.05), | |
| ) | |
| audio_scheduler = copy.deepcopy(scheduler) | |
| _ = retrieve_timesteps(audio_scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu) | |
| timesteps, _ = retrieve_timesteps(scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu) | |
| # ---- coords (requires rope/audio_rope on GPU) ---- | |
| video_coords = transformer.rope.prepare_video_coords( | |
| latents.shape[0], f_lat, h_lat, w_lat, latents.device, fps=frame_rate | |
| ) | |
| audio_coords = transformer.audio_rope.prepare_audio_coords( | |
| audio_latents.shape[0], audio_num_frames, audio_latents.device | |
| ) | |
| # ---- denoising loop ---- | |
| for i, t in enumerate(timesteps): | |
| latent_model_input = torch.cat([latents] * 2) if do_cfg else latents | |
| audio_in = torch.cat([audio_latents] * 2) if do_cfg else audio_latents | |
| latent_model_input = latent_model_input.to(dtype) | |
| audio_in = audio_in.to(dtype) | |
| timestep = t.expand(latent_model_input.shape[0]) | |
| video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) | |
| with transformer.cache_context("cond_uncond"): | |
| noise_pred_video, noise_pred_audio = transformer( | |
| hidden_states=latent_model_input, | |
| audio_hidden_states=audio_in, | |
| encoder_hidden_states=connector_prompt_embeds, | |
| audio_encoder_hidden_states=connector_audio_prompt_embeds, | |
| timestep=video_timestep, | |
| audio_timestep=timestep, | |
| encoder_attention_mask=connector_attention_mask, | |
| audio_encoder_attention_mask=connector_attention_mask, | |
| num_frames=f_lat, | |
| height=h_lat, | |
| width=w_lat, | |
| fps=frame_rate, | |
| audio_num_frames=audio_num_frames, | |
| video_coords=video_coords, | |
| audio_coords=audio_coords, | |
| attention_kwargs=None, | |
| return_dict=False, | |
| ) | |
| noise_pred_video = noise_pred_video.float() | |
| noise_pred_audio = noise_pred_audio.float() | |
| if do_cfg: | |
| v_u, v_t = noise_pred_video.chunk(2) | |
| a_u, a_t = noise_pred_audio.chunk(2) | |
| noise_pred_video = v_u + guidance_scale * (v_t - v_u) | |
| noise_pred_audio = a_u + guidance_scale * (a_t - a_u) | |
| if guidance_rescale > 0: | |
| noise_pred_video = rescale_noise_cfg(noise_pred_video, v_t, guidance_rescale=guidance_rescale) | |
| noise_pred_audio = rescale_noise_cfg(noise_pred_audio, a_t, guidance_rescale=guidance_rescale) | |
| # scheduler step (video) | |
| noise_pred_video_u = unpack_latents(noise_pred_video, f_lat, h_lat, w_lat, patch_size, patch_size_t) | |
| latents_u = unpack_latents(latents, f_lat, h_lat, w_lat, patch_size, patch_size_t) | |
| noise_pred_video_u = noise_pred_video_u[:, :, 1:] | |
| noise_latents = latents_u[:, :, 1:] | |
| pred = scheduler.step(noise_pred_video_u, t, noise_latents, return_dict=False)[0] | |
| latents_u = torch.cat([latents_u[:, :, :1], pred], dim=2) | |
| latents = pack_latents(latents_u, patch_size, patch_size_t) | |
| # scheduler step (audio) | |
| audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] | |
| if step_log_every and (i % step_log_every == 0): | |
| print(f"[step {i:02d}] t={int(t)} latents={latents.device}/{latents.dtype}") | |
| # ---- denormalize / unpack to match output_type="latent" from pipeline ---- | |
| latents_u = unpack_latents(latents, f_lat, h_lat, w_lat, patch_size, patch_size_t) | |
| latents_u = denormalize_latents(latents_u, vae.latents_mean, vae.latents_std, vae.config.scaling_factor) | |
| audio_latents = denormalize_audio_latents(audio_latents, audio_vae.latents_mean, audio_vae.latents_std) | |
| audio_latents = unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) | |
| # Optional: move non-block modules back off GPU if you want strict VRAM control | |
| if swap_state is not None and not keep_non_block_on_gpu: | |
| for m in swap_state.non_block_modules: | |
| m.to("cpu") | |
| move_module_root_tensors(transformer, "cpu") | |
| return latents_u, audio_latents | |
| def move_module_root_tensors(module: torch.nn.Module, device: Union[str, torch.device]) -> None: | |
| """ | |
| Move only parameters/buffers registered directly on `module` (recurse=False), | |
| without touching submodules (e.g., transformer_blocks). | |
| """ | |
| # Parameters registered on the module itself (not children) | |
| for name, p in list(module._parameters.items()): | |
| if p is None: | |
| continue | |
| # Preserve the Parameter object; just move its storage | |
| p.data = p.data.to(device) | |
| # Buffers registered on the module itself (not children) | |
| for name, b in list(module._buffers.items()): | |
| if b is None: | |
| continue | |
| module._buffers[name] = b.to(device) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--image", required=True) | |
| parser.add_argument("--prompt", required=True) | |
| parser.add_argument("--negative", default="worst quality, inconsistent motion, blurry, jittery, distorted") | |
| parser.add_argument("--out", default="ltx2_i2v.mp4") | |
| parser.add_argument("--num_frames", type=int, default=121) | |
| parser.add_argument("--fps", type=float, default=24.0) | |
| parser.add_argument("--steps", type=int, default=20) | |
| parser.add_argument("--guidance", type=float, default=4.0) | |
| parser.add_argument("--seed", type=int, default=0) | |
| parser.add_argument("--cache_dir", default="./ltx2_quanto_cache") | |
| parser.add_argument("--group_size", type=int, default=8, help="How many transformer blocks to keep on GPU at a time.") | |
| parser.add_argument("--keep_non_block_on_gpu", action="store_true", help="Keep transformer non-block submodules on GPU.") | |
| parser.add_argument("--dtype", default="bf16", choices=["bf16", "fp16"]) | |
| args = parser.parse_args() | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("CUDA is required for practical inference here.") | |
| device = torch.device("cuda") | |
| torch_dtype = torch.bfloat16 if (args.dtype == "bf16" and torch.cuda.is_bf16_supported()) else torch.float16 | |
| print(f"[setup] device={device}, torch_dtype={torch_dtype}") | |
| model_id = "Lightricks/LTX-2" | |
| cache_dir = Path(args.cache_dir) | |
| # Load non-huge components (CPU is fine; we'll move to GPU only when needed) | |
| from diffusers.schedulers import FlowMatchEulerDiscreteScheduler | |
| from diffusers.models.autoencoders import AutoencoderKLLTX2Video, AutoencoderKLLTX2Audio | |
| from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors | |
| from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder | |
| from diffusers.pipelines.ltx2 import LTX2ImageToVideoPipeline | |
| from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer | |
| from diffusers.models.transformers import LTX2VideoTransformer3DModel | |
| print("[setup] Loading scheduler/vae/audio_vae/vocoder/connectors/tokenizer ...") | |
| scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") | |
| vae = AutoencoderKLLTX2Video.from_pretrained(model_id, subfolder="vae", torch_dtype=torch_dtype).cpu() | |
| audio_vae = AutoencoderKLLTX2Audio.from_pretrained(model_id, subfolder="audio_vae", torch_dtype=torch_dtype).cpu() | |
| vocoder = LTX2Vocoder.from_pretrained(model_id, subfolder="vocoder", torch_dtype=torch_dtype).cpu() | |
| connectors = LTX2TextConnectors.from_pretrained(model_id, subfolder="connectors", torch_dtype=torch_dtype).cpu() | |
| # connectors = quantize_quanto_int8(connectors, "connectors") | |
| tokenizer = GemmaTokenizer.from_pretrained(model_id, subfolder="tokenizer") | |
| # --------------------------- | |
| # Phase 1: text encoder only | |
| # --------------------------- | |
| text_cache = cache_dir / "text_encoder_gemma3_qint8.pt" | |
| def build_text_encoder(): | |
| te = Gemma3ForConditionalGeneration.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch_dtype).cpu() | |
| te = quantize_quanto_int8(te, "text_encoder") | |
| return te | |
| text_encoder = load_or_build_quantized(text_cache, build_text_encoder, "text_encoder") | |
| # Compute prompt embeddings on GPU (text encoder present) | |
| print("[text] Moving text encoder to GPU and encoding prompts ...") | |
| generator = torch.Generator(device="cuda") | |
| generator.manual_seed(args.seed) | |
| prompt_embeds, prompt_mask = gemma_prompt_embeds( | |
| args.prompt, tokenizer, text_encoder, device=device, max_sequence_length=1024, | |
| scale_factor=8, dtype=torch_dtype, num_videos_per_prompt=1 | |
| ) | |
| neg_embeds, neg_mask = gemma_prompt_embeds( | |
| args.negative or "", tokenizer, text_encoder, device=device, max_sequence_length=1024, | |
| scale_factor=8, dtype=torch_dtype, num_videos_per_prompt=1 | |
| ) | |
| # Delete text encoder completely before loading transformer | |
| print("[text] Deleting text encoder and clearing VRAM ...") | |
| text_encoder.to("cpu") | |
| del text_encoder | |
| cuda_gc() | |
| # ------------------------------- | |
| # Phase 2: transformer only (qint8) | |
| # ------------------------------- | |
| tr_cache = cache_dir / "transformer_ltx2_qint8.pt" | |
| def build_transformer(): | |
| tr = LTX2VideoTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch_dtype).cpu() | |
| tr = quantize_quanto_int8(tr, "transformer") | |
| return tr | |
| transformer = load_or_build_quantized(tr_cache, build_transformer, "transformer") | |
| # Attach N-block-at-a-time swaps | |
| print(f"[swap] Attaching group swap hooks: group_size={args.group_size}, keep_non_block_on_gpu={args.keep_non_block_on_gpu}") | |
| swap_state = attach_group_swap_hooks( | |
| transformer, | |
| group_size=args.group_size, | |
| device=device, | |
| keep_non_block_on_gpu=True # args.keep_non_block_on_gpu, | |
| ) | |
| def debug_once(module, inputs): | |
| # only print once | |
| if getattr(module, "_printed", False): | |
| return | |
| module._printed = True | |
| # inspect first tensor input device | |
| hs = inputs[0] if len(inputs) else None | |
| if isinstance(hs, (tuple, list)) and len(hs): | |
| hs = hs[0] | |
| if torch.is_tensor(hs): | |
| print("[debug] transformer hidden_states device:", hs.device, "dtype:", hs.dtype) | |
| transformer.register_forward_pre_hook(debug_once) | |
| # Load image | |
| image = load_image(args.image) | |
| width, height = image.size | |
| width //= 2 | |
| height //= 2 | |
| if (height % 32) != 0 or (width % 32) != 0: | |
| raise ValueError("--height and --width must be divisible by 32") | |
| # ------------------------------- | |
| # Run denoising only -> latents | |
| # ------------------------------- | |
| print("[infer] Running denoising to latents (transformer swaps active) ...") | |
| keep_non_block = True # strongly recommended | |
| video_latents, audio_latents = ltx2_denoise_to_latents( | |
| image=image, | |
| prompt_embeds=prompt_embeds, | |
| prompt_attention_mask=prompt_mask, | |
| negative_prompt_embeds=neg_embeds, | |
| negative_prompt_attention_mask=neg_mask, | |
| scheduler=scheduler, | |
| vae=vae, | |
| audio_vae=audio_vae, | |
| connectors=connectors, | |
| transformer=transformer, | |
| num_frames=args.num_frames, | |
| frame_rate=args.fps, | |
| height=height, | |
| width=width, | |
| num_inference_steps=args.steps, | |
| guidance_scale=args.guidance, | |
| guidance_rescale=0.0, | |
| device=device, | |
| dtype=torch_dtype, | |
| generator=generator, | |
| swap_state=swap_state, | |
| keep_non_block_on_gpu=keep_non_block, | |
| step_log_every=1, | |
| ) | |
| # Free transformer before decode (per your requirement) | |
| print("[swap] Freeing transformer and clearing VRAM before decode ...") | |
| try: | |
| swap_state.unload_all() | |
| except Exception: | |
| pass | |
| transformer.to("cpu") | |
| del transformer | |
| cuda_gc() | |
| # ------------------------------- | |
| # Decode latents -> final outputs | |
| # ------------------------------- | |
| print("[decode] Decoding video/audio ...") | |
| video, audio = decode_from_latents_components( | |
| vae=vae, | |
| audio_vae=audio_vae, | |
| vocoder=vocoder, | |
| video_latents=video_latents, | |
| audio_latents=audio_latents, | |
| device=device, | |
| output_type="np", | |
| ) | |
| # Write MP4 + audio using diffusers encode_video if available, else ffmpeg mux fallback | |
| out_path = str(Path(args.out).resolve()) | |
| # video is float in [0,1], shape likely [B,F,H,W,C] | |
| # video = (video + 1) / 2 | |
| video = 1.0 - video | |
| video_u8 = (video.astype(np.float32) * 255.0).round().clip(0, 255).astype(np.uint8) | |
| # video_u8 = 255 - video_u8 | |
| # audio is torch tensor (from vocoder). Usually [B, T]. | |
| sr = int(getattr(vocoder.config, "output_sampling_rate", 24000)) | |
| out_path = write_mp4_with_audio_ffmpeg( | |
| out_path=str(Path(args.out).resolve()), | |
| video_u8=video_u8, | |
| fps=args.fps, | |
| audio=audio, | |
| audio_sample_rate=sr, | |
| macro_block_size=16, # set to 1 if you ever output odd sizes | |
| ) | |
| print(f"[done] Wrote: {out_path}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment