Last active
December 25, 2025 04:33
-
-
Save tamnguyenvan/23c8fcad20706c81ec7adcbfa3e4009e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import re | |
| import time | |
| import math | |
| import uuid | |
| import json | |
| import shutil | |
| import random | |
| import types | |
| import hashlib | |
| import importlib | |
| import gc | |
| import copy | |
| from dataclasses import dataclass | |
| from collections import namedtuple, deque | |
| from contextlib import contextmanager | |
| from typing import List, Tuple, Optional, Literal, TypeAlias | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.nn.init as init | |
| from torch.utils.checkpoint import checkpoint | |
| # --- START FIX: TRITON & FLASH ATTN BYPASS FOR MAC --- | |
| try: | |
| import triton | |
| import triton.language as tl | |
| TRITON_AVAILABLE = True | |
| except ImportError: | |
| TRITON_AVAILABLE = False | |
| # Dummy object để bypass decorator @triton.jit | |
| def dummy_jit(func): return func | |
| triton = type('obj', (object,), {'jit': dummy_jit, 'cdiv': lambda a, b: (a + b - 1) // b}) | |
| tl = type('obj', (object,), {'constexpr': int, 'float32': torch.float32, 'int8': torch.int8, 'float16': torch.float16}) | |
| FLASH_ATTN_3_AVAILABLE = False | |
| FLASH_ATTN_2_AVAILABLE = False | |
| SAGE_ATTN_AVAILABLE = False | |
| BLOCK_ATTN_AVAILABLE = False | |
| USE_BLOCK_ATTN = False | |
| # --- END FIX --- | |
| from PIL import Image | |
| from tqdm.auto import tqdm | |
| from huggingface_hub import snapshot_download | |
| from safetensors import safe_open | |
| from einops import rearrange, repeat | |
| from torchvision.transforms import GaussianBlur | |
| import imageio | |
| import ffmpeg | |
| # ============================================================================== | |
| # Configuration & Globals | |
| # ============================================================================== | |
| @dataclass | |
| class Args: | |
| input: str = "/content/REqq8zq3pSFVsvUn.mp4" | |
| output_folder: str = "./outputs" | |
| scale: int = 4 | |
| version: str = "10" | |
| mode: str = "tiny" | |
| tiled_vae: bool = True # Bắt buộc True cho 16GB | |
| tiled_dit: bool = True # Bắt buộc True cho 16GB | |
| tile_size: int = 16 # Giảm xuống 64 để an toàn cho 16GB (gốc 256) | |
| overlap: int = 8 # Overlap nhỏ lại tương ứng | |
| unload_dit: bool = True # Unload để tiết kiệm RAM | |
| color_fix: bool = False | |
| seed: int = 0 | |
| dtype: str = "fp16" # Mac MPS chạy fp16 tốt hơn bf16 | |
| device: str = "auto" # Ép chạy MPS (Mac GPU) | |
| fps: int = 30 | |
| quality: int = 6 | |
| attention: str = "sdpa" # Dùng Standard Attention của PyTorch | |
| args = Args() | |
| CACHE_T = 2 | |
| root = os.path.dirname(os.path.abspath(__file__)) | |
| temp = os.path.join(root, "_temp") | |
| def log(message: str, message_type: str = "normal"): | |
| if message_type == "error": | |
| message = "\033[1;41m" + message + "\033[m" | |
| elif message_type == "warning": | |
| message = "\033[1;31m" + message + "\033[m" | |
| elif message_type == "finish": | |
| message = "\033[1;32m" + message + "\033[m" | |
| elif message_type == "info": | |
| message = "\033[1;33m" + message + "\033[m" | |
| print(f"{message}") | |
| # ============================================================================== | |
| # Attention Imports & Triton Kernels | |
| # ============================================================================== | |
| try: | |
| import flash_attn_interface | |
| FLASH_ATTN_3_AVAILABLE = True | |
| except ModuleNotFoundError: | |
| FLASH_ATTN_3_AVAILABLE = False | |
| try: | |
| import flash_attn | |
| FLASH_ATTN_2_AVAILABLE = True | |
| except ModuleNotFoundError: | |
| FLASH_ATTN_2_AVAILABLE = False | |
| try: | |
| from sageattention import sageattn | |
| SAGE_ATTN_AVAILABLE = True | |
| except ModuleNotFoundError: | |
| SAGE_ATTN_AVAILABLE = False | |
| try: | |
| from block_sparse_attn import block_sparse_attn_func | |
| BLOCK_ATTN_AVAILABLE = True | |
| except: | |
| BLOCK_ATTN_AVAILABLE = False | |
| USE_BLOCK_ATTN = False | |
| @triton.jit | |
| def quant_per_block_int8_kernel( | |
| Input, Output, Scale, L, | |
| stride_iz, stride_ih, stride_in, | |
| stride_oz, stride_oh, stride_on, | |
| stride_sz, stride_sh, | |
| sm_scale, C: tl.constexpr, BLK: tl.constexpr, | |
| ): | |
| off_blk = tl.program_id(0) | |
| off_h = tl.program_id(1) | |
| off_b = tl.program_id(2) | |
| offs_n = off_blk * BLK + tl.arange(0, BLK) | |
| offs_k = tl.arange(0, C) | |
| input_ptrs = (Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]) | |
| output_ptrs = (Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]) | |
| scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk | |
| x = tl.load(input_ptrs, mask=offs_n[:, None] < L) | |
| x = x.to(tl.float32) | |
| x *= sm_scale | |
| scale = tl.max(tl.abs(x)) / 127.0 | |
| x_int8 = x / scale | |
| x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) | |
| x_int8 = x_int8.to(tl.int8) | |
| tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) | |
| tl.store(scale_ptrs, scale) | |
| @triton.jit | |
| def _attn_fwd_inner( | |
| acc, l_i, old_m, q, q_scale, kv_len, | |
| K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs, | |
| stride_kn, stride_vn, start_m, | |
| BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, | |
| STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, | |
| ): | |
| if STAGE == 1: | |
| lo, hi = 0, start_m * BLOCK_M | |
| elif STAGE == 2: | |
| lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M | |
| lo = tl.multiple_of(lo, BLOCK_M) | |
| K_scale_ptr += lo // BLOCK_N | |
| K_ptrs += stride_kn * lo | |
| V_ptrs += stride_vn * lo | |
| elif STAGE == 3: | |
| lo, hi = 0, kv_len | |
| for start_n in range(lo, hi, BLOCK_N): | |
| kbid = tl.load(K_bid_ptr + start_n // BLOCK_N) | |
| if kbid: | |
| k_mask = offs_n[None, :] < (kv_len - start_n) | |
| k = tl.load(K_ptrs, mask=k_mask) | |
| k_scale = tl.load(K_scale_ptr) | |
| qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale | |
| if STAGE == 2: | |
| mask = offs_m[:, None] >= (start_n + offs_n[None, :]) | |
| qk = qk + tl.where(mask, 0, -1.0e6) | |
| local_m = tl.max(qk, 1) | |
| new_m = tl.maximum(old_m, local_m) | |
| qk -= new_m[:, None] | |
| else: | |
| local_m = tl.max(qk, 1) | |
| new_m = tl.maximum(old_m, local_m) | |
| qk = qk - new_m[:, None] | |
| p = tl.math.exp2(qk) | |
| l_ij = tl.sum(p, 1) | |
| alpha = tl.math.exp2(old_m - new_m) | |
| l_i = l_i * alpha + l_ij | |
| acc = acc * alpha[:, None] | |
| v = tl.load(V_ptrs, mask=offs_n[:, None] < (kv_len - start_n)) | |
| p = p.to(tl.float16) | |
| acc += tl.dot(p, v, out_dtype=tl.float16) | |
| old_m = new_m | |
| K_ptrs += BLOCK_N * stride_kn | |
| K_scale_ptr += 1 | |
| V_ptrs += BLOCK_N * stride_vn | |
| return acc, l_i, old_m | |
| @triton.jit | |
| def _attn_fwd( | |
| Q, K, K_blkid, V, Q_scale, K_scale, Out, | |
| stride_qz, stride_qh, stride_qn, | |
| stride_kz, stride_kh, stride_kn, | |
| stride_vz, stride_vh, stride_vn, | |
| stride_oz, stride_oh, stride_on, | |
| stride_kbidq, stride_kbidk, | |
| qo_len, kv_len, H: tl.constexpr, num_kv_groups: tl.constexpr, | |
| HEAD_DIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, | |
| STAGE: tl.constexpr, | |
| ): | |
| start_m = tl.program_id(0) | |
| off_z = tl.program_id(2).to(tl.int64) | |
| off_h = tl.program_id(1).to(tl.int64) | |
| q_scale_offset = (off_z * H + off_h) * tl.cdiv(qo_len, BLOCK_M) | |
| k_scale_offset = (off_z * (H // num_kv_groups) + off_h // num_kv_groups) * tl.cdiv(kv_len, BLOCK_N) | |
| k_bid_offset = (off_z * (H // num_kv_groups) + off_h // num_kv_groups) * stride_kbidq | |
| offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
| offs_n = tl.arange(0, BLOCK_N) | |
| offs_k = tl.arange(0, HEAD_DIM) | |
| Q_ptrs = (Q + (off_z * stride_qz + off_h * stride_qh) + offs_m[:, None] * stride_qn + offs_k[None, :]) | |
| Q_scale_ptr = Q_scale + q_scale_offset + start_m | |
| K_ptrs = (K + (off_z * stride_kz + (off_h // num_kv_groups) * stride_kh) + offs_n[None, :] * stride_kn + offs_k[:, None]) | |
| K_scale_ptr = K_scale + k_scale_offset | |
| K_bid_ptr = K_blkid + k_bid_offset + start_m * stride_kbidk | |
| V_ptrs = (V + (off_z * stride_vz + (off_h // num_kv_groups) * stride_vh) + offs_n[:, None] * stride_vn + offs_k[None, :]) | |
| O_block_ptr = (Out + (off_z * stride_oz + off_h * stride_oh) + offs_m[:, None] * stride_on + offs_k[None, :]) | |
| m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") | |
| l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 | |
| acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) | |
| q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len) | |
| q_scale = tl.load(Q_scale_ptr) | |
| acc, l_i, m_i = _attn_fwd_inner( | |
| acc, l_i, m_i, q, q_scale, kv_len, | |
| K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs, | |
| stride_kn, stride_vn, start_m, | |
| BLOCK_M, HEAD_DIM, BLOCK_N, 4 - STAGE, | |
| offs_m, offs_n, | |
| ) | |
| if STAGE != 1: | |
| acc, l_i, _ = _attn_fwd_inner( | |
| acc, l_i, m_i, q, q_scale, kv_len, | |
| K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs, | |
| stride_kn, stride_vn, start_m, | |
| BLOCK_M, HEAD_DIM, BLOCK_N, 2, | |
| offs_m, offs_n, | |
| ) | |
| acc = acc / l_i[:, None] | |
| tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=(offs_m[:, None] < qo_len)) | |
| def per_block_int8(q, k, km=None, BLKQ=128, BLKK=64, sm_scale=None, tensor_layout="HND"): | |
| """ | |
| Phiên bản bypass cho Mac/MPS: Không lượng tử hóa int8 để tránh lỗi kernel. | |
| Trả về Tensor gốc và scale = 1.0. | |
| """ | |
| if not TRITON_AVAILABLE or q.device.type == 'mps': | |
| if km is not None: | |
| k = k - km | |
| # Trả về nguyên bản, dummy scale | |
| b = q.shape[0] | |
| # Layout scales tùy thuộc implementation, ở đây return dummy để không crash | |
| # Thực tế logic attention sau này sẽ dùng SDPA nên scale không quan trọng lắm nếu ta pass q, k float. | |
| q_scale = torch.ones((1,), device=q.device, dtype=torch.float32) | |
| k_scale = torch.ones((1,), device=q.device, dtype=torch.float32) | |
| # Giả vờ là int8 nhưng thực chất giữ nguyên dtype hoặc cast nhẹ nếu cần | |
| # Nhưng để an toàn cho SDPA trên Mac, ta trả về float tensor | |
| return q, q_scale, k, k_scale | |
| # Code gốc cho Nvidia (giữ nguyên logic cũ nếu có Triton) | |
| q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) | |
| k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) | |
| # ... (Giữ nguyên phần Triton cũ nếu muốn, hoặc xóa đi cũng được vì ta chạy Mac) | |
| # Để ngắn gọn, tôi return None ở đây vì bạn chạy Mac | |
| return None, None, None, None | |
| def sparse_sageattn_fwd(q, k, k_block_id, v, q_scale, k_scale, is_causal=False, tensor_layout="HND", output_dtype=torch.float16): | |
| # Hàm này chỉ được gọi bởi Triton flow. Trên Mac ta sẽ bypass từ sparse_sageattn | |
| return torch.zeros_like(q, dtype=output_dtype) | |
| def sparse_sageattn(q, k, v, mask_id=None, is_causal=False, tensor_layout="HND"): | |
| """ | |
| Thay thế Sparse Attention bằng Standard Attention (SDPA) trên Mac. | |
| """ | |
| if not TRITON_AVAILABLE or q.device.type == 'mps': | |
| # Chuyển layout về (Batch, Head, Seq, Dim) cho SDPA | |
| if tensor_layout == "HND": # (Batch, Head, Seq, Dim) | |
| q = q.permute(0, 1, 2, 3) | |
| k = k.permute(0, 1, 2, 3) | |
| v = v.permute(0, 1, 2, 3) | |
| elif tensor_layout == "NHD": # (Batch, Seq, Head, Dim) | |
| q = q.permute(0, 2, 1, 3) | |
| k = k.permute(0, 2, 1, 3) | |
| v = v.permute(0, 2, 1, 3) | |
| # PyTorch SDPA hỗ trợ tốt trên MPS | |
| out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal) | |
| # Trả về layout cũ | |
| if tensor_layout == "NHD": | |
| out = out.permute(0, 2, 1, 3) | |
| return out | |
| # Code cũ (Triton) - Giữ nguyên hoặc xóa | |
| return None | |
| # ============================================================================== | |
| # Math, Embeddings & Helpers | |
| # ============================================================================== | |
| def sinusoidal_embedding_1d(dim, position): | |
| sinusoid = torch.outer( | |
| position.type(torch.float64), | |
| torch.pow(10000, -torch.arange(dim // 2, dtype=torch.float64, device=position.device).div(dim // 2)), | |
| ) | |
| x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) | |
| return x.to(position.dtype) | |
| def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0): | |
| f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta) | |
| h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) | |
| w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) | |
| return f_freqs_cis, h_freqs_cis, w_freqs_cis | |
| def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): | |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].double() / dim)) | |
| freqs = torch.outer(torch.arange(end, device=freqs.device), freqs) | |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) | |
| return freqs_cis | |
| def rope_apply(x, freqs, num_heads): | |
| x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) | |
| x_out = torch.view_as_complex(x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)) | |
| x_out = torch.view_as_real(x_out * freqs).flatten(2) | |
| return x_out.to(x.dtype) | |
| def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): | |
| return x * (1 + scale) + shift | |
| def check_is_instance(model, module_class): | |
| if isinstance(model, module_class): | |
| return True | |
| if hasattr(model, "module") and isinstance(model.module, module_class): | |
| return True | |
| return False | |
| def count_conv3d(model): | |
| count = 0 | |
| for m in model.modules(): | |
| if check_is_instance(m, CausalConv3dZeroPad): | |
| count += 1 | |
| return count | |
| @torch.no_grad() | |
| def build_local_block_mask_shifted_vec_normal_slide( | |
| block_h: int, block_w: int, win_h: int = 6, win_w: int = 6, | |
| include_self: bool = True, device=None, | |
| ) -> torch.Tensor: | |
| device = device or torch.device("cpu") | |
| H, W = block_h, block_w | |
| r = torch.arange(H, device=device) | |
| c = torch.arange(W, device=device) | |
| YY, XX = torch.meshgrid(r, c, indexing="ij") | |
| r_all = YY.reshape(-1) | |
| c_all = XX.reshape(-1) | |
| r_half = win_h // 2 | |
| c_half = win_w // 2 | |
| start_r = r_all - r_half | |
| end_r = start_r + win_h - 1 | |
| start_c = c_all - c_half | |
| end_c = start_c + win_w - 1 | |
| in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None]) | |
| in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None]) | |
| mask = in_row & in_col | |
| if not include_self: | |
| mask.fill_diagonal_(False) | |
| return mask | |
| @torch.no_grad() | |
| def generate_draft_block_mask(batch_size, nheads, seqlen, q_w, k_w, topk=10, local_attn_mask=None): | |
| assert batch_size == 1 | |
| assert local_attn_mask is not None | |
| avgpool_q = torch.mean(q_w, dim=1) | |
| avgpool_k = torch.mean(k_w, dim=1) | |
| avgpool_q = rearrange(avgpool_q, "s (h d) -> s h d", h=nheads) | |
| avgpool_k = rearrange(avgpool_k, "s (h d) -> s h d", h=nheads) | |
| q_heads = avgpool_q.permute(1, 0, 2) | |
| k_heads = avgpool_k.permute(1, 0, 2) | |
| D = avgpool_q.shape[-1] | |
| scores = torch.einsum("hld,hmd->hlm", q_heads, k_heads) / math.sqrt(D) | |
| repeat_head = scores.shape[0] | |
| repeat_len = scores.shape[1] // local_attn_mask.shape[0] | |
| repeat_num = scores.shape[2] // local_attn_mask.shape[1] | |
| local_attn_mask = local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1) | |
| local_attn_mask = rearrange(local_attn_mask, "x a y b -> (x a) (y b)") | |
| local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1).to(torch.float32) | |
| local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == False, -float("inf")) | |
| local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0) | |
| scores = scores + local_attn_mask | |
| attn_map = torch.softmax(scores, dim=-1) | |
| attn_map = rearrange(attn_map, "h (it s1) s2 -> (h it) s1 s2", it=seqlen) | |
| loop_num, s1, s2 = attn_map.shape | |
| flat = attn_map.reshape(loop_num, -1) | |
| apply_topk = min(flat.shape[1] - 1, topk) | |
| thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1] | |
| thresholds = thresholds.unsqueeze(1) | |
| mask_new = (flat > thresholds).reshape(loop_num, s1, s2) | |
| mask_new = rearrange(mask_new, "(h it) s1 s2 -> h (it s1) s2", it=seqlen) | |
| mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1) | |
| return mask | |
| @torch.no_grad() | |
| def generate_draft_block_mask_sage(batch_size, nheads, seqlen, q_w, k_w, topk=10, local_attn_mask=None): | |
| assert batch_size == 1 | |
| assert local_attn_mask is not None | |
| avgpool_q = torch.mean(q_w, dim=1) | |
| avgpool_q = rearrange(avgpool_q, "s (h d) -> s h d", h=nheads) | |
| q_heads = avgpool_q.permute(1, 0, 2) | |
| D = avgpool_q.shape[-1] | |
| k_w_split = k_w.view(k_w.shape[0], 2, 64, k_w.shape[2]) | |
| avgpool_k_split = torch.mean(k_w_split, dim=2) | |
| avgpool_k_refined = rearrange(avgpool_k_split, "s two d -> (s two) d", two=2) | |
| avgpool_k_refined = rearrange(avgpool_k_refined, "s (h d) -> s h d", h=nheads) | |
| k_heads_doubled = avgpool_k_refined.permute(1, 0, 2) | |
| k_heads_1, k_heads_2 = torch.chunk(k_heads_doubled, 2, dim=1) | |
| scores_1 = torch.einsum("hld,hmd->hlm", q_heads, k_heads_1) / math.sqrt(D) | |
| scores_2 = torch.einsum("hld,hmd->hlm", q_heads, k_heads_2) / math.sqrt(D) | |
| scores = torch.cat([scores_1, scores_2], dim=-1) | |
| repeat_head = scores.shape[0] | |
| repeat_len = scores.shape[1] // local_attn_mask.shape[0] | |
| repeat_num = (scores.shape[2] // 2) // local_attn_mask.shape[1] | |
| local_attn_mask = local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1) | |
| local_attn_mask = rearrange(local_attn_mask, "x a y b -> (x a) (y b)") | |
| local_attn_mask = local_attn_mask.repeat_interleave(2, dim=1) | |
| local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1).to(torch.float32) | |
| local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == False, -float("inf")) | |
| local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0) | |
| scores = scores + local_attn_mask | |
| attn_map = torch.softmax(scores, dim=-1) | |
| attn_map = rearrange(attn_map, "h (it s1) s2 -> (h it) s1 s2", it=seqlen) | |
| loop_num, s1, s2 = attn_map.shape | |
| flat = attn_map.reshape(loop_num, -1) | |
| apply_topk = min(flat.shape[1] - 1, topk) | |
| if apply_topk <= 0: | |
| mask_new = torch.zeros_like(flat, dtype=torch.bool).reshape(loop_num, s1, s2) | |
| else: | |
| thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1] | |
| thresholds = thresholds.unsqueeze(1) | |
| mask_new = (flat > thresholds).reshape(loop_num, s1, s2) | |
| mask_new = rearrange(mask_new, "(h it) s1 s2 -> h (it s1) s2", it=seqlen) | |
| mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1) | |
| return mask | |
| class WindowPartition3D: | |
| @staticmethod | |
| def partition(x: torch.Tensor, win: Tuple[int, int, int]): | |
| B, F, H, W, C = x.shape | |
| wf, wh, ww = win | |
| assert F % wf == 0 and H % wh == 0 and W % ww == 0 | |
| x = x.view(B, F // wf, wf, H // wh, wh, W // ww, ww, C) | |
| x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous() | |
| return x.view(-1, wf * wh * ww, C) | |
| @staticmethod | |
| def reverse(windows: torch.Tensor, win: Tuple[int, int, int], orig: Tuple[int, int, int]): | |
| F, H, W = orig | |
| wf, wh, ww = win | |
| nf, nh, nw = F // wf, H // wh, W // ww | |
| B = windows.size(0) // (nf * nh * nw) | |
| x = windows.view(B, nf, nh, nw, wf, wh, ww, -1) | |
| x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous() | |
| return x.view(B, F, H, W, -1) | |
| # ============================================================================== | |
| # Basic Layers (Norms, Convs) | |
| # ============================================================================== | |
| class RMS_norm_General(nn.Module): | |
| """General RMS Norm (supports images arg for channel_first). Used in VAE/Projs.""" | |
| def __init__(self, dim, channel_first=True, images=True, bias=False): | |
| super().__init__() | |
| broadcastable_dims = (1, 1, 1) if not images else (1, 1) | |
| shape = (dim, *broadcastable_dims) if channel_first else (dim,) | |
| self.channel_first = channel_first | |
| self.scale = dim**0.5 | |
| self.gamma = nn.Parameter(torch.ones(shape)) | |
| self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 | |
| def forward(self, x): | |
| return (F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias) | |
| class RMSNorm(nn.Module): | |
| """Simple RMS Norm used in DiT backbone.""" | |
| def __init__(self, dim, eps=1e-5): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def norm(self, x): | |
| return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) | |
| def forward(self, x): | |
| dtype = x.dtype | |
| return self.norm(x.float()).to(dtype) * self.weight | |
| class CausalConv3dZeroPad(nn.Conv3d): | |
| """Causal Conv3d with zero padding (Used in VAE/Encoder/Decoder).""" | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) | |
| self.padding = (0, 0, 0) | |
| def forward(self, x, cache_x=None): | |
| padding = list(self._padding) | |
| if cache_x is not None and self._padding[4] > 0: | |
| cache_x = cache_x.to(x.device) | |
| x = torch.cat([cache_x, x], dim=2) | |
| padding[4] -= cache_x.shape[2] | |
| x = F.pad(x, padding) | |
| return super().forward(x) | |
| class CausalConv3dReplicate(nn.Conv3d): | |
| """Causal Conv3d with replicate padding (Used in LQ Projections).""" | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) | |
| self.padding = (0, 0, 0) | |
| def forward(self, x, cache_x=None): | |
| padding = list(self._padding) | |
| if cache_x is not None and self._padding[4] > 0: | |
| cache_x = cache_x.to(x.device) | |
| x = torch.cat([cache_x, x], dim=2) | |
| padding[4] -= cache_x.shape[2] | |
| x = F.pad(x, padding, mode="replicate") | |
| return super().forward(x) | |
| class IdentityConv2d(nn.Conv2d): | |
| def __init__(self, C, kernel_size=3, bias=False): | |
| pad = kernel_size // 2 | |
| super().__init__(C, C, kernel_size, padding=pad, bias=bias) | |
| with torch.no_grad(): | |
| init.dirac_(self.weight) | |
| if self.bias is not None: | |
| self.bias.zero_() | |
| class Clamp(nn.Module): | |
| def forward(self, x): | |
| return torch.tanh(x / 3) * 3 | |
| class Upsample(nn.Upsample): | |
| def forward(self, x): | |
| return super().forward(x.float()).type_as(x) | |
| class Resample(nn.Module): | |
| def __init__(self, dim, mode): | |
| assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d") | |
| super().__init__() | |
| self.dim = dim | |
| self.mode = mode | |
| if mode == "upsample2d": | |
| self.resample = nn.Sequential(Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)) | |
| elif mode == "upsample3d": | |
| self.resample = nn.Sequential(Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)) | |
| self.time_conv = CausalConv3dZeroPad(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) | |
| elif mode == "downsample2d": | |
| self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) | |
| elif mode == "downsample3d": | |
| self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) | |
| self.time_conv = CausalConv3dZeroPad(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) | |
| else: | |
| self.resample = nn.Identity() | |
| def forward(self, x, feat_cache=None, feat_idx=[0]): | |
| b, c, t, h, w = x.size() | |
| if self.mode == "upsample3d": | |
| if feat_cache is not None: | |
| idx = feat_idx[0] | |
| if feat_cache[idx] is None: | |
| feat_cache[idx] = "Rep" | |
| feat_idx[0] += 1 | |
| else: | |
| cache_x = x[:, :, -CACHE_T:, :, :].clone() | |
| if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": | |
| cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) | |
| if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": | |
| cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) | |
| if feat_cache[idx] == "Rep": | |
| x = self.time_conv(x) | |
| else: | |
| x = self.time_conv(x, feat_cache[idx]) | |
| feat_cache[idx] = cache_x | |
| feat_idx[0] += 1 | |
| x = x.reshape(b, 2, c, t, h, w) | |
| x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) | |
| x = x.reshape(b, c, t * 2, h, w) | |
| t = x.shape[2] | |
| x = rearrange(x, "b c t h w -> (b t) c h w") | |
| x = self.resample(x) | |
| x = rearrange(x, "(b t) c h w -> b c t h w", t=t) | |
| if self.mode == "downsample3d": | |
| if feat_cache is not None: | |
| idx = feat_idx[0] | |
| if feat_cache[idx] is None: | |
| feat_cache[idx] = x.clone() | |
| feat_idx[0] += 1 | |
| else: | |
| cache_x = x[:, :, -1:, :, :].clone() | |
| x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) | |
| feat_cache[idx] = cache_x | |
| feat_idx[0] += 1 | |
| return x | |
| class PixelShuffle3d(nn.Module): | |
| def __init__(self, ff, hh, ww): | |
| super().__init__() | |
| self.ff, self.hh, self.ww = ff, hh, ww | |
| def forward(self, x): | |
| return rearrange(x, "b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w", ff=self.ff, hh=self.hh, ww=self.ww) | |
| class PixelShuffle3dTAEHV(nn.Module): | |
| def __init__(self, ff, hh, ww): | |
| super().__init__() | |
| self.ff, self.hh, self.ww = ff, hh, ww | |
| def forward(self, x): | |
| B, C, F, H, W = x.shape | |
| if F % self.ff != 0: | |
| first_frame = x[:, :, 0:1, :, :].repeat(1, 1, self.ff - F % self.ff, 1, 1) | |
| x = torch.cat([first_frame, x], dim=2) | |
| return rearrange(x, "b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w", ff=self.ff, hh=self.hh, ww=self.ww).transpose(1, 2) | |
| # ============================================================================== | |
| # Blocks (Residual, Attention, MemBlock) | |
| # ============================================================================== | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, in_dim, out_dim, dropout=0.0): | |
| super().__init__() | |
| self.residual = nn.Sequential( | |
| RMS_norm_General(in_dim, images=False), | |
| nn.SiLU(), | |
| CausalConv3dZeroPad(in_dim, out_dim, 3, padding=1), | |
| RMS_norm_General(out_dim, images=False), | |
| nn.SiLU(), | |
| nn.Dropout(dropout), | |
| CausalConv3dZeroPad(out_dim, out_dim, 3, padding=1), | |
| ) | |
| self.shortcut = CausalConv3dZeroPad(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() | |
| def forward(self, x, feat_cache=None, feat_idx=[0]): | |
| h = self.shortcut(x) | |
| for layer in self.residual: | |
| if check_is_instance(layer, CausalConv3dZeroPad) and feat_cache is not None: | |
| idx = feat_idx[0] | |
| cache_x = x[:, :, -CACHE_T:, :, :].clone() | |
| if cache_x.shape[2] < 2 and feat_cache[idx] is not None: | |
| cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) | |
| x = layer(x, feat_cache[idx]) | |
| feat_cache[idx] = cache_x | |
| feat_idx[0] += 1 | |
| else: | |
| x = layer(x) | |
| return x + h | |
| class AttentionBlock(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.norm = RMS_norm_General(dim) | |
| self.to_qkv = nn.Conv2d(dim, dim * 3, 1) | |
| self.proj = nn.Conv2d(dim, dim, 1) | |
| nn.init.zeros_(self.proj.weight) | |
| def forward(self, x): | |
| identity = x | |
| b, c, t, h, w = x.size() | |
| x = rearrange(x, "b c t h w -> (b t) c h w") | |
| x = self.norm(x) | |
| q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1) | |
| x = F.scaled_dot_product_attention(q, k, v) | |
| x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) | |
| x = self.proj(x) | |
| x = rearrange(x, "(b t) c h w-> b c t h w", t=t) | |
| return x + identity | |
| def flash_attention(q, k, v, num_heads, compatibility_mode=False, attention_mask=None, return_KV=False): | |
| """ | |
| Hàm wrapper chính cho Attention. Tự động chuyển sang SDPA trên Mac. | |
| """ | |
| # --- MAC / MPS SUPPORT --- | |
| if q.device.type == 'mps' or not TRITON_AVAILABLE: | |
| # Input format thường là (b, s, n*d) hoặc (b, s, n, d) tùy chỗ gọi | |
| # Ở WanModel code này, input vào là (b, s, n*d) | |
| # Reshape: (b, s, n*d) -> (b, n, s, d) | |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) | |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) | |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) | |
| # Dùng native SDPA (nhanh nhất trên Mac) | |
| x = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=False) | |
| # Reshape lại output: (b, n, s, d) -> (b, s, n*d) | |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) | |
| return x | |
| # --- END MAC SUPPORT --- | |
| if attention_mask is not None: | |
| seqlen = q.shape[1] | |
| # ... (Phần code cũ cho Nvidia giữ nguyên) | |
| # Để code chạy được nếu lỡ không xóa import, ta dùng fallback | |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) | |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) | |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) | |
| x = F.scaled_dot_product_attention(q, k, v) | |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) | |
| return x | |
| # Fallback cuối cùng | |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) | |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) | |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) | |
| x = F.scaled_dot_product_attention(q, k, v) | |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) | |
| return x | |
| class AttentionModule(nn.Module): | |
| def __init__(self, num_heads): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| def forward(self, q, k, v, attention_mask=None): | |
| return flash_attention(q=q, k=k, v=v, num_heads=self.num_heads, attention_mask=attention_mask) | |
| class SelfAttention(nn.Module): | |
| def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| self.head_dim = dim // num_heads | |
| self.q = nn.Linear(dim, dim) | |
| self.k = nn.Linear(dim, dim) | |
| self.v = nn.Linear(dim, dim) | |
| self.o = nn.Linear(dim, dim) | |
| self.norm_q = RMSNorm(dim, eps=eps) | |
| self.norm_k = RMSNorm(dim, eps=eps) | |
| self.attn = AttentionModule(self.num_heads) | |
| self.local_attn_mask = None | |
| def forward( | |
| self, x, freqs, f=None, h=None, w=None, local_num=None, topk=None, train_img=False, | |
| block_id=None, kv_len=None, is_full_block=False, is_stream=False, | |
| pre_cache_k=None, pre_cache_v=None, local_range=9, | |
| ): | |
| B, L, D = x.shape | |
| if is_stream and pre_cache_k is not None and pre_cache_v is not None: | |
| assert f == 2, "f must be 2" | |
| if is_stream and (pre_cache_k is None or pre_cache_v is None): | |
| assert f == 6, " start f must be 6" | |
| assert L == f * h * w, "Sequence length mismatch with provided (f,h,w)." | |
| q = self.norm_q(self.q(x)) | |
| k = self.norm_k(self.k(x)) | |
| v = self.v(x) | |
| q = rope_apply(q, freqs, self.num_heads) | |
| k = rope_apply(k, freqs, self.num_heads) | |
| # Cấu hình Window Partition | |
| win = (2, 8, 8) | |
| q = q.view(B, f, h, w, D) | |
| k = k.view(B, f, h, w, D) | |
| v = v.view(B, f, h, w, D) | |
| q_w = WindowPartition3D.partition(q, win) | |
| k_w = WindowPartition3D.partition(k, win) | |
| v_w = WindowPartition3D.partition(v, win) | |
| seqlen = f // win[0] | |
| one_len = k_w.shape[0] // B // seqlen | |
| if pre_cache_k is not None and pre_cache_v is not None: | |
| k_w = torch.cat([pre_cache_k, k_w], dim=0) | |
| v_w = torch.cat([pre_cache_v, v_w], dim=0) | |
| block_n = q_w.shape[0] // B | |
| block_s = q_w.shape[1] | |
| block_n_kv = k_w.shape[0] // B | |
| # Reshape cho Attention | |
| reorder_q = rearrange(q_w, "(b block_n) (block_s) d -> b (block_n block_s) d", block_n=block_n, block_s=block_s) | |
| reorder_k = rearrange(k_w, "(b block_n) (block_s) d -> b (block_n block_s) d", block_n=block_n_kv, block_s=block_s) | |
| reorder_v = rearrange(v_w, "(b block_n) (block_s) d -> b (block_n block_s) d", block_n=block_n_kv, block_s=block_s) | |
| # --- FIX CHO MAC (Bỏ qua logic tạo mask phức tạp) --- | |
| attention_mask = None | |
| if x.device.type != 'mps': | |
| # Chỉ chạy logic tạo mask này nếu có GPU Nvidia | |
| if (self.local_attn_mask is None or self.local_attn_mask_h != h // 8 or self.local_attn_mask_w != w // 8 or self.local_range != local_range): | |
| self.local_attn_mask = build_local_block_mask_shifted_vec_normal_slide( | |
| h // 8, w // 8, local_range, local_range, include_self=True, device=k_w.device, | |
| ) | |
| self.local_attn_mask_h = h // 8 | |
| self.local_attn_mask_w = w // 8 | |
| self.local_range = local_range | |
| if USE_BLOCK_ATTN and BLOCK_ATTN_AVAILABLE: | |
| attention_mask = generate_draft_block_mask(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask) | |
| else: | |
| attention_mask = generate_draft_block_mask_sage(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask) | |
| # ---------------------------------------------------- | |
| x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask) | |
| # Cache management cho streaming | |
| cur_block_n, cur_block_s, _ = k_w.shape | |
| cache_num = cur_block_n // one_len | |
| if cache_num > kv_len: | |
| cache_k = k_w[one_len:, :, :] | |
| cache_v = v_w[one_len:, :, :] | |
| else: | |
| cache_k = k_w | |
| cache_v = v_w | |
| # Reshape ngược lại | |
| x = rearrange(x, "b (block_n block_s) d -> (b block_n) (block_s) d", block_n=block_n, block_s=block_s) | |
| x = WindowPartition3D.reverse(x, win, (f, h, w)) | |
| x = x.view(B, f * h * w, D) | |
| if is_stream: | |
| return self.o(x), cache_k, cache_v | |
| return self.o(x) | |
| class CrossAttention(nn.Module): | |
| def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| self.head_dim = dim // num_heads | |
| self.q = nn.Linear(dim, dim) | |
| self.k = nn.Linear(dim, dim) | |
| self.v = nn.Linear(dim, dim) | |
| self.o = nn.Linear(dim, dim) | |
| self.norm_q = RMSNorm(dim, eps=eps) | |
| self.norm_k = RMSNorm(dim, eps=eps) | |
| self.attn = AttentionModule(self.num_heads) | |
| self.cache_k = None | |
| self.cache_v = None | |
| @torch.no_grad() | |
| def init_cache(self, ctx: torch.Tensor): | |
| self.cache_k = self.norm_k(self.k(ctx)) | |
| self.cache_v = self.v(ctx) | |
| def clear_cache(self): | |
| self.cache_k = None | |
| self.cache_v = None | |
| def forward(self, x: torch.Tensor, y: torch.Tensor, is_stream: bool = False): | |
| q = self.norm_q(self.q(x)) | |
| assert self.cache_k is not None and self.cache_v is not None | |
| k, v = self.cache_k, self.cache_v | |
| x = self.attn(q, k, v) | |
| return self.o(x) | |
| class GateModule(nn.Module): | |
| def forward(self, x, gate, residual): | |
| return x + gate * residual | |
| class DiTBlock(nn.Module): | |
| def __init__(self, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| self.ffn_dim = ffn_dim | |
| self.self_attn = SelfAttention(dim, num_heads, eps) | |
| self.cross_attn = CrossAttention(dim, num_heads, eps) | |
| self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) | |
| self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) | |
| self.norm3 = nn.LayerNorm(dim, eps=eps) | |
| self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim)) | |
| self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) | |
| self.gate = GateModule() | |
| def forward( | |
| self, x, context, t_mod, freqs, f, h, w, local_num=None, topk=None, train_img=False, | |
| block_id=None, kv_len=None, is_full_block=False, is_stream=False, | |
| pre_cache_k=None, pre_cache_v=None, local_range=9, | |
| ): | |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( | |
| self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod | |
| ).chunk(6, dim=1) | |
| input_x = modulate(self.norm1(x), shift_msa, scale_msa) | |
| self_attn_output, self_attn_cache_k, self_attn_cache_v = self.self_attn( | |
| input_x, freqs, f, h, w, local_num, topk, train_img, block_id, | |
| kv_len=kv_len, is_full_block=is_full_block, is_stream=is_stream, | |
| pre_cache_k=pre_cache_k, pre_cache_v=pre_cache_v, local_range=local_range, | |
| ) | |
| x = self.gate(x, gate_msa, self_attn_output) | |
| x = x + self.cross_attn(self.norm3(x), context, is_stream=is_stream) | |
| input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) | |
| x = self.gate(x, gate_mlp, self.ffn(input_x)) | |
| if is_stream: | |
| return x, self_attn_cache_k, self_attn_cache_v | |
| return x | |
| class Head(nn.Module): | |
| def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float): | |
| super().__init__() | |
| self.dim = dim | |
| self.patch_size = patch_size | |
| self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) | |
| self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) | |
| self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) | |
| def forward(self, x, t_mod): | |
| shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) | |
| x = self.head(self.norm(x) * (1 + scale) + shift) | |
| return x | |
| def conv(n_in, n_out, **kwargs): | |
| return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) | |
| class MemBlock(nn.Module): | |
| def __init__(self, n_in, n_out): | |
| super().__init__() | |
| self.conv = nn.Sequential( | |
| conv(n_in * 2, n_out), nn.ReLU(inplace=True), | |
| conv(n_out, n_out), nn.ReLU(inplace=True), | |
| conv(n_out, n_out), | |
| ) | |
| self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() | |
| self.act = nn.ReLU(inplace=True) | |
| def forward(self, x, past): | |
| return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x)) | |
| class TGrow(nn.Module): | |
| def __init__(self, n_f, stride): | |
| super().__init__() | |
| self.stride = stride | |
| self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False) | |
| def forward(self, x): | |
| _NT, C, H, W = x.shape | |
| x = self.conv(x) | |
| return x.reshape(-1, C, H, W) | |
| # ============================================================================== | |
| # Core Models: WanModel, VAE, Proj, TAEHV | |
| # ============================================================================== | |
| class WanModelStateDictConverter: | |
| def __init__(self): | |
| pass | |
| def from_diffusers(self, state_dict): | |
| rename_dict = { | |
| "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight", | |
| "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight", | |
| "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias", | |
| "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight", | |
| "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias", | |
| "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight", | |
| "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias", | |
| "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight", | |
| "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias", | |
| "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight", | |
| "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight", | |
| "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight", | |
| "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias", | |
| "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight", | |
| "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias", | |
| "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight", | |
| "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias", | |
| "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight", | |
| "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias", | |
| "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight", | |
| "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias", | |
| "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight", | |
| "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias", | |
| "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight", | |
| "blocks.0.norm2.bias": "blocks.0.norm3.bias", | |
| "blocks.0.norm2.weight": "blocks.0.norm3.weight", | |
| "blocks.0.scale_shift_table": "blocks.0.modulation", | |
| "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias", | |
| "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight", | |
| "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias", | |
| "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight", | |
| "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias", | |
| "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight", | |
| "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias", | |
| "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight", | |
| "condition_embedder.time_proj.bias": "time_projection.1.bias", | |
| "condition_embedder.time_proj.weight": "time_projection.1.weight", | |
| "patch_embedding.bias": "patch_embedding.bias", | |
| "patch_embedding.weight": "patch_embedding.weight", | |
| "scale_shift_table": "head.modulation", | |
| "proj_out.bias": "head.head.bias", | |
| "proj_out.weight": "head.head.weight", | |
| } | |
| state_dict_ = {} | |
| for name, param in state_dict.items(): | |
| if name in rename_dict: | |
| state_dict_[rename_dict[name]] = param | |
| else: | |
| name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:]) | |
| if name_ in rename_dict: | |
| name_ = rename_dict[name_] | |
| name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:]) | |
| state_dict_[name_] = param | |
| if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b": | |
| config = { | |
| "model_type": "t2v", "patch_size": (1, 2, 2), "text_len": 512, "in_dim": 16, | |
| "dim": 5120, "ffn_dim": 13824, "freq_dim": 256, "text_dim": 4096, "out_dim": 16, | |
| "num_heads": 40, "num_layers": 40, "window_size": (-1, -1), "qk_norm": True, | |
| "cross_attn_norm": True, "eps": 1e-6, | |
| } | |
| else: | |
| config = {} | |
| return state_dict_, config | |
| def from_civitai(self, state_dict): | |
| state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")} | |
| h = hash_state_dict_keys(state_dict) | |
| config = {} | |
| if h == "9269f8db9040a9d860eaca435be61814": | |
| config = {"has_image_input": False, "patch_size": [1, 2, 2], "in_dim": 16, "dim": 1536, "ffn_dim": 8960, "freq_dim": 256, "text_dim": 4096, "out_dim": 16, "num_heads": 12, "num_layers": 30, "eps": 1e-6} | |
| elif h == "aafcfd9672c3a2456dc46e1cb6e52c70": | |
| config = {"has_image_input": False, "patch_size": [1, 2, 2], "in_dim": 16, "dim": 5120, "ffn_dim": 13824, "freq_dim": 256, "text_dim": 4096, "out_dim": 16, "num_heads": 40, "num_layers": 40, "eps": 1e-6} | |
| elif h == "6bfcfb3b342cb286ce886889d519a77e": | |
| config = {"has_image_input": False, "patch_size": [1, 2, 2], "in_dim": 36, "dim": 5120, "ffn_dim": 13824, "freq_dim": 256, "text_dim": 4096, "out_dim": 16, "num_heads": 40, "num_layers": 40, "eps": 1e-6} | |
| elif h == "6d6ccde6845b95ad9114ab993d917893": | |
| config = {"has_image_input": False, "patch_size": [1, 2, 2], "in_dim": 36, "dim": 1536, "ffn_dim": 8960, "freq_dim": 256, "text_dim": 4096, "out_dim": 16, "num_heads": 12, "num_layers": 30, "eps": 1e-6} | |
| elif h == "349723183fc063b2bfc10bb2835cf677": | |
| config = {"has_image_input": False, "patch_size": [1, 2, 2], "in_dim": 48, "dim": 1536, "ffn_dim": 8960, "freq_dim": 256, "text_dim": 4096, "out_dim": 16, "num_heads": 12, "num_layers": 30, "eps": 1e-6} | |
| elif h == "efa44cddf936c70abd0ea28b6cbe946c": | |
| config = {"has_image_input": False, "patch_size": [1, 2, 2], "in_dim": 48, "dim": 5120, "ffn_dim": 13824, "freq_dim": 256, "text_dim": 4096, "out_dim": 16, "num_heads": 40, "num_layers": 40, "eps": 1e-6} | |
| elif h == "3ef3b1f8e1dab83d5b71fd7b617f859f": | |
| config = {"has_image_input": False, "patch_size": [1, 2, 2], "in_dim": 36, "dim": 5120, "ffn_dim": 13824, "freq_dim": 256, "text_dim": 4096, "out_dim": 16, "num_heads": 40, "num_layers": 40, "eps": 1e-6, "has_image_pos_emb": False} | |
| return state_dict, config | |
| class WanModel(torch.nn.Module): | |
| def __init__(self, dim: int, in_dim: int, ffn_dim: int, out_dim: int, text_dim: int, freq_dim: int, eps: float, patch_size: Tuple[int, int, int], num_heads: int, num_layers: int, has_image_input: bool = False): | |
| super().__init__() | |
| self.dim = dim | |
| self.freq_dim = freq_dim | |
| self.patch_size = patch_size | |
| self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) | |
| self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)) | |
| self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) | |
| self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) | |
| self.blocks = nn.ModuleList([DiTBlock(dim, num_heads, ffn_dim, eps) for _ in range(num_layers)]) | |
| self.head = Head(dim, out_dim, patch_size, eps) | |
| head_dim = dim // num_heads | |
| self.freqs = precompute_freqs_cis_3d(head_dim) | |
| self._cross_kv_initialized = False | |
| def clear_cross_kv(self): | |
| for blk in self.blocks: | |
| blk.cross_attn.clear_cache() | |
| self._cross_kv_initialized = False | |
| @torch.no_grad() | |
| def reinit_cross_kv(self, new_context: torch.Tensor): | |
| ctx_txt = self.text_embedding(new_context) | |
| for blk in self.blocks: | |
| blk.cross_attn.init_cache(ctx_txt) | |
| self._cross_kv_initialized = True | |
| def patchify(self, x: torch.Tensor): | |
| x = self.patch_embedding(x) | |
| grid_size = x.shape[2:] | |
| x = rearrange(x, "b c f h w -> b (f h w) c").contiguous() | |
| return x, grid_size | |
| def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): | |
| return rearrange(x, "b (f h w) (x y z c) -> b c (f x) (h y) (w z)", f=grid_size[0], h=grid_size[1], w=grid_size[2], x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]) | |
| def forward( | |
| self, x: torch.Tensor, timestep: torch.Tensor, context: torch.Tensor, use_gradient_checkpointing: bool = False, | |
| use_gradient_checkpointing_offload: bool = False, LQ_latents: Optional[List[torch.Tensor]] = None, | |
| train_img: bool = False, topk_ratio: Optional[float] = None, kv_ratio: Optional[float] = None, | |
| local_num: Optional[int] = None, is_full_block: bool = False, causal_idx: Optional[int] = None, **kwargs, | |
| ): | |
| t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) | |
| t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) | |
| x, (f, h, w) = self.patchify(x) | |
| win = (2, 8, 8) | |
| seqlen = f // win[0] | |
| if local_num is None: | |
| local_random = random.random() | |
| if local_random < 0.3: local_num = seqlen - 3 | |
| elif local_random < 0.4: local_num = seqlen - 4 | |
| elif local_random < 0.5: local_num = seqlen - 2 | |
| else: local_num = seqlen | |
| window_size = win[0] * h * w // 128 | |
| square_num = window_size * window_size | |
| topk_ratio = 2.0 | |
| topk = min(max(int(square_num * topk_ratio), 1), int(square_num * seqlen) - 1) | |
| if kv_ratio is None: | |
| kv_ratio = (random.uniform(0.0, 1.0) ** 2) * (local_num - 2 - 2) + 2 | |
| kv_len = min(max(int(window_size * kv_ratio), 1), int(window_size * seqlen) - 1) | |
| freqs = (torch.cat([self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)], dim=-1).reshape(f * h * w, 1, -1).to(x.device)) | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): return module(*inputs) | |
| return custom_forward | |
| for block_id, block in enumerate(self.blocks): | |
| if LQ_latents is not None and block_id < len(LQ_latents): | |
| x += LQ_latents[block_id] | |
| if self.training and use_gradient_checkpointing: | |
| if use_gradient_checkpointing_offload: | |
| with torch.autograd.graph.save_on_cpu(): | |
| x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, context, t_mod, freqs, f, h, w, local_num, topk, train_img, block_id, kv_len, is_full_block, False, None, None, use_reentrant=False) | |
| else: | |
| x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, context, t_mod, freqs, f, h, w, local_num, topk, train_img, block_id, kv_len, is_full_block, False, None, None, use_reentrant=False) | |
| else: | |
| x = block(x, context, t_mod, freqs, f, h, w, local_num, topk, train_img, block_id, kv_len, is_full_block, False, None, None) | |
| x = self.head(x, t) | |
| x = self.unpatchify(x, (f, h, w)) | |
| return x | |
| @staticmethod | |
| def state_dict_converter(): | |
| return WanModelStateDictConverter() | |
| class WanVideoVAEStateDictConverter: | |
| def __init__(self): | |
| pass | |
| def from_civitai(self, state_dict): | |
| state_dict_ = {} | |
| if "model_state" in state_dict: | |
| state_dict = state_dict["model_state"] | |
| for name in state_dict: | |
| state_dict_["model." + name] = state_dict[name] | |
| return state_dict_ | |
| class Encoder3d(nn.Module): | |
| def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_downsample=[True, True, False], dropout=0.0): | |
| super().__init__() | |
| self.dim = dim | |
| self.z_dim = z_dim | |
| self.dim_mult = dim_mult | |
| self.num_res_blocks = num_res_blocks | |
| self.attn_scales = attn_scales | |
| self.temperal_downsample = temperal_downsample | |
| dims = [dim * u for u in [1] + dim_mult] | |
| scale = 1.0 | |
| self.conv1 = CausalConv3dZeroPad(3, dims[0], 3, padding=1) | |
| downsamples = [] | |
| for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): | |
| for _ in range(num_res_blocks): | |
| downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) | |
| if scale in attn_scales: downsamples.append(AttentionBlock(out_dim)) | |
| in_dim = out_dim | |
| if i != len(dim_mult) - 1: | |
| mode = "downsample3d" if temperal_downsample[i] else "downsample2d" | |
| downsamples.append(Resample(out_dim, mode=mode)) | |
| scale /= 2.0 | |
| self.downsamples = nn.Sequential(*downsamples) | |
| self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), ResidualBlock(out_dim, out_dim, dropout)) | |
| self.head = nn.Sequential(RMS_norm_General(out_dim, images=False), nn.SiLU(), CausalConv3dZeroPad(out_dim, z_dim, 3, padding=1)) | |
| def forward(self, x, feat_cache=None, feat_idx=[0]): | |
| if feat_cache is not None: | |
| idx = feat_idx[0] | |
| cache_x = x[:, :, -CACHE_T:, :, :].clone() | |
| if cache_x.shape[2] < 2 and feat_cache[idx] is not None: | |
| cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) | |
| x = self.conv1(x, feat_cache[idx]) | |
| feat_cache[idx] = cache_x | |
| feat_idx[0] += 1 | |
| else: | |
| x = self.conv1(x) | |
| for layer in self.downsamples: | |
| if feat_cache is not None: x = layer(x, feat_cache, feat_idx) | |
| else: x = layer(x) | |
| for layer in self.middle: | |
| if check_is_instance(layer, ResidualBlock) and feat_cache is not None: x = layer(x, feat_cache, feat_idx) | |
| else: x = layer(x) | |
| for layer in self.head: | |
| if check_is_instance(layer, CausalConv3dZeroPad) and feat_cache is not None: | |
| idx = feat_idx[0] | |
| cache_x = x[:, :, -CACHE_T:, :, :].clone() | |
| if cache_x.shape[2] < 2 and feat_cache[idx] is not None: | |
| cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) | |
| x = layer(x, feat_cache[idx]) | |
| feat_cache[idx] = cache_x | |
| feat_idx[0] += 1 | |
| else: x = layer(x) | |
| return x | |
| class Decoder3d(nn.Module): | |
| def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_upsample=[False, True, True], dropout=0.0): | |
| super().__init__() | |
| self.dim, self.z_dim, self.dim_mult, self.num_res_blocks, self.attn_scales, self.temperal_upsample = dim, z_dim, dim_mult, num_res_blocks, attn_scales, temperal_upsample | |
| dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] | |
| scale = 1.0 / 2 ** (len(dim_mult) - 2) | |
| self.conv1 = CausalConv3dZeroPad(z_dim, dims[0], 3, padding=1) | |
| self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0], dropout)) | |
| upsamples = [] | |
| for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): | |
| if i == 1 or i == 2 or i == 3: in_dim = in_dim // 2 | |
| for _ in range(num_res_blocks + 1): | |
| upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) | |
| if scale in attn_scales: upsamples.append(AttentionBlock(out_dim)) | |
| in_dim = out_dim | |
| if i != len(dim_mult) - 1: | |
| mode = "upsample3d" if temperal_upsample[i] else "upsample2d" | |
| upsamples.append(Resample(out_dim, mode=mode)) | |
| scale *= 2.0 | |
| self.upsamples = nn.Sequential(*upsamples) | |
| self.head = nn.Sequential(RMS_norm_General(out_dim, images=False), nn.SiLU(), CausalConv3dZeroPad(out_dim, 3, 3, padding=1)) | |
| def forward(self, x, feat_cache=None, feat_idx=[0]): | |
| if feat_cache is not None: | |
| idx = feat_idx[0] | |
| cache_x = x[:, :, -CACHE_T:, :, :].clone() | |
| if cache_x.shape[2] < 2 and feat_cache[idx] is not None: | |
| cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) | |
| x = self.conv1(x, feat_cache[idx]) | |
| feat_cache[idx] = cache_x | |
| feat_idx[0] += 1 | |
| else: x = self.conv1(x) | |
| for layer in self.middle: | |
| if check_is_instance(layer, ResidualBlock) and feat_cache is not None: x = layer(x, feat_cache, feat_idx) | |
| else: x = layer(x) | |
| for layer in self.upsamples: | |
| if feat_cache is not None: x = layer(x, feat_cache, feat_idx) | |
| else: x = layer(x) | |
| for layer in self.head: | |
| if check_is_instance(layer, CausalConv3dZeroPad) and feat_cache is not None: | |
| idx = feat_idx[0] | |
| cache_x = x[:, :, -CACHE_T:, :, :].clone() | |
| if cache_x.shape[2] < 2 and feat_cache[idx] is not None: | |
| cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) | |
| x = layer(x, feat_cache[idx]) | |
| feat_cache[idx] = cache_x | |
| feat_idx[0] += 1 | |
| else: x = layer(x) | |
| return x | |
| class VideoVAE_(nn.Module): | |
| def __init__(self, dim=96, z_dim=16, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_downsample=[False, True, True], dropout=0.0): | |
| super().__init__() | |
| self.dim, self.z_dim, self.dim_mult, self.num_res_blocks, self.attn_scales, self.temperal_downsample = dim, z_dim, dim_mult, num_res_blocks, attn_scales, temperal_downsample | |
| self.temperal_upsample = temperal_downsample[::-1] | |
| self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout) | |
| self.conv1 = CausalConv3dZeroPad(z_dim * 2, z_dim * 2, 1) | |
| self.conv2 = CausalConv3dZeroPad(z_dim, z_dim, 1) | |
| self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout) | |
| def forward(self, x): | |
| mu, log_var = self.encode(x) | |
| z = self.reparameterize(mu, log_var) | |
| x_recon = self.decode(z) | |
| return x_recon, mu, log_var | |
| def encode(self, x, scale): | |
| self.clear_cache() | |
| t = x.shape[2] | |
| iter_ = 1 + (t - 1) // 4 | |
| for i in range(iter_): | |
| self._enc_conv_idx = [0] | |
| if i == 0: out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) | |
| else: | |
| out_ = self.encoder(x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) | |
| out = torch.cat([out, out_], 2) | |
| mu, log_var = self.conv1(out).chunk(2, dim=1) | |
| if isinstance(scale[0], torch.Tensor): | |
| scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale] | |
| mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1) | |
| else: | |
| scale = scale.to(dtype=mu.dtype, device=mu.device) | |
| mu = (mu - scale[0]) * scale[1] | |
| return mu | |
| def decode(self, z, scale): | |
| self.clear_cache() | |
| if isinstance(scale[0], torch.Tensor): | |
| scale = [s.to(dtype=z.dtype, device=z.device) for s in scale] | |
| z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) | |
| else: | |
| scale = scale.to(dtype=z.dtype, device=z.device) | |
| z = z / scale[1] + scale[0] | |
| iter_ = z.shape[2] | |
| x = self.conv2(z) | |
| for i in range(iter_): | |
| self._conv_idx = [0] | |
| if i == 0: out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) | |
| else: | |
| out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) | |
| out = torch.cat([out, out_], 2) | |
| return out | |
| def stream_decode(self, z, scale): | |
| if isinstance(scale[0], torch.Tensor): | |
| scale = [s.to(dtype=z.dtype, device=z.device) for s in scale] | |
| z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) | |
| else: | |
| scale = scale.to(dtype=z.dtype, device=z.device) | |
| z = z / scale[1] + scale[0] | |
| iter_ = z.shape[2] | |
| x = self.conv2(z) | |
| for i in range(iter_): | |
| self._conv_idx = [0] | |
| if i == 0: out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) | |
| else: | |
| out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) | |
| out = torch.cat([out, out_], 2) | |
| return out | |
| def reparameterize(self, mu, log_var): | |
| std = torch.exp(0.5 * log_var) | |
| eps = torch.randn_like(std) | |
| return eps * std + mu | |
| def clear_cache(self): | |
| self._conv_num = count_conv3d(self.decoder) | |
| self._conv_idx = [0] | |
| self._feat_map = [None] * self._conv_num | |
| if self.encoder is not None: | |
| self._enc_conv_num = count_conv3d(self.encoder) | |
| self._enc_conv_idx = [0] | |
| self._enc_feat_map = [None] * self._enc_conv_num | |
| class WanVideoVAE(nn.Module): | |
| def __init__(self, z_dim=16, dim=96): | |
| super().__init__() | |
| self.mean = torch.tensor([-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921]) | |
| self.std = torch.tensor([2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160]) | |
| self.scale = [self.mean, 1.0 / self.std] | |
| self.model = VideoVAE_(z_dim=z_dim, dim=dim).eval().requires_grad_(False) | |
| self.upsampling_factor = 8 | |
| def build_1d_mask(self, length, left_bound, right_bound, border_width): | |
| x = torch.ones((length,)) | |
| if not left_bound: x[:border_width] = (torch.arange(border_width) + 1) / border_width | |
| if not right_bound: x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,)) | |
| return x | |
| def build_mask(self, data, is_bound, border_width): | |
| _, _, _, H, W = data.shape | |
| h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0]) | |
| w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1]) | |
| h = repeat(h, "H -> H W", H=H, W=W) | |
| w = repeat(w, "W -> H W", H=H, W=W) | |
| mask = torch.stack([h, w]).min(dim=0).values | |
| mask = rearrange(mask, "H W -> 1 1 1 H W") | |
| return mask | |
| def tiled_decode(self, hidden_states, device, tile_size, tile_stride): | |
| _, _, T, H, W = hidden_states.shape | |
| size_h, size_w = tile_size | |
| stride_h, stride_w = tile_stride | |
| tasks = [] | |
| for h in range(0, H, stride_h): | |
| if h - stride_h >= 0 and h - stride_h + size_h >= H: continue | |
| for w in range(0, W, stride_w): | |
| if w - stride_w >= 0 and w - stride_w + size_w >= W: continue | |
| tasks.append((h, h + size_h, w, w + size_w)) | |
| data_device, computation_device = "cpu", device | |
| out_T = T * 4 - 3 | |
| weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device) | |
| values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device) | |
| for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"): | |
| hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device) | |
| hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device) | |
| mask = self.build_mask(hidden_states_batch, is_bound=(h == 0, h_ >= H, w == 0, w_ >= W), border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)).to(dtype=hidden_states.dtype, device=data_device) | |
| target_h, target_w = h * self.upsampling_factor, w * self.upsampling_factor | |
| values[:, :, :, target_h : target_h + hidden_states_batch.shape[3], target_w : target_w + hidden_states_batch.shape[4]] += (hidden_states_batch * mask) | |
| weight[:, :, :, target_h : target_h + hidden_states_batch.shape[3], target_w : target_w + hidden_states_batch.shape[4]] += mask | |
| values = values / weight | |
| values = values.clamp_(-1, 1) | |
| return values | |
| def tiled_encode(self, video, device, tile_size, tile_stride): | |
| _, _, T, H, W = video.shape | |
| size_h, size_w = tile_size | |
| stride_h, stride_w = tile_stride | |
| tasks = [] | |
| for h in range(0, H, stride_h): | |
| if h - stride_h >= 0 and h - stride_h + size_h >= H: continue | |
| for w in range(0, W, stride_w): | |
| if w - stride_w >= 0 and w - stride_w + size_w >= W: continue | |
| tasks.append((h, h + size_h, w, w + size_w)) | |
| data_device, computation_device = "cpu", device | |
| out_T = (T + 3) // 4 | |
| weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) | |
| values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) | |
| for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"): | |
| hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device) | |
| hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device) | |
| mask = self.build_mask(hidden_states_batch, is_bound=(h == 0, h_ >= H, w == 0, w_ >= W), border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)).to(dtype=video.dtype, device=data_device) | |
| target_h, target_w = h // self.upsampling_factor, w // self.upsampling_factor | |
| values[:, :, :, target_h : target_h + hidden_states_batch.shape[3], target_w : target_w + hidden_states_batch.shape[4]] += (hidden_states_batch * mask) | |
| weight[:, :, :, target_h : target_h + hidden_states_batch.shape[3], target_w : target_w + hidden_states_batch.shape[4]] += mask | |
| values = values / weight | |
| return values | |
| def single_encode(self, video, device): | |
| video = video.to(device) | |
| x = self.model.encode(video, self.scale) | |
| return x | |
| def single_decode(self, hidden_state, device): | |
| hidden_state = hidden_state.to(device) | |
| video = self.model.decode(hidden_state, self.scale) | |
| return video.clamp_(-1, 1) | |
| def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): | |
| videos = [video.to("cpu") for video in videos] | |
| hidden_states = [] | |
| for video in videos: | |
| video = video.unsqueeze(0) | |
| if tiled: | |
| tile_size_ = (tile_size[0] * 8, tile_size[1] * 8) | |
| tile_stride_ = (tile_stride[0] * 8, tile_stride[1] * 8) | |
| hidden_state = self.tiled_encode(video, device, tile_size_, tile_stride_) | |
| else: hidden_state = self.single_encode(video, device) | |
| hidden_state = hidden_state.squeeze(0) | |
| hidden_states.append(hidden_state) | |
| return torch.stack(hidden_states) | |
| def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): | |
| hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states] | |
| videos = [] | |
| for hidden_state in hidden_states: | |
| hidden_state = hidden_state.unsqueeze(0) | |
| if tiled: video = self.tiled_decode(hidden_state, device, tile_size, tile_stride) | |
| else: video = self.single_decode(hidden_state, device) | |
| video = video.squeeze(0) | |
| videos.append(video) | |
| return torch.stack(videos) | |
| def clear_cache(self): | |
| self.model.clear_cache() | |
| def stream_decode(self, hidden_states, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): | |
| hidden_states = [hidden_state for hidden_state in hidden_states] | |
| assert len(hidden_states) == 1 | |
| return self.model.stream_decode(hidden_states[0], self.scale) | |
| @staticmethod | |
| def state_dict_converter(): | |
| return WanVideoVAEStateDictConverter() | |
| class Buffer_LQ4x_Proj(nn.Module): | |
| def __init__(self, in_dim, out_dim, layer_num=30): | |
| super().__init__() | |
| self.ff, self.hh, self.ww = 1, 16, 16 | |
| self.hidden_dim1, self.hidden_dim2, self.layer_num = 2048, 3072, layer_num | |
| self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww) | |
| self.conv1 = CausalConv3dReplicate(in_dim * self.ff * self.hh * self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) | |
| self.norm1 = RMS_norm_General(self.hidden_dim1, images=False) | |
| self.act1 = nn.SiLU() | |
| self.conv2 = CausalConv3dReplicate(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) | |
| self.norm2 = RMS_norm_General(self.hidden_dim2, images=False) | |
| self.act2 = nn.SiLU() | |
| self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)]) | |
| self.clip_idx = 0 | |
| def forward(self, video): | |
| self.clear_cache() | |
| t = video.shape[2] | |
| iter_ = 1 + (t - 1) // 4 | |
| first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1) | |
| video = torch.cat([first_frame, video], dim=2) | |
| out_x = [] | |
| for i in range(iter_): | |
| x = self.pixel_shuffle(video[:, :, i * 4 : (i + 1) * 4, :, :]) | |
| cache1_x = x[:, :, -CACHE_T:, :, :].clone() | |
| self.cache["conv1"] = cache1_x | |
| x = self.conv1(x, self.cache["conv1"]) | |
| x = self.act1(self.norm1(x)) | |
| cache2_x = x[:, :, -CACHE_T:, :, :].clone() | |
| self.cache["conv2"] = cache2_x | |
| if i == 0: continue | |
| x = self.conv2(x, self.cache["conv2"]) | |
| x = self.act2(self.norm2(x)) | |
| out_x.append(x) | |
| out_x = rearrange(torch.cat(out_x, dim=2), "b c f h w -> b (f h w) c") | |
| return [self.linear_layers[i](out_x) for i in range(self.layer_num)] | |
| def clear_cache(self): | |
| self.cache = {"conv1": None, "conv2": None} | |
| self.clip_idx = 0 | |
| def stream_forward(self, video_clip): | |
| if self.clip_idx == 0: | |
| first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1) | |
| video_clip = torch.cat([first_frame, video_clip], dim=2) | |
| x = self.pixel_shuffle(video_clip) | |
| self.cache["conv1"] = x[:, :, -CACHE_T:, :, :].clone() | |
| x = self.act1(self.norm1(self.conv1(x, self.cache["conv1"]))) | |
| self.cache["conv2"] = x[:, :, -CACHE_T:, :, :].clone() | |
| self.clip_idx += 1 | |
| return None | |
| else: | |
| x = self.pixel_shuffle(video_clip) | |
| self.cache["conv1"] = x[:, :, -CACHE_T:, :, :].clone() | |
| x = self.act1(self.norm1(self.conv1(x, self.cache["conv1"]))) | |
| self.cache["conv2"] = x[:, :, -CACHE_T:, :, :].clone() | |
| x = self.act2(self.norm2(self.conv2(x, self.cache["conv2"]))) | |
| out_x = rearrange(x, "b c f h w -> b (f h w) c") | |
| self.clip_idx += 1 | |
| return [self.linear_layers[i](out_x) for i in range(self.layer_num)] | |
| class Causal_LQ4x_Proj(nn.Module): | |
| def __init__(self, in_dim, out_dim, layer_num=30): | |
| super().__init__() | |
| self.ff, self.hh, self.ww = 1, 16, 16 | |
| self.hidden_dim1, self.hidden_dim2, self.layer_num = 2048, 3072, layer_num | |
| self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww) | |
| self.conv1 = CausalConv3dReplicate(in_dim * self.ff * self.hh * self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) | |
| self.norm1 = RMS_norm_General(self.hidden_dim1, images=False) | |
| self.act1 = nn.SiLU() | |
| self.conv2 = CausalConv3dReplicate(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) | |
| self.norm2 = RMS_norm_General(self.hidden_dim2, images=False) | |
| self.act2 = nn.SiLU() | |
| self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)]) | |
| self.clip_idx = 0 | |
| def forward(self, video): | |
| self.clear_cache() | |
| t = video.shape[2] | |
| iter_ = 1 + (t - 1) // 4 | |
| first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1) | |
| video = torch.cat([first_frame, video], dim=2) | |
| out_x = [] | |
| for i in range(iter_): | |
| x = self.pixel_shuffle(video[:, :, i * 4 : (i + 1) * 4, :, :]) | |
| cache1_x = x[:, :, -CACHE_T:, :, :].clone() | |
| x = self.conv1(x, self.cache["conv1"]) | |
| self.cache["conv1"] = cache1_x | |
| x = self.act1(self.norm1(x)) | |
| cache2_x = x[:, :, -CACHE_T:, :, :].clone() | |
| if i == 0: | |
| self.cache["conv2"] = cache2_x | |
| continue | |
| x = self.conv2(x, self.cache["conv2"]) | |
| self.cache["conv2"] = cache2_x | |
| x = self.act2(self.norm2(x)) | |
| out_x.append(x) | |
| out_x = rearrange(torch.cat(out_x, dim=2), "b c f h w -> b (f h w) c") | |
| return [self.linear_layers[i](out_x) for i in range(self.layer_num)] | |
| def clear_cache(self): | |
| self.cache = {"conv1": None, "conv2": None} | |
| self.clip_idx = 0 | |
| def stream_forward(self, video_clip): | |
| if self.clip_idx == 0: | |
| first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1) | |
| video_clip = torch.cat([first_frame, video_clip], dim=2) | |
| x = self.pixel_shuffle(video_clip) | |
| cache1_x = x[:, :, -CACHE_T:, :, :].clone() | |
| x = self.conv1(x, self.cache["conv1"]) | |
| self.cache["conv1"] = cache1_x | |
| x = self.act1(self.norm1(x)) | |
| self.cache["conv2"] = x[:, :, -CACHE_T:, :, :].clone() | |
| self.clip_idx += 1 | |
| return None | |
| else: | |
| x = self.pixel_shuffle(video_clip) | |
| cache1_x = x[:, :, -CACHE_T:, :, :].clone() | |
| x = self.conv1(x, self.cache["conv1"]) | |
| self.cache["conv1"] = cache1_x | |
| x = self.act1(self.norm1(x)) | |
| cache2_x = x[:, :, -CACHE_T:, :, :].clone() | |
| x = self.conv2(x, self.cache["conv2"]) | |
| self.cache["conv2"] = cache2_x | |
| x = self.act2(self.norm2(x)) | |
| out_x = rearrange(x, "b c f h w -> b (f h w) c") | |
| self.clip_idx += 1 | |
| return [self.linear_layers[i](out_x) for i in range(self.layer_num)] | |
| class TAEHV(nn.Module): | |
| image_channels = 3 | |
| def __init__(self, checkpoint_path="taehv.pth", decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), channels=[256, 128, 64, 64], latent_channels=16): | |
| super().__init__() | |
| self.latent_channels = latent_channels | |
| n_f = channels | |
| self.frames_to_trim = 2 ** sum(decoder_time_upscale) - 1 | |
| base_decoder = nn.Sequential( | |
| Clamp(), conv(self.latent_channels, n_f[0]), nn.ReLU(inplace=True), | |
| MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), | |
| nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), | |
| conv(n_f[0], n_f[1], bias=False), | |
| MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), | |
| nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), | |
| conv(n_f[1], n_f[2], bias=False), | |
| MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), | |
| nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), | |
| conv(n_f[2], n_f[3], bias=False), nn.ReLU(inplace=True), | |
| conv(n_f[3], TAEHV.image_channels), | |
| ) | |
| self.decoder = self._apply_identity_deepen(base_decoder, how_many_each=1, k=3) | |
| self.pixel_shuffle = PixelShuffle3dTAEHV(4, 8, 8) | |
| if checkpoint_path is not None: | |
| missing_keys = self.load_state_dict(self.patch_tgrow_layers(torch.load(checkpoint_path, map_location="cpu", weights_only=True)), strict=False) | |
| print("missing_keys", missing_keys) | |
| self.mem = [None] * len(self.decoder) | |
| @staticmethod | |
| def _apply_identity_deepen(decoder: nn.Sequential, how_many_each=1, k=3) -> nn.Sequential: | |
| new_layers = [] | |
| for b in decoder: | |
| new_layers.append(b) | |
| if isinstance(b, nn.ReLU): | |
| C = None | |
| if len(new_layers) >= 2 and isinstance(new_layers[-2], nn.Conv2d): C = new_layers[-2].out_channels | |
| elif len(new_layers) >= 2 and isinstance(new_layers[-2], MemBlock): C = new_layers[-2].conv[-1].out_channels | |
| if C is not None: | |
| for _ in range(how_many_each): | |
| new_layers.append(IdentityConv2d(C, kernel_size=k, bias=False)) | |
| new_layers.append(nn.ReLU(inplace=True)) | |
| return nn.Sequential(*new_layers) | |
| def patch_tgrow_layers(self, sd): | |
| new_sd = self.state_dict() | |
| for i, layer in enumerate(self.decoder): | |
| if isinstance(layer, TGrow): | |
| key = f"decoder.{i}.conv.weight" | |
| if key in sd and sd[key].shape[0] > new_sd[key].shape[0]: sd[key] = sd[key][-new_sd[key].shape[0] :] | |
| return sd | |
| def decode_video(self, x, parallel=True, show_progress_bar=False, cond=None): | |
| trim_flag = self.mem[-8] is None | |
| if cond is not None: x = torch.cat([self.pixel_shuffle(cond), x], dim=2) | |
| x, self.mem = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar, mem=self.mem) | |
| if trim_flag: return x[:, self.frames_to_trim :] | |
| return x | |
| def forward(self, *args, **kwargs): raise NotImplementedError("Decoder-only model: call decode_video(...) instead.") | |
| def clean_mem(self): self.mem = [None] * len(self.decoder) | |
| TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index")) | |
| class TPool(nn.Module): | |
| def __init__(self, n_f, stride): | |
| super().__init__() | |
| self.stride = stride | |
| self.conv = nn.Conv2d(n_f * stride, n_f, 1, bias=False) | |
| def forward(self, x): | |
| _NT, C, H, W = x.shape | |
| return self.conv(x.reshape(-1, self.stride * C, H, W)) | |
| class TGrow(nn.Module): | |
| def __init__(self, n_f, stride): | |
| super().__init__() | |
| self.stride = stride | |
| self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False) | |
| def forward(self, x): | |
| _NT, C, H, W = x.shape | |
| x = self.conv(x) | |
| return x.reshape(-1, C, H, W) | |
| def apply_model_with_memblocks(model, x, parallel, show_progress_bar, mem=None): | |
| assert x.ndim == 5 | |
| N, T, C, H, W = x.shape | |
| if parallel: | |
| x = x.reshape(N * T, C, H, W) | |
| for b in tqdm(model, disable=not show_progress_bar): | |
| if isinstance(b, MemBlock): | |
| NT, C, H, W = x.shape | |
| T = NT // N | |
| _x = x.reshape(N, T, C, H, W) | |
| mem = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape(x.shape) | |
| x = b(x, mem) | |
| else: x = b(x) | |
| NT, C, H, W = x.shape | |
| T = NT // N | |
| x = x.view(N, T, C, H, W) | |
| else: | |
| out = [] | |
| work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))] | |
| progress_bar = tqdm(range(T), disable=not show_progress_bar) | |
| while work_queue: | |
| xt, i = work_queue.pop(0) | |
| if i == 0: progress_bar.update(1) | |
| if i == len(model): out.append(xt) | |
| else: | |
| b = model[i] | |
| if isinstance(b, MemBlock): | |
| if mem[i] is None: | |
| xt_new = b(xt, xt * 0) | |
| mem[i] = xt | |
| else: | |
| xt_new = b(xt, mem[i]) | |
| mem[i].copy_(xt) | |
| work_queue.insert(0, TWorkItem(xt_new, i + 1)) | |
| elif isinstance(b, TPool): # Assuming TPool exists (it was defined earlier but missing from class defs in raw dump, I'll add stub if missing) | |
| pass # Simplified for brevity as it wasn't used in TAEHV init path provided | |
| elif isinstance(b, TGrow): | |
| xt = b(xt) | |
| NT, C_, H_, W_ = xt.shape | |
| for xt_next in reversed(xt.view(N, b.stride * C_, H_, W_).chunk(b.stride, 1)): | |
| work_queue.insert(0, TWorkItem(xt_next, i + 1)) | |
| else: | |
| xt = b(xt) | |
| work_queue.insert(0, TWorkItem(xt, i + 1)) | |
| progress_bar.close() | |
| x = torch.stack(out, 1) | |
| return x, mem | |
| def build_tcdecoder(new_channels=[512, 256, 128, 128], device="cuda", dtype=torch.bfloat16, new_latent_channels=None): | |
| if new_latent_channels is not None: | |
| big = TAEHV(checkpoint_path=None, channels=new_channels, latent_channels=new_latent_channels).to(device).to(dtype).train() | |
| else: | |
| big = TAEHV(checkpoint_path=None, channels=new_channels).to(device).to(dtype).train() | |
| big.clean_mem() | |
| return big | |
| # ============================================================================== | |
| # Model Loading & Management | |
| # ============================================================================== | |
| model_loader_configs = [ | |
| (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"), | |
| (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"), | |
| (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"), | |
| (None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"), | |
| (None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"), | |
| (None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"), | |
| (None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"), | |
| (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), | |
| (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"), | |
| (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"), | |
| ] | |
| huggingface_model_loader_configs = [] | |
| patch_model_loader_configs = [] | |
| @contextmanager | |
| def init_weights_on_device(device=torch.device("meta"), include_buffers: bool = False): | |
| old_register_parameter = torch.nn.Module.register_parameter | |
| old_register_buffer = torch.nn.Module.register_buffer if include_buffers else None | |
| def register_empty_parameter(module, name, param): | |
| old_register_parameter(module, name, param) | |
| if param is not None: | |
| param_cls = type(module._parameters[name]) | |
| kwargs = module._parameters[name].__dict__ | |
| kwargs["requires_grad"] = param.requires_grad | |
| module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) | |
| def register_empty_buffer(module, name, buffer, persistent=True): | |
| old_register_buffer(module, name, buffer, persistent=persistent) | |
| if buffer is not None: module._buffers[name] = module._buffers[name].to(device) | |
| def patch_tensor_constructor(fn): | |
| def wrapper(*args, **kwargs): | |
| kwargs["device"] = device | |
| return fn(*args, **kwargs) | |
| return wrapper | |
| if include_buffers: | |
| tensor_constructors_to_patch = {name: getattr(torch, name) for name in ["empty", "zeros", "ones", "full"]} | |
| else: tensor_constructors_to_patch = {} | |
| try: | |
| torch.nn.Module.register_parameter = register_empty_parameter | |
| if include_buffers: torch.nn.Module.register_buffer = register_empty_buffer | |
| for name in tensor_constructors_to_patch: setattr(torch, name, patch_tensor_constructor(getattr(torch, name))) | |
| yield | |
| finally: | |
| torch.nn.Module.register_parameter = old_register_parameter | |
| if include_buffers: torch.nn.Module.register_buffer = old_register_buffer | |
| for name, old_fn in tensor_constructors_to_patch.items(): setattr(torch, name, old_fn) | |
| def load_state_dict(file_path, torch_dtype=None): | |
| if file_path.endswith(".safetensors"): | |
| state_dict = {} | |
| with safe_open(file_path, framework="pt", device="cpu") as f: | |
| for k in f.keys(): | |
| state_dict[k] = f.get_tensor(k) | |
| if torch_dtype is not None: state_dict[k] = state_dict[k].to(torch_dtype) | |
| return state_dict | |
| else: | |
| state_dict = torch.load(file_path, map_location="cpu", weights_only=True) | |
| if torch_dtype is not None: | |
| for i in state_dict: | |
| if isinstance(state_dict[i], torch.Tensor): state_dict[i] = state_dict[i].to(torch_dtype) | |
| return state_dict | |
| def convert_state_dict_keys_to_single_str(state_dict, with_shape=True): | |
| keys = [] | |
| for key, value in state_dict.items(): | |
| if isinstance(key, str): | |
| if isinstance(value, torch.Tensor): | |
| if with_shape: keys.append(key + ":" + "_".join(map(str, list(value.shape)))) | |
| keys.append(key) | |
| elif isinstance(value, dict): keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape)) | |
| keys.sort() | |
| return ",".join(keys) | |
| def hash_state_dict_keys(state_dict, with_shape=True): | |
| keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape) | |
| return hashlib.md5(keys_str.encode(encoding="UTF-8")).hexdigest() | |
| def split_state_dict_with_prefix(state_dict): | |
| keys = sorted([key for key in state_dict if isinstance(key, str)]) | |
| prefix_dict = {} | |
| for key in keys: | |
| prefix = key if "." not in key else key.split(".")[0] | |
| if prefix not in prefix_dict: prefix_dict[prefix] = [] | |
| prefix_dict[prefix].append(key) | |
| state_dicts = [] | |
| for prefix, keys in prefix_dict.items(): | |
| state_dicts.append({key: state_dict[key] for key in keys}) | |
| return state_dicts | |
| def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device): | |
| loaded_model_names, loaded_models = [], [] | |
| for model_name, model_class in zip(model_names, model_classes): | |
| state_dict_converter = model_class.state_dict_converter() | |
| if model_resource == "civitai": state_dict_results = state_dict_converter.from_civitai(state_dict) | |
| elif model_resource == "diffusers": state_dict_results = state_dict_converter.from_diffusers(state_dict) | |
| if isinstance(state_dict_results, tuple): model_state_dict, extra_kwargs = state_dict_results | |
| else: model_state_dict, extra_kwargs = state_dict_results, {} | |
| torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype | |
| with init_weights_on_device(): model = model_class(**extra_kwargs) | |
| if hasattr(model, "eval"): model = model.eval() | |
| model.load_state_dict(model_state_dict, assign=True) | |
| model = model.to(dtype=torch_dtype, device=device) | |
| loaded_model_names.append(model_name) | |
| loaded_models.append(model) | |
| return loaded_model_names, loaded_models | |
| class ModelDetectorFromSingleFile: | |
| def __init__(self, model_loader_configs=[]): | |
| self.keys_hash_with_shape_dict = {} | |
| self.keys_hash_dict = {} | |
| for metadata in model_loader_configs: self.add_model_metadata(*metadata) | |
| def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource): | |
| self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource) | |
| if keys_hash is not None: self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource) | |
| def match(self, file_path="", state_dict={}): | |
| if isinstance(file_path, str) and os.path.isdir(file_path): return False | |
| if len(state_dict) == 0: state_dict = load_state_dict(file_path) | |
| keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True) | |
| if keys_hash_with_shape in self.keys_hash_with_shape_dict: return True | |
| keys_hash = hash_state_dict_keys(state_dict, with_shape=False) | |
| if keys_hash in self.keys_hash_dict: return True | |
| return False | |
| def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs): | |
| if len(state_dict) == 0: state_dict = load_state_dict(file_path) | |
| keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True) | |
| if keys_hash_with_shape in self.keys_hash_with_shape_dict: | |
| return load_model_from_single_file(state_dict, *self.keys_hash_with_shape_dict[keys_hash_with_shape], torch_dtype, device) | |
| keys_hash = hash_state_dict_keys(state_dict, with_shape=False) | |
| if keys_hash in self.keys_hash_dict: | |
| return load_model_from_single_file(state_dict, *self.keys_hash_dict[keys_hash], torch_dtype, device) | |
| return [], [] | |
| class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile): | |
| def match(self, file_path="", state_dict={}): | |
| if isinstance(file_path, str) and os.path.isdir(file_path): return False | |
| if len(state_dict) == 0: state_dict = load_state_dict(file_path) | |
| splited_state_dict = split_state_dict_with_prefix(state_dict) | |
| for sub_state_dict in splited_state_dict: | |
| if super().match(file_path, sub_state_dict): return True | |
| return False | |
| def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs): | |
| splited_state_dict = split_state_dict_with_prefix(state_dict) | |
| valid_state_dict = {} | |
| for sub_state_dict in splited_state_dict: | |
| if super().match(file_path, sub_state_dict): valid_state_dict.update(sub_state_dict) | |
| if super().match(file_path, valid_state_dict): return super().load(file_path, valid_state_dict, device, torch_dtype) | |
| loaded_model_names, loaded_models = [], [] | |
| for sub_state_dict in splited_state_dict: | |
| if super().match(file_path, sub_state_dict): | |
| ln, lm = super().load(file_path, valid_state_dict, device, torch_dtype) | |
| loaded_model_names += ln | |
| loaded_models += lm | |
| return loaded_model_names, loaded_models | |
| class ModelManager: | |
| def __init__(self, torch_dtype=torch.float16, device="cuda", file_path_list: List[str] = []): | |
| self.torch_dtype, self.device = torch_dtype, device | |
| self.model, self.model_path, self.model_name = [], [], [] | |
| self.model_detector = [ModelDetectorFromSingleFile(model_loader_configs), ModelDetectorFromSplitedSingleFile(model_loader_configs)] | |
| self.load_models(file_path_list) | |
| def load_model(self, file_path, model_names=None, device=None, torch_dtype=None): | |
| if device is None: device = self.device | |
| if torch_dtype is None: torch_dtype = self.torch_dtype | |
| if isinstance(file_path, list): | |
| state_dict = {} | |
| for path in file_path: state_dict.update(load_state_dict(path)) | |
| elif os.path.isfile(file_path): state_dict = load_state_dict(file_path) | |
| else: state_dict = None | |
| for model_detector in self.model_detector: | |
| if model_detector.match(file_path, state_dict): | |
| model_names, models = model_detector.load(file_path, state_dict, device=device, torch_dtype=torch_dtype, allowed_model_names=model_names, model_manager=self) | |
| for model_name, model in zip(model_names, models): | |
| self.model.append(model) | |
| self.model_path.append(file_path) | |
| self.model_name.append(model_name) | |
| break | |
| else: print(f" We cannot detect the model type. No models are loaded.") | |
| def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None): | |
| for file_path in file_path_list: self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype) | |
| def fetch_model(self, model_name, file_path=None, require_model_path=False): | |
| fetched_models, fetched_model_paths = [], [] | |
| for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name): | |
| if file_path is not None and file_path != model_path: continue | |
| if model_name == model_name_: | |
| fetched_models.append(model) | |
| fetched_model_paths.append(model_path) | |
| if len(fetched_models) == 0: return None | |
| if len(fetched_models) == 1: print(f"Using {model_name} from {fetched_model_paths[0]}") | |
| else: print(f"More than one {model_name} models are loaded. Using {model_name} from {fetched_model_paths[0]}") | |
| if require_model_path: return fetched_models[0], fetched_model_paths[0] | |
| else: return fetched_models[0] | |
| def to(self, device): | |
| for model in self.model: model.to(device) | |
| def cast_to(weight, dtype, device): | |
| r = torch.empty_like(weight, dtype=dtype, device=device) | |
| r.copy_(weight) | |
| return r | |
| class AutoWrappedModule(torch.nn.Module): | |
| def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device): | |
| super().__init__() | |
| self.module = module.to(dtype=offload_dtype, device=offload_device) | |
| self.offload_dtype, self.offload_device = offload_dtype, offload_device | |
| self.onload_dtype, self.onload_device = onload_dtype, onload_device | |
| self.computation_dtype, self.computation_device = computation_dtype, computation_device | |
| self.state = 0 | |
| def offload(self): | |
| if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): | |
| self.module.to(dtype=self.offload_dtype, device=self.offload_device) | |
| self.state = 0 | |
| def onload(self): | |
| if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): | |
| self.module.to(dtype=self.onload_dtype, device=self.onload_device) | |
| self.state = 1 | |
| def forward(self, *args, **kwargs): | |
| if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: module = self.module | |
| else: module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device) | |
| return module(*args, **kwargs) | |
| class AutoWrappedLinear(torch.nn.Linear): | |
| def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device): | |
| with init_weights_on_device(device=torch.device("meta")): | |
| super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device) | |
| self.weight, self.bias = module.weight, module.bias | |
| self.offload_dtype, self.offload_device = offload_dtype, offload_device | |
| self.onload_dtype, self.onload_device = onload_dtype, onload_device | |
| self.computation_dtype, self.computation_device = computation_dtype, computation_device | |
| self.state = 0 | |
| def offload(self): | |
| if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): | |
| self.to(dtype=self.offload_dtype, device=self.offload_device) | |
| self.state = 0 | |
| def onload(self): | |
| if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): | |
| self.to(dtype=self.onload_dtype, device=self.onload_device) | |
| self.state = 1 | |
| def forward(self, x, *args, **kwargs): | |
| if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: weight, bias = self.weight, self.bias | |
| else: | |
| weight = cast_to(self.weight, self.computation_dtype, self.computation_device) | |
| bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) | |
| return torch.nn.functional.linear(x, weight, bias) | |
| def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0): | |
| for name, module in model.named_children(): | |
| for source_module, target_module in module_map.items(): | |
| if isinstance(module, source_module): | |
| num_param = sum(p.numel() for p in module.parameters()) | |
| if max_num_param is not None and total_num_param + num_param > max_num_param: module_config_ = overflow_module_config | |
| else: module_config_ = module_config | |
| module_ = target_module(module, **module_config_) | |
| setattr(model, name, module_) | |
| total_num_param += num_param | |
| break | |
| else: total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param) | |
| return total_num_param | |
| def enable_vram_management_(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None): | |
| enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0) | |
| model.vram_management_enabled = True | |
| # ============================================================================== | |
| # Pipeline & Schedulers | |
| # ============================================================================== | |
| class FlowMatchScheduler: | |
| def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003 / 1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False): | |
| self.num_train_timesteps = num_train_timesteps | |
| self.shift, self.sigma_max, self.sigma_min = shift, sigma_max, sigma_min | |
| self.inverse_timesteps, self.extra_one_step, self.reverse_sigmas = inverse_timesteps, extra_one_step, reverse_sigmas | |
| self.set_timesteps(num_inference_steps) | |
| def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None): | |
| if shift is not None: self.shift = shift | |
| sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength | |
| if self.extra_one_step: self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1] | |
| else: self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps) | |
| if self.inverse_timesteps: self.sigmas = torch.flip(self.sigmas, dims=[0]) | |
| self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas) | |
| if self.reverse_sigmas: self.sigmas = 1 - self.sigmas | |
| self.timesteps = self.sigmas * self.num_train_timesteps | |
| def step(self, model_output, timestep, sample, to_final=False, **kwargs): | |
| if isinstance(timestep, torch.Tensor): timestep = timestep.cpu() | |
| timestep_id = torch.argmin((self.timesteps - timestep).abs()) | |
| sigma = self.sigmas[timestep_id] | |
| if to_final or timestep_id + 1 >= len(self.timesteps): sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0 | |
| else: sigma_ = self.sigmas[timestep_id + 1] | |
| prev_sample = sample + model_output * (sigma_ - sigma) | |
| return prev_sample | |
| class BasePipeline(torch.nn.Module): | |
| def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64): | |
| super().__init__() | |
| self.device, self.torch_dtype = device, torch_dtype | |
| self.height_division_factor, self.width_division_factor = height_division_factor, width_division_factor | |
| self.cpu_offload, self.model_names = False, [] | |
| def check_resize_height_width(self, height, width): | |
| if height % self.height_division_factor != 0: | |
| height = ((height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor) | |
| print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.") | |
| if width % self.width_division_factor != 0: | |
| width = ((width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor) | |
| print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.") | |
| return height, width | |
| def preprocess_image(self, image): | |
| return torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) | |
| def enable_cpu_offload(self): | |
| self.cpu_offload = True | |
| def load_models_to_device(self, loadmodel_names=[]): | |
| if not self.cpu_offload: return | |
| for model_name in self.model_names: | |
| if model_name not in loadmodel_names: | |
| model = getattr(self, model_name) | |
| if model is not None: | |
| if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: | |
| for module in model.modules(): | |
| if hasattr(module, "offload"): module.offload() | |
| else: model.cpu() | |
| for model_name in loadmodel_names: | |
| model = getattr(self, model_name) | |
| if model is not None: | |
| if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: | |
| for module in model.modules(): | |
| if hasattr(module, "onload"): module.onload() | |
| else: model.to(self.device) | |
| if torch.cuda.is_available(): torch.cuda.empty_cache() | |
| if torch.backends.mps.is_available(): torch.mps.empty_cache() | |
| def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16): | |
| generator = None if seed is None else torch.Generator(device).manual_seed(seed) | |
| noise = torch.randn(shape, generator=generator, device=device, dtype=dtype) | |
| return noise | |
| def to(self, device, dtype=None): | |
| self.device = device | |
| if dtype: self.torch_dtype = dtype | |
| super().to(device) | |
| def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]: | |
| N, C = feat.shape[:2] | |
| var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps | |
| std = var.sqrt().view(N, C, 1, 1) | |
| mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) | |
| return mean, std | |
| def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor: | |
| size = content_feat.size() | |
| style_mean, style_std = _calc_mean_std(style_feat) | |
| content_mean, content_std = _calc_mean_std(content_feat) | |
| normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size) | |
| return normalized * style_std.expand(size) + style_mean.expand(size) | |
| def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor: | |
| return torch.tensor([[0.0625, 0.125, 0.0625], [0.125, 0.25, 0.125], [0.0625, 0.125, 0.0625]], dtype=dtype, device=device) | |
| def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor: | |
| N, C, H, W = x.shape | |
| base = _make_gaussian3x3_kernel(x.dtype, x.device) | |
| weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1) | |
| x_pad = F.pad(x, (radius, radius, radius, radius), mode="replicate") | |
| out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C) | |
| return out | |
| def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]: | |
| high, low = torch.zeros_like(x), x | |
| for i in range(levels): | |
| radius = 2**i | |
| blurred = _wavelet_blur(low, radius) | |
| high = high + (low - blurred) | |
| low = blurred | |
| return high, low | |
| def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor: | |
| c_high, _ = _wavelet_decompose(content, levels=levels) | |
| _, s_low = _wavelet_decompose(style, levels=levels) | |
| return c_high + s_low | |
| class TorchColorCorrectorWavelet(nn.Module): | |
| def __init__(self, levels: int = 5): | |
| super().__init__() | |
| self.levels = levels | |
| @staticmethod | |
| def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]: | |
| B, C, f, H, W = x.shape | |
| return x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W), B, f | |
| @staticmethod | |
| def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor: | |
| return y.reshape(B, f, -1, y.shape[2], y.shape[3]).permute(0, 2, 1, 3, 4) | |
| def forward(self, hq_image: torch.Tensor, lq_image: torch.Tensor, clip_range: Tuple[float, float] = (-1.0, 1.0), method: Literal["wavelet", "adain"] = "wavelet", chunk_size: Optional[int] = None) -> torch.Tensor: | |
| B, C, f, H, W = hq_image.shape | |
| if chunk_size is None or chunk_size >= f: | |
| hq4, B, f = self._flatten_time(hq_image) | |
| lq4, _, _ = self._flatten_time(lq_image) | |
| if method == "wavelet": out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels) | |
| elif method == "adain": out4 = _adain(hq4, lq4) | |
| return self._unflatten_time(torch.clamp(out4, *clip_range), B, f) | |
| outs = [] | |
| for start in range(0, f, chunk_size): | |
| end = min(start + chunk_size, f) | |
| hq4, B_, f_ = self._flatten_time(hq_image[:, :, start:end]) | |
| lq4, _, _ = self._flatten_time(lq_image[:, :, start:end]) | |
| if method == "wavelet": out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels) | |
| elif method == "adain": out4 = _adain(hq4, lq4) | |
| outs.append(self._unflatten_time(torch.clamp(out4, *clip_range), B_, f_)) | |
| return torch.cat(outs, dim=2) | |
| def model_fn_wan_video(dit: WanModel, x: torch.Tensor, timestep: torch.Tensor, context: torch.Tensor, tea_cache: Optional[object] = None, use_unified_sequence_parallel: bool = False, LQ_latents: Optional[torch.Tensor] = None, is_full_block: bool = False, is_stream: bool = False, pre_cache_k: Optional[list[torch.Tensor]] = None, pre_cache_v: Optional[list[torch.Tensor]] = None, topk_ratio: float = 2.0, kv_ratio: float = 3.0, cur_process_idx: int = 0, t_mod: torch.Tensor = None, t: torch.Tensor = None, local_range: int = 9, **kwargs): | |
| x, (f, h, w) = dit.patchify(x) | |
| win = (2, 8, 8) | |
| seqlen = f // win[0] | |
| local_num = seqlen | |
| window_size = win[0] * h * w // 128 | |
| square_num = window_size * window_size | |
| topk = int(square_num * topk_ratio) - 1 | |
| kv_len = int(kv_ratio) | |
| if cur_process_idx == 0: | |
| freqs = (torch.cat([dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)], dim=-1).reshape(f * h * w, 1, -1).to(x.device)) | |
| else: | |
| freqs = (torch.cat([dit.freqs[0][4 + cur_process_idx * 2 : 4 + cur_process_idx * 2 + f].view(f, 1, 1, -1).expand(f, h, w, -1), dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)], dim=-1).reshape(f * h * w, 1, -1).to(x.device)) | |
| tea_cache_update = (tea_cache.check(dit, x, t_mod) if tea_cache is not None else False) | |
| if use_unified_sequence_parallel: | |
| import torch.distributed as dist | |
| from xfuser.core.distributed import get_sequence_parallel_rank, get_sequence_parallel_world_size | |
| if dist.is_initialized() and dist.get_world_size() > 1: x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] | |
| if tea_cache_update: x = tea_cache.update(x) | |
| else: | |
| for block_id, block in enumerate(dit.blocks): | |
| if LQ_latents is not None and block_id < len(LQ_latents): x = x + LQ_latents[block_id] | |
| x, last_pre_cache_k, last_pre_cache_v = block(x, context, t_mod, freqs, f, h, w, local_num, topk, block_id=block_id, kv_len=kv_len, is_full_block=is_full_block, is_stream=is_stream, pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None, pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None, local_range=local_range) | |
| if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k | |
| if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v | |
| x = dit.head(x, t) | |
| if use_unified_sequence_parallel: | |
| from xfuser.core.distributed import get_sp_group | |
| if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: x = get_sp_group().all_gather(x, dim=1) | |
| x = dit.unpatchify(x, (f, h, w)) | |
| return x, pre_cache_k, pre_cache_v | |
| class FlashVSRTinyPipeline(BasePipeline): | |
| def __init__(self, device="cuda", torch_dtype=torch.float16): | |
| super().__init__(device=device, torch_dtype=torch_dtype) | |
| self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True) | |
| self.dit: WanModel = None | |
| self.vae: WanVideoVAE = None | |
| self.model_names = ["dit", "vae"] | |
| self.height_division_factor = 16 | |
| self.width_division_factor = 16 | |
| self.use_unified_sequence_parallel = False | |
| self.prompt_emb_posi = None | |
| self.ColorCorrector = TorchColorCorrectorWavelet(levels=5) | |
| print(r""" | |
| ███████╗██╗ █████╗ ███████╗██╗ ██╗██╗ ██╗███████╗█████╗ | |
| ██╔════╝██║ ██╔══██╗██╔════╝██║ ██║██║ ██║██╔════╝██╔══██╗ ██╗ | |
| █████╗ ██║ ███████║███████╗███████║╚██╗ ██╔╝███████╗███████║ ██████╗ | |
| ██╔══╝ ██║ ██╔══██║╚════██║██╔══██║ ╚████╔╝ ╚════██║██╔═██║ ██╔═╝ | |
| ██║ ███████╗██║ ██║███████║██║ ██║ ╚██╔╝ ███████║██║ ██║ ╚═╝ | |
| ╚═╝ ╚══════╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝ | |
| """) | |
| def enable_vram_management(self, num_persistent_param_in_dit=None): | |
| dtype = next(iter(self.dit.parameters())).dtype | |
| enable_vram_management_( | |
| self.dit, | |
| module_map={torch.nn.Linear: AutoWrappedLinear, torch.nn.Conv3d: AutoWrappedModule, torch.nn.LayerNorm: AutoWrappedModule, RMSNorm: AutoWrappedModule}, | |
| module_config=dict(offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device=self.device, computation_dtype=self.torch_dtype, computation_device=self.device), | |
| max_num_param=num_persistent_param_in_dit, | |
| overflow_module_config=dict(offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device="cpu", computation_dtype=self.torch_dtype, computation_device=self.device), | |
| ) | |
| self.enable_cpu_offload() | |
| def fetch_models(self, model_manager: ModelManager): | |
| self.dit = model_manager.fetch_model("wan_video_dit") | |
| self.vae = model_manager.fetch_model("wan_video_vae") | |
| @staticmethod | |
| def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False): | |
| if device is None: device = model_manager.device | |
| if torch_dtype is None: torch_dtype = model_manager.torch_dtype | |
| pipe = FlashVSRTinyPipeline(device=device, torch_dtype=torch_dtype) | |
| pipe.fetch_models(model_manager) | |
| pipe.use_unified_sequence_parallel = False | |
| return pipe | |
| def denoising_model(self): | |
| return self.dit | |
| def init_cross_kv(self, context_tensor: Optional[torch.Tensor] = None, prompt_path=None): | |
| self.load_models_to_device(["dit"]) | |
| if self.dit is None: raise RuntimeError("Please initialize self.dit first") | |
| if context_tensor is None: | |
| if prompt_path is None: raise ValueError("Provide prompt_path or context_tensor") | |
| ctx = torch.load(prompt_path, map_location=self.device) | |
| else: ctx = context_tensor | |
| ctx = ctx.to(dtype=self.torch_dtype, device=self.device) | |
| if self.prompt_emb_posi is None: self.prompt_emb_posi = {} | |
| self.prompt_emb_posi["context"] = ctx | |
| self.prompt_emb_posi["stats"] = "load" | |
| if hasattr(self.dit, "reinit_cross_kv"): self.dit.reinit_cross_kv(ctx) | |
| else: raise AttributeError("WanModel missing reinit_cross_kv") | |
| self.timestep = torch.tensor([1000.0], device=self.device, dtype=self.torch_dtype) | |
| self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep)) | |
| self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim)) | |
| self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0) | |
| self.load_models_to_device([]) | |
| def offload_model(self, keep_vae=False): | |
| self.dit.clear_cross_kv() | |
| self.prompt_emb_posi["stats"] = "offload" | |
| self.load_models_to_device([]) | |
| if hasattr(self.dit, "LQ_proj_in"): self.dit.LQ_proj_in.to("cpu") | |
| if not keep_vae: self.TCDecoder.to("cpu") | |
| @torch.no_grad() | |
| def __call__( | |
| self, prompt=None, negative_prompt="", denoising_strength=1.0, seed=None, rand_device="gpu", height=480, width=832, | |
| num_frames=81, cfg_scale=5.0, num_inference_steps=50, sigma_shift=5.0, tiled=True, tile_size=(60, 104), tile_stride=(30, 52), | |
| tea_cache_l1_thresh=None, tea_cache_model_id="Wan2.1-T2V-1.3B", progress_bar_cmd=tqdm, progress_bar_st=None, | |
| LQ_video=None, is_full_block=False, if_buffer=False, topk_ratio=2.0, kv_ratio=3.0, local_range=9, color_fix=True, | |
| unload_dit=False, force_offload=False, **kwargs, | |
| ): | |
| assert cfg_scale == 1.0, "cfg_scale must be 1.0" | |
| if self.prompt_emb_posi is None or "context" not in self.prompt_emb_posi: raise RuntimeError("Call init_cross_kv() first") | |
| height, width = self.check_resize_height_width(height, width) | |
| if num_frames % 4 != 1: | |
| num_frames = (num_frames + 2) // 4 * 4 + 1 | |
| print(f"Rounding frames to {num_frames}.") | |
| if if_buffer: noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height // 8, width // 8), seed=seed, device=self.device, dtype=self.torch_dtype) | |
| else: noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height // 8, width // 8), seed=seed, device=self.device, dtype=self.torch_dtype) | |
| latents = noise | |
| process_total_num = (num_frames - 1) // 8 - 2 | |
| is_stream = True | |
| if self.prompt_emb_posi["stats"] == "offload": self.init_cross_kv(context_tensor=self.prompt_emb_posi["context"]) | |
| self.load_models_to_device(["dit"]) | |
| self.dit.LQ_proj_in.to(self.device) | |
| self.TCDecoder.to(self.device) | |
| if hasattr(self.dit, "LQ_proj_in"): self.dit.LQ_proj_in.clear_cache() | |
| latents_total = [] | |
| self.TCDecoder.clean_mem() | |
| LQ_pre_idx = 0 | |
| LQ_cur_idx = 0 | |
| with torch.no_grad(): | |
| for cur_process_idx in progress_bar_cmd(range(process_total_num)): | |
| if cur_process_idx == 0: | |
| pre_cache_k = [None] * len(self.dit.blocks) | |
| pre_cache_v = [None] * len(self.dit.blocks) | |
| LQ_latents = None | |
| inner_loop_num = 7 | |
| for inner_idx in range(inner_loop_num): | |
| cur = (self.denoising_model().LQ_proj_in.stream_forward(LQ_video[:, :, max(0, inner_idx * 4 - 3) : (inner_idx + 1) * 4 - 3, :, :]) if LQ_video is not None else None) | |
| if cur is None: continue | |
| if LQ_latents is None: LQ_latents = cur | |
| else: | |
| for layer_idx in range(len(LQ_latents)): LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1) | |
| LQ_cur_idx = (inner_loop_num - 1) * 4 - 3 | |
| cur_latents = latents[:, :, :6, :, :] | |
| else: | |
| LQ_latents = None | |
| inner_loop_num = 2 | |
| for inner_idx in range(inner_loop_num): | |
| cur = (self.denoising_model().LQ_proj_in.stream_forward(LQ_video[:, :, cur_process_idx * 8 + 17 + inner_idx * 4 : cur_process_idx * 8 + 21 + inner_idx * 4, :, :]) if LQ_video is not None else None) | |
| if cur is None: continue | |
| if LQ_latents is None: LQ_latents = cur | |
| else: | |
| for layer_idx in range(len(LQ_latents)): LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1) | |
| LQ_cur_idx = cur_process_idx * 8 + 21 + (inner_loop_num - 2) * 4 | |
| cur_latents = latents[:, :, 4 + cur_process_idx * 2 : 6 + cur_process_idx * 2, :, :] | |
| noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video( | |
| self.dit, x=cur_latents, timestep=self.timestep, context=None, tea_cache=None, | |
| use_unified_sequence_parallel=False, LQ_latents=LQ_latents, is_full_block=is_full_block, | |
| is_stream=is_stream, pre_cache_k=pre_cache_k, pre_cache_v=pre_cache_v, | |
| topk_ratio=topk_ratio, kv_ratio=kv_ratio, cur_process_idx=cur_process_idx, | |
| t_mod=self.t_mod, t=self.t, local_range=local_range, | |
| ) | |
| cur_latents = cur_latents - noise_pred_posi | |
| latents_total.append(cur_latents) | |
| LQ_pre_idx = LQ_cur_idx | |
| if hasattr(self.dit, "LQ_proj_in"): self.dit.LQ_proj_in.clear_cache() | |
| if unload_dit and hasattr(self, "dit") and not next(self.dit.parameters()).is_cpu: | |
| print("[FlashVSR] Offloading DiT to the CPU to free up VRAM...") | |
| self.offload_model(keep_vae=True) | |
| latents = torch.cat(latents_total, dim=2) | |
| print("[FlashVSR] Starting VAE decoding...") | |
| frames = (self.TCDecoder.decode_video(latents.transpose(1, 2), parallel=False, show_progress_bar=False, cond=LQ_video[:, :, :LQ_cur_idx, :, :]).transpose(1, 2).mul_(2).sub_(1)) | |
| self.TCDecoder.clean_mem() | |
| if force_offload: self.offload_model() | |
| try: | |
| if color_fix: frames = self.ColorCorrector(frames.to(device=LQ_video.device), LQ_video[:, :, : frames.shape[2], :, :], clip_range=(-1, 1), chunk_size=16, method="adain") | |
| except: pass | |
| return frames[0] | |
| # Assuming FlashVSRTinyLongPipeline functionality is covered by FlashVSRTinyPipeline or is an alias | |
| FlashVSRTinyLongPipeline = FlashVSRTinyPipeline | |
| FlashVSRFullPipeline = FlashVSRTinyPipeline # Fallback alias for safety | |
| # ============================================================================== | |
| # Helper Utils (IO, Tiling, Processing) | |
| # ============================================================================== | |
| def clean_vram(): | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| if torch.backends.mps.is_available(): | |
| torch.mps.empty_cache() # Hỗ trợ dọn dẹp RAM trên Mac | |
| def get_device_list(): | |
| devs = [] | |
| try: | |
| if torch.cuda.is_available(): devs += [f"cuda:{i}" for i in range(torch.cuda.device_count())] | |
| except: pass | |
| try: | |
| # Check MPS cho Mac | |
| if torch.backends.mps.is_available(): devs += ["mps"] | |
| except: pass | |
| # Nếu không có gì thì fallback về cpu | |
| if not devs: devs = ["cpu"] | |
| return devs | |
| devices = get_device_list() | |
| def model_downlod(model_name="JunhaoZhuang/FlashVSR"): | |
| model_dir = os.path.join(root, "models", model_name.split("/")[-1]) | |
| print(f"model dir: {model_dir}") | |
| if not os.path.exists(model_dir): | |
| log(f"Downloading model '{model_name}' from huggingface...", message_type="info") | |
| snapshot_download(repo_id=model_name, local_dir=model_dir, local_dir_use_symlinks=False, resume_download=True) | |
| def is_ffmpeg_available(): | |
| if shutil.which("ffmpeg") is None: | |
| log("[FlashVSR] FFmpeg not found!", message_type="warning") | |
| return False | |
| return True | |
| def tensor2video(frames: torch.Tensor): | |
| return (rearrange(frames.squeeze(0), "C F H W -> F H W C").float() + 1.0) / 2.0 | |
| def natural_key(name: str): | |
| return [int(t) if t.isdigit() else t.lower() for t in re.split(r"([0-9]+)", os.path.basename(name))] | |
| def list_images_natural(folder: str): | |
| exts = (".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG") | |
| fs = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(exts)] | |
| fs.sort(key=natural_key) | |
| return fs | |
| def largest_8n1_leq(n): return 0 if n < 1 else ((n - 1) // 8) * 8 + 1 | |
| def next_8n5(n): return 21 if n < 21 else ((n - 5 + 7) // 8) * 8 + 5 | |
| def is_video(path): return os.path.isfile(path) and path.lower().endswith((".mp4", ".mov", ".avi", ".mkv")) | |
| def save_video(frames, save_path, fps=30, quality=5): | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| frames_np = (frames.cpu().float() * 255.0).clip(0, 255).numpy().astype(np.uint8) | |
| w = imageio.get_writer(save_path, fps=fps, quality=quality) | |
| for frame_np in tqdm(frames_np, desc=f"[FlashVSR] Saving video"): w.append_data(frame_np) | |
| w.close() | |
| def merge_video_with_audio(video_path, audio_source_path): | |
| temp_path = video_path + "temp.mp4" | |
| if os.path.isdir(audio_source_path) or not is_ffmpeg_available(): | |
| log(f"[FlashVSR] Output video saved to '{video_path}'", message_type="info") | |
| return | |
| try: | |
| if not [s for s in ffmpeg.probe(audio_source_path)["streams"] if s["codec_type"] == "audio"]: return | |
| log("[FlashVSR] Copying audio tracks...") | |
| os.rename(video_path, temp_path) | |
| ffmpeg.output(ffmpeg.input(temp_path)["v"], ffmpeg.input(audio_source_path)["a"], video_path, vcodec="copy", acodec="copy").run(overwrite_output=True, quiet=True) | |
| log(f"[FlashVSR] Output video saved to '{video_path}'", message_type="info") | |
| except Exception as e: | |
| print("[ERROR] FFmpeg error:", e) | |
| log(f"[FlashVSR] Audio merge failed.", message_type="warning") | |
| finally: | |
| if os.path.exists(temp_path): os.remove(temp_path) | |
| def compute_scaled_and_target_dims(w0: int, h0: int, scale: int = 4, multiple: int = 128): | |
| if w0 <= 0 or h0 <= 0: raise ValueError("invalid original size") | |
| sW, sH = w0 * scale, h0 * scale | |
| tW, tH = max(multiple, (sW // multiple) * multiple), max(multiple, (sH // multiple) * multiple) | |
| return sW, sH, tW, tH | |
| def tensor_upscale_then_center_crop(frame_tensor: torch.Tensor, scale: int, tW: int, tH: int) -> torch.Tensor: | |
| h0, w0, c = frame_tensor.shape | |
| tensor_bchw = frame_tensor.permute(2, 0, 1).unsqueeze(0) | |
| sW, sH = w0 * scale, h0 * scale | |
| upscaled_tensor = F.interpolate(tensor_bchw, size=(sH, sW), mode="bicubic", align_corners=False) | |
| l, t = max(0, (sW - tW) // 2), max(0, (sH - tH) // 2) | |
| return upscaled_tensor[:, :, t : t + tH, l : l + tW].squeeze(0) | |
| def prepare_tensors(path: str, dtype=torch.bfloat16): | |
| if os.path.isdir(path): | |
| paths0 = list_images_natural(path) | |
| if not paths0: raise FileNotFoundError(f"No images in {path}") | |
| with Image.open(paths0[0]) as _img0: w0, h0 = _img0.size | |
| frames = [torch.from_numpy(np.array(Image.open(p).convert("RGB")).astype(np.float32) / 255.0).to(dtype) for p in paths0] | |
| return torch.stack(frames, 0), 30 | |
| if is_video(path): | |
| rdr = imageio.get_reader(path) | |
| meta = rdr.get_meta_data() if hasattr(rdr, "get_meta_data") else {} | |
| fps_val = meta.get("fps", 30) | |
| fps = int(round(fps_val)) if isinstance(fps_val, (int, float)) else 30 | |
| frames = [torch.from_numpy(frame.astype(np.float32) / 255.0).to(dtype) for frame in rdr] | |
| rdr.close() | |
| return torch.stack(frames, 0), fps | |
| raise ValueError(f"Unsupported input: {path}") | |
| def get_input_params(image_tensor, scale): | |
| N0, h0, w0, _ = image_tensor.shape | |
| sW, sH, tW, tH = compute_scaled_and_target_dims(w0, h0, scale=scale, multiple=128) | |
| F = largest_8n1_leq(N0 + 4) | |
| if F == 0: raise RuntimeError(f"Not enough frames after padding. Got {N0 + 4}.") | |
| return tH, tW, F | |
| def input_tensor_generator(image_tensor: torch.Tensor, device, scale: int = 4, dtype=torch.bfloat16): | |
| N0, h0, w0, _ = image_tensor.shape | |
| tH, tW, F = get_input_params(image_tensor, scale) | |
| for i in range(F): | |
| frame_idx = min(i, N0 - 1) | |
| tensor_out = (tensor_upscale_then_center_crop(image_tensor[frame_idx].to(device), scale=scale, tW=tW, tH=tH) * 2.0 - 1.0) | |
| yield tensor_out.to("cpu").to(dtype) | |
| def prepare_input_tensor(image_tensor: torch.Tensor, device, scale: int = 4, dtype=torch.bfloat16): | |
| N0, h0, w0, _ = image_tensor.shape | |
| sW, sH, tW, tH = compute_scaled_and_target_dims(w0, h0, scale=scale, multiple=128) | |
| F = largest_8n1_leq(N0 + 4) | |
| if F == 0: raise RuntimeError(f"Not enough frames after padding. Got {N0 + 4}.") | |
| frames = [] | |
| for i in range(F): | |
| frame_idx = min(i, N0 - 1) | |
| tensor_out = (tensor_upscale_then_center_crop(image_tensor[frame_idx].to(device), scale=scale, tW=tW, tH=tH) * 2.0 - 1.0) | |
| frames.append(tensor_out.to("cpu").to(dtype)) | |
| vid_final = torch.stack(frames, 0).permute(1, 0, 2, 3).unsqueeze(0) | |
| del frames | |
| clean_vram() | |
| return vid_final, tH, tW, F | |
| def calculate_tile_coords(height, width, tile_size, overlap): | |
| coords = [] | |
| stride = tile_size - overlap | |
| num_rows, num_cols = math.ceil((height - overlap) / stride), math.ceil((width - overlap) / stride) | |
| for r in range(num_rows): | |
| for c in range(num_cols): | |
| y1, x1 = r * stride, c * stride | |
| y2, x2 = min(y1 + tile_size, height), min(x1 + tile_size, width) | |
| if y2 - y1 < tile_size: y1 = max(0, y2 - tile_size) | |
| if x2 - x1 < tile_size: x1 = max(0, x2 - tile_size) | |
| coords.append((x1, y1, x2, y2)) | |
| return coords | |
| def create_feather_mask(size, overlap): | |
| H, W = size | |
| mask = torch.ones(1, 1, H, W) | |
| ramp = torch.linspace(0, 1, overlap) | |
| mask[:, :, :, :overlap] = torch.minimum(mask[:, :, :, :overlap], ramp.view(1, 1, 1, -1)) | |
| mask[:, :, :, -overlap:] = torch.minimum(mask[:, :, :, -overlap:], ramp.flip(0).view(1, 1, 1, -1)) | |
| mask[:, :, :overlap, :] = torch.minimum(mask[:, :, :overlap, :], ramp.view(1, 1, -1, 1)) | |
| mask[:, :, -overlap:, :] = torch.minimum(mask[:, :, -overlap:, :], ramp.flip(0).view(1, 1, -1, 1)) | |
| return mask | |
| def stitch_video_tiles(tile_paths, tile_coords, final_dims, scale, overlap, output_path, fps, quality, cleanup=True, chunk_size=40): | |
| if not tile_paths: return | |
| final_W, final_H = final_dims | |
| readers = [imageio.get_reader(p) for p in tile_paths] | |
| try: | |
| num_frames = readers[0].count_frames() | |
| if num_frames is None or num_frames <= 0: num_frames = len([_ for _ in readers[0]]); readers = [imageio.get_reader(p) for p in tile_paths] | |
| with imageio.get_writer(output_path, fps=fps, quality=quality) as writer: | |
| for start_frame in tqdm(range(0, num_frames, chunk_size), desc="[FlashVSR] Stitching Chunks"): | |
| end_frame = min(start_frame + chunk_size, num_frames) | |
| current_chunk_size = end_frame - start_frame | |
| chunk_canvas = np.zeros((current_chunk_size, final_H, final_W, 3), dtype=np.float32) | |
| weight_canvas = np.zeros_like(chunk_canvas, dtype=np.float32) | |
| for i, reader in enumerate(readers): | |
| try: | |
| tile_chunk_frames = [frame.astype(np.float32) / 255.0 for idx, frame in enumerate(reader.iter_data()) if start_frame <= idx < end_frame] | |
| tile_chunk_np = np.stack(tile_chunk_frames, axis=0) | |
| except Exception as e: continue | |
| if tile_chunk_np.shape[0] != current_chunk_size: continue | |
| tile_H, tile_W, _ = tile_chunk_np.shape[1:] | |
| ramp = np.linspace(0, 1, overlap * scale, dtype=np.float32) | |
| mask = np.ones((tile_H, tile_W, 1), dtype=np.float32) | |
| mask[:, : overlap * scale, :] *= ramp[np.newaxis, :, np.newaxis] | |
| mask[:, -overlap * scale :, :] *= np.flip(ramp)[np.newaxis, :, np.newaxis] | |
| mask[: overlap * scale, :, :] *= ramp[:, np.newaxis, np.newaxis] | |
| mask[-overlap * scale :, :, :] *= np.flip(ramp)[:, np.newaxis, np.newaxis] | |
| mask_4d = mask[np.newaxis, :, :, :] | |
| x1_orig, y1_orig, _, _ = tile_coords[i] | |
| out_y1, out_x1 = y1_orig * scale, x1_orig * scale | |
| out_y2, out_x2 = out_y1 + tile_H, out_x1 + tile_W | |
| chunk_canvas[:, out_y1:out_y2, out_x1:out_x2, :] += (tile_chunk_np * mask_4d) | |
| weight_canvas[:, out_y1:out_y2, out_x1:out_x2, :] += mask_4d | |
| weight_canvas[weight_canvas == 0] = 1.0 | |
| stitched_chunk = chunk_canvas / weight_canvas | |
| for frame_idx_in_chunk in range(current_chunk_size): | |
| writer.append_data((np.clip(stitched_chunk[frame_idx_in_chunk], 0, 1) * 255).astype(np.uint8)) | |
| finally: | |
| for reader in readers: reader.close() | |
| if cleanup: | |
| for path in tile_paths: | |
| try: os.remove(path) | |
| except: pass | |
| # ============================================================================== | |
| # Initialization & Main | |
| # ============================================================================== | |
| def init_pipeline(version, mode, device, dtype): | |
| model = "FlashVSR" if version == "10" else "FlashVSR-v1.1" | |
| model_downlod(model_name="JunhaoZhuang/" + model) | |
| model_path = os.path.join(root, "models", model) | |
| if not os.path.exists(model_path): raise RuntimeError(f'Model directory "{model_path}" does not exist!') | |
| ckpt_path = os.path.join(model_path, "diffusion_pytorch_model_streaming_dmd.safetensors") | |
| vae_path = os.path.join(model_path, "Wan2.1_VAE.pth") | |
| lq_path = os.path.join(model_path, "LQ_proj_in.ckpt") | |
| tcd_path = os.path.join(model_path, "TCDecoder.ckpt") | |
| prompt_path = os.path.join(root, "posi_prompt.pth") | |
| if not all(os.path.exists(p) for p in [ckpt_path, vae_path, lq_path, tcd_path]): raise RuntimeError("Missing weights!") | |
| mm = ModelManager(torch_dtype=dtype, device="cpu") | |
| if mode == "full": | |
| mm.load_models([ckpt_path, vae_path]) | |
| pipe = FlashVSRFullPipeline.from_model_manager(mm, device=device) | |
| pipe.vae.model.encoder = None | |
| pipe.vae.model.conv1 = None | |
| else: | |
| mm.load_models([ckpt_path]) | |
| if mode == "tiny": pipe = FlashVSRTinyPipeline.from_model_manager(mm, device=device) | |
| else: pipe = FlashVSRTinyLongPipeline.from_model_manager(mm, device=device) | |
| pipe.TCDecoder = build_tcdecoder(new_channels=[512, 256, 128, 128], device=device, dtype=dtype, new_latent_channels=16 + 768) | |
| pipe.TCDecoder.load_state_dict(torch.load(tcd_path, map_location=device), strict=False) | |
| pipe.TCDecoder.clean_mem() | |
| if model == "FlashVSR": pipe.denoising_model().LQ_proj_in = Buffer_LQ4x_Proj(in_dim=3, out_dim=1536, layer_num=1).to(device, dtype=dtype) | |
| else: pipe.denoising_model().LQ_proj_in = Causal_LQ4x_Proj(in_dim=3, out_dim=1536, layer_num=1).to(device, dtype=dtype) | |
| pipe.denoising_model().LQ_proj_in.load_state_dict(torch.load(lq_path, map_location="cpu"), strict=True) | |
| pipe.denoising_model().LQ_proj_in.to(device) | |
| pipe.to(device, dtype=dtype) | |
| pipe.enable_vram_management(num_persistent_param_in_dit=None) | |
| pipe.init_cross_kv(prompt_path=prompt_path) | |
| pipe.load_models_to_device(["dit", "vae"]) | |
| return pipe | |
| def main(input, version, mode, scale, color_fix, tiled_vae, tiled_dit, tile_size, tile_overlap, unload_dit, dtype, sparse_ratio=2, kv_ratio=3, local_range=11, seed=0, device="auto", quality=6, output=None): | |
| _device = device | |
| if device == "auto": _device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else device | |
| if _device == "auto" or _device not in devices: raise RuntimeError("No devices found!") | |
| if _device.startswith("cuda"): torch.cuda.set_device(_device) | |
| if tiled_dit and (tile_overlap > tile_size / 2): raise ValueError('tile_overlap must be < tile_size / 2') | |
| _frames, fps = prepare_tensors(input, dtype=dtype) | |
| _fps = fps if is_video(input) else args.fps | |
| add = next_8n5(_frames.shape[0]) - _frames.shape[0] | |
| padding_frames = _frames[-1:, :, :, :].repeat(add, 1, 1, 1) | |
| frames = torch.cat([_frames, padding_frames], dim=0) | |
| frame_count = _frames.shape[0] | |
| del _frames | |
| clean_vram() | |
| log("[FlashVSR] Preparing frames...", message_type="finish") | |
| if tiled_dit: | |
| N, H, W, C = frames.shape | |
| if mode == "tiny-long": | |
| local_temp = os.path.join(temp, str(uuid.uuid4())) | |
| os.makedirs(local_temp, exist_ok=True) | |
| else: | |
| final_output_canvas = torch.zeros((largest_8n1_leq(N + 4) - 4, H * scale, W * scale, C), dtype=dtype, device="cpu") | |
| weight_sum_canvas = torch.zeros_like(final_output_canvas) | |
| tile_coords = calculate_tile_coords(H, W, tile_size, tile_overlap) | |
| temp_videos = [] | |
| pipe = init_pipeline(version, mode, _device, dtype) | |
| for i, (x1, y1, x2, y2) in enumerate(tile_coords): | |
| input_tile = frames[:, y1:y2, x1:x2, :] | |
| if mode == "tiny-long": | |
| temp_name = os.path.join(local_temp, f"{i+1:05d}.mp4") | |
| th, tw, F = get_input_params(input_tile, scale=scale) | |
| LQ_tile = input_tensor_generator(input_tile, _device, scale=scale, dtype=dtype) | |
| else: | |
| LQ_tile, th, tw, F = prepare_input_tensor(input_tile, _device, scale=scale, dtype=dtype) | |
| LQ_tile = LQ_tile.to(_device) | |
| if i == 0: log(f"[FlashVSR] Processing {frame_count} frames...", message_type="info") | |
| log(f"[FlashVSR] Processing tile {i+1}/{len(tile_coords)}: ({x1},{y1}) to ({x2},{y2})", message_type="info") | |
| output_tile_gpu = pipe( | |
| prompt="", negative_prompt="", cfg_scale=1.0, num_inference_steps=1, seed=seed, tiled=tiled_vae, LQ_video=LQ_tile, | |
| num_frames=F, height=th, width=tw, is_full_block=False, if_buffer=True, topk_ratio=sparse_ratio * 768 * 1280 / (th * tw), | |
| kv_ratio=kv_ratio, local_range=local_range, color_fix=color_fix, unload_dit=unload_dit, fps=_fps, quality=10, | |
| output_path=temp_name if mode=="tiny-long" else None, tiled_dit=True, | |
| ) | |
| if mode == "tiny-long": | |
| temp_videos.append(temp_name) | |
| del LQ_tile, input_tile | |
| clean_vram() | |
| continue | |
| processed_tile_cpu = tensor2video(output_tile_gpu).to("cpu") | |
| mask_nchw = create_feather_mask((processed_tile_cpu.shape[1], processed_tile_cpu.shape[2]), tile_overlap * scale).to("cpu") | |
| mask_nhwc = mask_nchw.permute(0, 2, 3, 1) | |
| out_x1, out_y1 = x1 * scale, y1 * scale | |
| out_x2, out_y2 = out_x1 + processed_tile_cpu.shape[2], out_y1 + processed_tile_cpu.shape[1] | |
| final_output_canvas[:, out_y1:out_y2, out_x1:out_x2, :] += (processed_tile_cpu * mask_nhwc) | |
| weight_sum_canvas[:, out_y1:out_y2, out_x1:out_x2, :] += mask_nhwc | |
| del LQ_tile, output_tile_gpu, processed_tile_cpu, input_tile | |
| clean_vram() | |
| if mode == "tiny-long": | |
| stitch_video_tiles(temp_videos, tile_coords, (W * scale, H * scale), scale, tile_overlap, output, _fps, quality, cleanup=True) | |
| shutil.rmtree(local_temp) | |
| else: | |
| weight_sum_canvas[weight_sum_canvas == 0] = 1.0 | |
| final_output = final_output_canvas / weight_sum_canvas | |
| else: | |
| if mode == "tiny-long": | |
| th, tw, F = get_input_params(frames, scale=scale) | |
| LQ = input_tensor_generator(frames, _device, scale=scale, dtype=dtype) | |
| else: | |
| LQ, th, tw, F = prepare_input_tensor(frames, _device, scale=scale, dtype=dtype) | |
| LQ = LQ.to(_device) | |
| pipe = init_pipeline(version, mode, _device, dtype) | |
| log(f"[FlashVSR] Processing {frame_count} frames...", message_type="info") | |
| video = pipe( | |
| prompt="", negative_prompt="", cfg_scale=1.0, num_inference_steps=1, seed=seed, tiled=tiled_vae, LQ_video=LQ, | |
| num_frames=F, height=th, width=tw, is_full_block=False, if_buffer=True, topk_ratio=sparse_ratio * 768 * 1280 / (th * tw), | |
| kv_ratio=kv_ratio, local_range=local_range, color_fix=color_fix, unload_dit=unload_dit, fps=_fps, output_path=output, tiled_dit=True, | |
| ) | |
| if mode == "tiny-long": | |
| del pipe, LQ | |
| clean_vram() | |
| return video, _fps | |
| final_output = tensor2video(video).to("cpu") | |
| del pipe, video, LQ | |
| clean_vram() | |
| return final_output[:frame_count, :, :, :], fps | |
| if __name__ == "__main__": | |
| dtype_map = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} | |
| dtype = dtype_map.get(args.dtype, torch.bfloat16) | |
| if args.attention == "sage": USE_BLOCK_ATTN = False | |
| else: USE_BLOCK_ATTN = True | |
| if os.path.exists(temp): shutil.rmtree(temp) | |
| os.makedirs(temp, exist_ok=True) | |
| final = "/content/output.mp4" | |
| result, fps = main( | |
| args.input, args.version, args.mode, args.scale, args.color_fix, args.tiled_vae, args.tiled_dit, args.tile_size, args.overlap, | |
| args.unload_dit, dtype, seed=args.seed, device=args.device, quality=args.quality, output=final, | |
| ) | |
| if args.mode != "tiny-long": save_video(result, final, fps=fps, quality=args.quality) | |
| merge_video_with_audio(final, args.input) | |
| log("[FlashVSR] Done.", message_type="finish") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment