Last active
December 24, 2025 03:45
-
-
Save tamnguyenvan/505041d9a124c767080706e3851e6aad to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from dataclasses import dataclass | |
| from tqdm import tqdm | |
| import os | |
| import re | |
| import math | |
| import uuid | |
| import torch | |
| import shutil | |
| import imageio | |
| import ffmpeg | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import numpy as np | |
| from PIL import Image | |
| from einops import rearrange | |
| from huggingface_hub import snapshot_download | |
| import os, torch, json, importlib | |
| from typing import List | |
| from typing_extensions import Literal, TypeAlias | |
| import types | |
| import os | |
| import time | |
| from typing import Optional, Tuple, Literal | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from einops import rearrange | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from einops import rearrange, repeat | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from tqdm import tqdm | |
| import torch | |
| import gc | |
| import numpy as np | |
| from PIL import Image | |
| from torchvision.transforms import GaussianBlur | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| import random | |
| import os | |
| import time | |
| from typing import Tuple, Optional, List | |
| from einops import rearrange | |
| import torch, math | |
| import triton | |
| import triton.language as tl | |
| import torch.nn.functional as F | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from tqdm.auto import tqdm | |
| from collections import namedtuple | |
| from einops import rearrange | |
| import torch.nn.init as init | |
| import torch, os, gc | |
| from safetensors import safe_open | |
| from contextlib import contextmanager | |
| from einops import rearrange, repeat | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import time | |
| import hashlib | |
| import types | |
| from collections import deque | |
| import numpy as np | |
| @dataclass | |
| class Args: | |
| input: str = "/root/example0.mp4" | |
| output_folder: str = "./outputs" | |
| scale: int = 4 | |
| version: str = "10" | |
| mode: str = "tiny" | |
| tiled_vae: bool = False | |
| tiled_dit: bool = False | |
| tile_size: int = 256 | |
| overlap: int = 24 | |
| unload_dit: bool = False | |
| color_fix: bool = False | |
| seed: int = 0 | |
| dtype: str = "bf16" | |
| device: str = "auto" | |
| fps: int = 30 | |
| quality: int = 6 | |
| attention: str = "sage" | |
| args = Args() | |
| def log(message: str, message_type: str = "normal"): | |
| if message_type == "error": | |
| message = "\033[1;41m" + message + "\033[m" | |
| elif message_type == "warning": | |
| message = "\033[1;31m" + message + "\033[m" | |
| elif message_type == "finish": | |
| message = "\033[1;32m" + message + "\033[m" | |
| elif message_type == "info": | |
| message = "\033[1;33m" + message + "\033[m" | |
| else: | |
| message = message | |
| print(f"{message}") | |
| # ---------------------------- | |
| # WanModel (no image branch) — init 时即产生 KV 缓存 | |
| # ---------------------------- | |
| class WanModel(torch.nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| in_dim: int, | |
| ffn_dim: int, | |
| out_dim: int, | |
| text_dim: int, | |
| freq_dim: int, | |
| eps: float, | |
| patch_size: Tuple[int, int, int], | |
| num_heads: int, | |
| num_layers: int, | |
| # init_context: torch.Tensor, # <<<< 必填:在 __init__ 里用它生成 cross-attn KV 缓存 | |
| has_image_input: bool = False, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.freq_dim = freq_dim | |
| self.patch_size = patch_size | |
| # patch embed | |
| self.patch_embedding = nn.Conv3d( | |
| in_dim, dim, kernel_size=patch_size, stride=patch_size | |
| ) | |
| # text / time embed | |
| self.text_embedding = nn.Sequential( | |
| nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim) | |
| ) | |
| self.time_embedding = nn.Sequential( | |
| nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim) | |
| ) | |
| self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) | |
| # blocks | |
| self.blocks = nn.ModuleList( | |
| [DiTBlock(dim, num_heads, ffn_dim, eps) for _ in range(num_layers)] | |
| ) | |
| self.head = Head(dim, out_dim, patch_size, eps) | |
| head_dim = dim // num_heads | |
| self.freqs = precompute_freqs_cis_3d(head_dim) | |
| self._cross_kv_initialized = False | |
| # 可选:手动清空 / 重新初始化 | |
| def clear_cross_kv(self): | |
| for blk in self.blocks: | |
| blk.cross_attn.clear_cache() | |
| self._cross_kv_initialized = False | |
| @torch.no_grad() | |
| def reinit_cross_kv(self, new_context: torch.Tensor): | |
| ctx_txt = self.text_embedding(new_context) | |
| for blk in self.blocks: | |
| blk.cross_attn.init_cache(ctx_txt) | |
| self._cross_kv_initialized = True | |
| def patchify(self, x: torch.Tensor): | |
| x = self.patch_embedding(x) | |
| grid_size = x.shape[2:] | |
| x = rearrange(x, "b c f h w -> b (f h w) c").contiguous() | |
| return x, grid_size # x, grid_size: (f, h, w) | |
| def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): | |
| return rearrange( | |
| x, | |
| "b (f h w) (x y z c) -> b c (f x) (h y) (w z)", | |
| f=grid_size[0], | |
| h=grid_size[1], | |
| w=grid_size[2], | |
| x=self.patch_size[0], | |
| y=self.patch_size[1], | |
| z=self.patch_size[2], | |
| ) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| timestep: torch.Tensor, | |
| context: torch.Tensor, | |
| use_gradient_checkpointing: bool = False, | |
| use_gradient_checkpointing_offload: bool = False, | |
| LQ_latents: Optional[List[torch.Tensor]] = None, | |
| train_img: bool = False, | |
| topk_ratio: Optional[float] = None, | |
| kv_ratio: Optional[float] = None, | |
| local_num: Optional[int] = None, | |
| is_full_block: bool = False, | |
| causal_idx: Optional[int] = None, | |
| **kwargs, | |
| ): | |
| # time / text embeds | |
| t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) | |
| t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) | |
| # 这里仍会嵌入 text(CrossAttention 若已有缓存会忽略它) | |
| # context = self.text_embedding(context) | |
| # 输入打补丁 | |
| x, (f, h, w) = self.patchify(x) | |
| B = x.shape[0] | |
| # window / masks 超参 | |
| win = (2, 8, 8) | |
| seqlen = f // win[0] | |
| if local_num is None: | |
| local_random = random.random() | |
| if local_random < 0.3: | |
| local_num = seqlen - 3 | |
| elif local_random < 0.4: | |
| local_num = seqlen - 4 | |
| elif local_random < 0.5: | |
| local_num = seqlen - 2 | |
| else: | |
| local_num = seqlen | |
| window_size = win[0] * h * w // 128 | |
| square_num = window_size * window_size | |
| topk_ratio = 2.0 | |
| topk = min(max(int(square_num * topk_ratio), 1), int(square_num * seqlen) - 1) | |
| if kv_ratio is None: | |
| kv_ratio = (random.uniform(0.0, 1.0) ** 2) * (local_num - 2 - 2) + 2 | |
| kv_len = min(max(int(window_size * kv_ratio), 1), int(window_size * seqlen) - 1) | |
| decay_ratio = random.uniform(0.7, 1.0) | |
| # RoPE 3D | |
| freqs = ( | |
| torch.cat( | |
| [ | |
| self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), | |
| self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), | |
| self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), | |
| ], | |
| dim=-1, | |
| ) | |
| .reshape(f * h * w, 1, -1) | |
| .to(x.device) | |
| ) | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module(*inputs) | |
| return custom_forward | |
| # blocks | |
| for block_id, block in enumerate(self.blocks): | |
| if LQ_latents is not None and block_id < len(LQ_latents): | |
| x += LQ_latents[block_id] | |
| if self.training and use_gradient_checkpointing: | |
| if use_gradient_checkpointing_offload: | |
| with torch.autograd.graph.save_on_cpu(): | |
| x = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(block), | |
| x, | |
| context, | |
| t_mod, | |
| freqs, | |
| f, | |
| h, | |
| w, | |
| local_num, | |
| topk, | |
| train_img, | |
| block_id, | |
| kv_len, | |
| is_full_block, | |
| False, | |
| None, | |
| None, | |
| use_reentrant=False, | |
| ) | |
| else: | |
| x = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(block), | |
| x, | |
| context, | |
| t_mod, | |
| freqs, | |
| f, | |
| h, | |
| w, | |
| local_num, | |
| topk, | |
| train_img, | |
| block_id, | |
| kv_len, | |
| is_full_block, | |
| False, | |
| None, | |
| None, | |
| use_reentrant=False, | |
| ) | |
| else: | |
| x = block( | |
| x, | |
| context, | |
| t_mod, | |
| freqs, | |
| f, | |
| h, | |
| w, | |
| local_num, | |
| topk, | |
| train_img, | |
| block_id, | |
| kv_len, | |
| is_full_block, | |
| False, | |
| None, | |
| None, | |
| ) | |
| x = self.head(x, t) | |
| x = self.unpatchify(x, (f, h, w)) | |
| return x | |
| @staticmethod | |
| def state_dict_converter(): | |
| return WanModelStateDictConverter() | |
| class WanVideoVAE(nn.Module): | |
| def __init__(self, z_dim=16, dim=96): | |
| super().__init__() | |
| mean = [ | |
| -0.7571, | |
| -0.7089, | |
| -0.9113, | |
| 0.1075, | |
| -0.1745, | |
| 0.9653, | |
| -0.1517, | |
| 1.5508, | |
| 0.4134, | |
| -0.0715, | |
| 0.5517, | |
| -0.3632, | |
| -0.1922, | |
| -0.9497, | |
| 0.2503, | |
| -0.2921, | |
| ] | |
| std = [ | |
| 2.8184, | |
| 1.4541, | |
| 2.3275, | |
| 2.6558, | |
| 1.2196, | |
| 1.7708, | |
| 2.6052, | |
| 2.0743, | |
| 3.2687, | |
| 2.1526, | |
| 2.8652, | |
| 1.5579, | |
| 1.6382, | |
| 1.1253, | |
| 2.8251, | |
| 1.9160, | |
| ] | |
| self.mean = torch.tensor(mean) | |
| self.std = torch.tensor(std) | |
| self.scale = [self.mean, 1.0 / self.std] | |
| # init model | |
| self.model = VideoVAE_(z_dim=z_dim, dim=dim).eval().requires_grad_(False) | |
| self.upsampling_factor = 8 | |
| def build_1d_mask(self, length, left_bound, right_bound, border_width): | |
| x = torch.ones((length,)) | |
| if not left_bound: | |
| x[:border_width] = (torch.arange(border_width) + 1) / border_width | |
| if not right_bound: | |
| x[-border_width:] = torch.flip( | |
| (torch.arange(border_width) + 1) / border_width, dims=(0,) | |
| ) | |
| return x | |
| def build_mask(self, data, is_bound, border_width): | |
| _, _, _, H, W = data.shape | |
| h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0]) | |
| w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1]) | |
| h = repeat(h, "H -> H W", H=H, W=W) | |
| w = repeat(w, "W -> H W", H=H, W=W) | |
| mask = torch.stack([h, w]).min(dim=0).values | |
| mask = rearrange(mask, "H W -> 1 1 1 H W") | |
| return mask | |
| def tiled_decode(self, hidden_states, device, tile_size, tile_stride): | |
| _, _, T, H, W = hidden_states.shape | |
| size_h, size_w = tile_size | |
| stride_h, stride_w = tile_stride | |
| # Split tasks | |
| tasks = [] | |
| for h in range(0, H, stride_h): | |
| if h - stride_h >= 0 and h - stride_h + size_h >= H: | |
| continue | |
| for w in range(0, W, stride_w): | |
| if w - stride_w >= 0 and w - stride_w + size_w >= W: | |
| continue | |
| h_, w_ = h + size_h, w + size_w | |
| tasks.append((h, h_, w, w_)) | |
| data_device = "cpu" | |
| computation_device = device | |
| out_T = T * 4 - 3 | |
| weight = torch.zeros( | |
| (1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), | |
| dtype=hidden_states.dtype, | |
| device=data_device, | |
| ) | |
| values = torch.zeros( | |
| (1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), | |
| dtype=hidden_states.dtype, | |
| device=data_device, | |
| ) | |
| for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"): | |
| hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to( | |
| computation_device | |
| ) | |
| hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to( | |
| data_device | |
| ) | |
| mask = self.build_mask( | |
| hidden_states_batch, | |
| is_bound=(h == 0, h_ >= H, w == 0, w_ >= W), | |
| border_width=( | |
| (size_h - stride_h) * self.upsampling_factor, | |
| (size_w - stride_w) * self.upsampling_factor, | |
| ), | |
| ).to(dtype=hidden_states.dtype, device=data_device) | |
| target_h = h * self.upsampling_factor | |
| target_w = w * self.upsampling_factor | |
| values[ | |
| :, | |
| :, | |
| :, | |
| target_h : target_h + hidden_states_batch.shape[3], | |
| target_w : target_w + hidden_states_batch.shape[4], | |
| ] += ( | |
| hidden_states_batch * mask | |
| ) | |
| weight[ | |
| :, | |
| :, | |
| :, | |
| target_h : target_h + hidden_states_batch.shape[3], | |
| target_w : target_w + hidden_states_batch.shape[4], | |
| ] += mask | |
| values = values / weight | |
| values = values.clamp_(-1, 1) | |
| return values | |
| def tiled_encode(self, video, device, tile_size, tile_stride): | |
| _, _, T, H, W = video.shape | |
| size_h, size_w = tile_size | |
| stride_h, stride_w = tile_stride | |
| # Split tasks | |
| tasks = [] | |
| for h in range(0, H, stride_h): | |
| if h - stride_h >= 0 and h - stride_h + size_h >= H: | |
| continue | |
| for w in range(0, W, stride_w): | |
| if w - stride_w >= 0 and w - stride_w + size_w >= W: | |
| continue | |
| h_, w_ = h + size_h, w + size_w | |
| tasks.append((h, h_, w, w_)) | |
| data_device = "cpu" | |
| computation_device = device | |
| out_T = (T + 3) // 4 | |
| weight = torch.zeros( | |
| (1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), | |
| dtype=video.dtype, | |
| device=data_device, | |
| ) | |
| values = torch.zeros( | |
| (1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), | |
| dtype=video.dtype, | |
| device=data_device, | |
| ) | |
| for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"): | |
| hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device) | |
| hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to( | |
| data_device | |
| ) | |
| mask = self.build_mask( | |
| hidden_states_batch, | |
| is_bound=(h == 0, h_ >= H, w == 0, w_ >= W), | |
| border_width=( | |
| (size_h - stride_h) // self.upsampling_factor, | |
| (size_w - stride_w) // self.upsampling_factor, | |
| ), | |
| ).to(dtype=video.dtype, device=data_device) | |
| target_h = h // self.upsampling_factor | |
| target_w = w // self.upsampling_factor | |
| values[ | |
| :, | |
| :, | |
| :, | |
| target_h : target_h + hidden_states_batch.shape[3], | |
| target_w : target_w + hidden_states_batch.shape[4], | |
| ] += ( | |
| hidden_states_batch * mask | |
| ) | |
| weight[ | |
| :, | |
| :, | |
| :, | |
| target_h : target_h + hidden_states_batch.shape[3], | |
| target_w : target_w + hidden_states_batch.shape[4], | |
| ] += mask | |
| values = values / weight | |
| return values | |
| def single_encode(self, video, device): | |
| video = video.to(device) | |
| x = self.model.encode(video, self.scale) | |
| return x | |
| def single_decode(self, hidden_state, device): | |
| hidden_state = hidden_state.to(device) | |
| video = self.model.decode(hidden_state, self.scale) | |
| return video.clamp_(-1, 1) | |
| def encode( | |
| self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16) | |
| ): | |
| videos = [video.to("cpu") for video in videos] | |
| hidden_states = [] | |
| for video in videos: | |
| video = video.unsqueeze(0) | |
| if tiled: | |
| tile_size = (tile_size[0] * 8, tile_size[1] * 8) | |
| tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8) | |
| hidden_state = self.tiled_encode(video, device, tile_size, tile_stride) | |
| else: | |
| hidden_state = self.single_encode(video, device) | |
| hidden_state = hidden_state.squeeze(0) | |
| hidden_states.append(hidden_state) | |
| hidden_states = torch.stack(hidden_states) | |
| return hidden_states | |
| def decode( | |
| self, | |
| hidden_states, | |
| device, | |
| tiled=False, | |
| tile_size=(34, 34), | |
| tile_stride=(18, 16), | |
| ): | |
| hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states] | |
| videos = [] | |
| for hidden_state in hidden_states: | |
| hidden_state = hidden_state.unsqueeze(0) | |
| if tiled: | |
| video = self.tiled_decode(hidden_state, device, tile_size, tile_stride) | |
| else: | |
| video = self.single_decode(hidden_state, device) | |
| video = video.squeeze(0) | |
| videos.append(video) | |
| videos = torch.stack(videos) | |
| return videos | |
| def clear_cache(self): | |
| self.model.clear_cache() | |
| def stream_decode( | |
| self, hidden_states, tiled=False, tile_size=(34, 34), tile_stride=(18, 16) | |
| ): | |
| hidden_states = [hidden_state for hidden_state in hidden_states] | |
| assert len(hidden_states) == 1 | |
| hidden_state = hidden_states[0] | |
| video = self.model.stream_decode(hidden_state, self.scale) | |
| return video | |
| @staticmethod | |
| def state_dict_converter(): | |
| return WanVideoVAEStateDictConverter() | |
| model_loader_configs = [ | |
| # These configs are provided for detecting model type automatically. | |
| # The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource) | |
| ( | |
| None, | |
| "9269f8db9040a9d860eaca435be61814", | |
| ["wan_video_dit"], | |
| [WanModel], | |
| "civitai", | |
| ), | |
| ( | |
| None, | |
| "aafcfd9672c3a2456dc46e1cb6e52c70", | |
| ["wan_video_dit"], | |
| [WanModel], | |
| "civitai", | |
| ), | |
| ( | |
| None, | |
| "6bfcfb3b342cb286ce886889d519a77e", | |
| ["wan_video_dit"], | |
| [WanModel], | |
| "civitai", | |
| ), | |
| ( | |
| None, | |
| "6d6ccde6845b95ad9114ab993d917893", | |
| ["wan_video_dit"], | |
| [WanModel], | |
| "civitai", | |
| ), | |
| ( | |
| None, | |
| "6bfcfb3b342cb286ce886889d519a77e", | |
| ["wan_video_dit"], | |
| [WanModel], | |
| "civitai", | |
| ), | |
| ( | |
| None, | |
| "349723183fc063b2bfc10bb2835cf677", | |
| ["wan_video_dit"], | |
| [WanModel], | |
| "civitai", | |
| ), | |
| ( | |
| None, | |
| "efa44cddf936c70abd0ea28b6cbe946c", | |
| ["wan_video_dit"], | |
| [WanModel], | |
| "civitai", | |
| ), | |
| ( | |
| None, | |
| "3ef3b1f8e1dab83d5b71fd7b617f859f", | |
| ["wan_video_dit"], | |
| [WanModel], | |
| "civitai", | |
| ), | |
| ( | |
| None, | |
| "cb104773c6c2cb6df4f9529ad5c60d0b", | |
| ["wan_video_dit"], | |
| [WanModel], | |
| "diffusers", | |
| ), | |
| ( | |
| None, | |
| "1378ea763357eea97acdef78e65d6d96", | |
| ["wan_video_vae"], | |
| [WanVideoVAE], | |
| "civitai", | |
| ), | |
| ( | |
| None, | |
| "ccc42284ea13e1ad04693284c7a09be6", | |
| ["wan_video_vae"], | |
| [WanVideoVAE], | |
| "civitai", | |
| ), | |
| ] | |
| huggingface_model_loader_configs = [ | |
| # These configs are provided for detecting model type automatically. | |
| # The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture) | |
| ] | |
| patch_model_loader_configs = [ | |
| # These configs are provided for detecting model type automatically. | |
| # The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs) | |
| ] | |
| def load_model_from_single_file( | |
| state_dict, model_names, model_classes, model_resource, torch_dtype, device | |
| ): | |
| loaded_model_names, loaded_models = [], [] | |
| for model_name, model_class in zip(model_names, model_classes): | |
| # print(f" model_name: {model_name} model_class: {model_class.__name__}") | |
| state_dict_converter = model_class.state_dict_converter() | |
| if model_resource == "civitai": | |
| state_dict_results = state_dict_converter.from_civitai(state_dict) | |
| elif model_resource == "diffusers": | |
| state_dict_results = state_dict_converter.from_diffusers(state_dict) | |
| if isinstance(state_dict_results, tuple): | |
| model_state_dict, extra_kwargs = state_dict_results | |
| # print(f" This model is initialized with extra kwargs: {extra_kwargs}") | |
| else: | |
| model_state_dict, extra_kwargs = state_dict_results, {} | |
| torch_dtype = ( | |
| torch.float32 | |
| if extra_kwargs.get("upcast_to_float32", False) | |
| else torch_dtype | |
| ) | |
| with init_weights_on_device(): | |
| model = model_class(**extra_kwargs) | |
| if hasattr(model, "eval"): | |
| model = model.eval() | |
| model.load_state_dict(model_state_dict, assign=True) | |
| model = model.to(dtype=torch_dtype, device=device) | |
| loaded_model_names.append(model_name) | |
| loaded_models.append(model) | |
| return loaded_model_names, loaded_models | |
| def load_model_from_huggingface_folder( | |
| file_path, model_names, model_classes, torch_dtype, device | |
| ): | |
| loaded_model_names, loaded_models = [], [] | |
| for model_name, model_class in zip(model_names, model_classes): | |
| if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]: | |
| model = model_class.from_pretrained( | |
| file_path, torch_dtype=torch_dtype | |
| ).eval() | |
| else: | |
| model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype) | |
| if torch_dtype == torch.float16 and hasattr(model, "half"): | |
| model = model.half() | |
| try: | |
| model = model.to(device=device) | |
| except: | |
| pass | |
| loaded_model_names.append(model_name) | |
| loaded_models.append(model) | |
| return loaded_model_names, loaded_models | |
| def load_single_patch_model_from_single_file( | |
| state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device | |
| ): | |
| # print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}") | |
| base_state_dict = base_model.state_dict() | |
| base_model.to("cpu") | |
| del base_model | |
| model = model_class(**extra_kwargs) | |
| model.load_state_dict(base_state_dict, strict=False) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.to(dtype=torch_dtype, device=device) | |
| return model | |
| def load_patch_model_from_single_file( | |
| state_dict, | |
| model_names, | |
| model_classes, | |
| extra_kwargs, | |
| model_manager, | |
| torch_dtype, | |
| device, | |
| ): | |
| loaded_model_names, loaded_models = [], [] | |
| for model_name, model_class in zip(model_names, model_classes): | |
| while True: | |
| for model_id in range(len(model_manager.model)): | |
| base_model_name = model_manager.model_name[model_id] | |
| if base_model_name == model_name: | |
| base_model_path = model_manager.model_path[model_id] | |
| base_model = model_manager.model[model_id] | |
| print( | |
| f" Adding patch model to {base_model_name} ({base_model_path})" | |
| ) | |
| patched_model = load_single_patch_model_from_single_file( | |
| state_dict, | |
| model_name, | |
| model_class, | |
| base_model, | |
| extra_kwargs, | |
| torch_dtype, | |
| device, | |
| ) | |
| loaded_model_names.append(base_model_name) | |
| loaded_models.append(patched_model) | |
| model_manager.model.pop(model_id) | |
| model_manager.model_path.pop(model_id) | |
| model_manager.model_name.pop(model_id) | |
| break | |
| else: | |
| break | |
| return loaded_model_names, loaded_models | |
| class ModelDetectorTemplate: | |
| def __init__(self): | |
| pass | |
| def match(self, file_path="", state_dict={}): | |
| return False | |
| def load( | |
| self, | |
| file_path="", | |
| state_dict={}, | |
| device="cuda", | |
| torch_dtype=torch.float16, | |
| **kwargs, | |
| ): | |
| return [], [] | |
| class ModelDetectorFromSingleFile: | |
| def __init__(self, model_loader_configs=[]): | |
| self.keys_hash_with_shape_dict = {} | |
| self.keys_hash_dict = {} | |
| for metadata in model_loader_configs: | |
| self.add_model_metadata(*metadata) | |
| def add_model_metadata( | |
| self, | |
| keys_hash, | |
| keys_hash_with_shape, | |
| model_names, | |
| model_classes, | |
| model_resource, | |
| ): | |
| self.keys_hash_with_shape_dict[keys_hash_with_shape] = ( | |
| model_names, | |
| model_classes, | |
| model_resource, | |
| ) | |
| if keys_hash is not None: | |
| self.keys_hash_dict[keys_hash] = ( | |
| model_names, | |
| model_classes, | |
| model_resource, | |
| ) | |
| def match(self, file_path="", state_dict={}): | |
| if isinstance(file_path, str) and os.path.isdir(file_path): | |
| return False | |
| if len(state_dict) == 0: | |
| state_dict = load_state_dict(file_path) | |
| keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True) | |
| if keys_hash_with_shape in self.keys_hash_with_shape_dict: | |
| return True | |
| keys_hash = hash_state_dict_keys(state_dict, with_shape=False) | |
| if keys_hash in self.keys_hash_dict: | |
| return True | |
| return False | |
| def load( | |
| self, | |
| file_path="", | |
| state_dict={}, | |
| device="cuda", | |
| torch_dtype=torch.float16, | |
| **kwargs, | |
| ): | |
| if len(state_dict) == 0: | |
| state_dict = load_state_dict(file_path) | |
| # Load models with strict matching | |
| keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True) | |
| if keys_hash_with_shape in self.keys_hash_with_shape_dict: | |
| model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[ | |
| keys_hash_with_shape | |
| ] | |
| loaded_model_names, loaded_models = load_model_from_single_file( | |
| state_dict, | |
| model_names, | |
| model_classes, | |
| model_resource, | |
| torch_dtype, | |
| device, | |
| ) | |
| return loaded_model_names, loaded_models | |
| # Load models without strict matching | |
| # (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture) | |
| keys_hash = hash_state_dict_keys(state_dict, with_shape=False) | |
| if keys_hash in self.keys_hash_dict: | |
| model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash] | |
| loaded_model_names, loaded_models = load_model_from_single_file( | |
| state_dict, | |
| model_names, | |
| model_classes, | |
| model_resource, | |
| torch_dtype, | |
| device, | |
| ) | |
| return loaded_model_names, loaded_models | |
| return loaded_model_names, loaded_models | |
| class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile): | |
| def __init__(self, model_loader_configs=[]): | |
| super().__init__(model_loader_configs) | |
| def match(self, file_path="", state_dict={}): | |
| if isinstance(file_path, str) and os.path.isdir(file_path): | |
| return False | |
| if len(state_dict) == 0: | |
| state_dict = load_state_dict(file_path) | |
| splited_state_dict = split_state_dict_with_prefix(state_dict) | |
| for sub_state_dict in splited_state_dict: | |
| if super().match(file_path, sub_state_dict): | |
| return True | |
| return False | |
| def load( | |
| self, | |
| file_path="", | |
| state_dict={}, | |
| device="cuda", | |
| torch_dtype=torch.float16, | |
| **kwargs, | |
| ): | |
| # Split the state_dict and load from each component | |
| splited_state_dict = split_state_dict_with_prefix(state_dict) | |
| valid_state_dict = {} | |
| for sub_state_dict in splited_state_dict: | |
| if super().match(file_path, sub_state_dict): | |
| valid_state_dict.update(sub_state_dict) | |
| if super().match(file_path, valid_state_dict): | |
| loaded_model_names, loaded_models = super().load( | |
| file_path, valid_state_dict, device, torch_dtype | |
| ) | |
| else: | |
| loaded_model_names, loaded_models = [], [] | |
| for sub_state_dict in splited_state_dict: | |
| if super().match(file_path, sub_state_dict): | |
| loaded_model_names_, loaded_models_ = super().load( | |
| file_path, valid_state_dict, device, torch_dtype | |
| ) | |
| loaded_model_names += loaded_model_names_ | |
| loaded_models += loaded_models_ | |
| return loaded_model_names, loaded_models | |
| class ModelDetectorFromHuggingfaceFolder: | |
| def __init__(self, model_loader_configs=[]): | |
| self.architecture_dict = {} | |
| for metadata in model_loader_configs: | |
| self.add_model_metadata(*metadata) | |
| def add_model_metadata( | |
| self, architecture, huggingface_lib, model_name, redirected_architecture | |
| ): | |
| self.architecture_dict[architecture] = ( | |
| huggingface_lib, | |
| model_name, | |
| redirected_architecture, | |
| ) | |
| def match(self, file_path="", state_dict={}): | |
| if not isinstance(file_path, str) or os.path.isfile(file_path): | |
| return False | |
| file_list = os.listdir(file_path) | |
| if "config.json" not in file_list: | |
| return False | |
| with open(os.path.join(file_path, "config.json"), "r") as f: | |
| config = json.load(f) | |
| if "architectures" not in config and "_class_name" not in config: | |
| return False | |
| return True | |
| def load( | |
| self, | |
| file_path="", | |
| state_dict={}, | |
| device="cuda", | |
| torch_dtype=torch.float16, | |
| **kwargs, | |
| ): | |
| with open(os.path.join(file_path, "config.json"), "r") as f: | |
| config = json.load(f) | |
| loaded_model_names, loaded_models = [], [] | |
| architectures = ( | |
| config["architectures"] | |
| if "architectures" in config | |
| else [config["_class_name"]] | |
| ) | |
| for architecture in architectures: | |
| huggingface_lib, model_name, redirected_architecture = ( | |
| self.architecture_dict[architecture] | |
| ) | |
| if redirected_architecture is not None: | |
| architecture = redirected_architecture | |
| model_class = importlib.import_module(huggingface_lib).__getattribute__( | |
| architecture | |
| ) | |
| loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder( | |
| file_path, [model_name], [model_class], torch_dtype, device | |
| ) | |
| loaded_model_names += loaded_model_names_ | |
| loaded_models += loaded_models_ | |
| return loaded_model_names, loaded_models | |
| class ModelDetectorFromPatchedSingleFile: | |
| def __init__(self, model_loader_configs=[]): | |
| self.keys_hash_with_shape_dict = {} | |
| for metadata in model_loader_configs: | |
| self.add_model_metadata(*metadata) | |
| def add_model_metadata( | |
| self, keys_hash_with_shape, model_name, model_class, extra_kwargs | |
| ): | |
| self.keys_hash_with_shape_dict[keys_hash_with_shape] = ( | |
| model_name, | |
| model_class, | |
| extra_kwargs, | |
| ) | |
| def match(self, file_path="", state_dict={}): | |
| if not isinstance(file_path, str) or os.path.isdir(file_path): | |
| return False | |
| if len(state_dict) == 0: | |
| state_dict = load_state_dict(file_path) | |
| keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True) | |
| if keys_hash_with_shape in self.keys_hash_with_shape_dict: | |
| return True | |
| return False | |
| def load( | |
| self, | |
| file_path="", | |
| state_dict={}, | |
| device="cuda", | |
| torch_dtype=torch.float16, | |
| model_manager=None, | |
| **kwargs, | |
| ): | |
| if len(state_dict) == 0: | |
| state_dict = load_state_dict(file_path) | |
| # Load models with strict matching | |
| loaded_model_names, loaded_models = [], [] | |
| keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True) | |
| if keys_hash_with_shape in self.keys_hash_with_shape_dict: | |
| model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[ | |
| keys_hash_with_shape | |
| ] | |
| loaded_model_names_, loaded_models_ = load_patch_model_from_single_file( | |
| state_dict, | |
| model_names, | |
| model_classes, | |
| extra_kwargs, | |
| model_manager, | |
| torch_dtype, | |
| device, | |
| ) | |
| loaded_model_names += loaded_model_names_ | |
| loaded_models += loaded_models_ | |
| return loaded_model_names, loaded_models | |
| class ModelManager: | |
| def __init__( | |
| self, | |
| torch_dtype=torch.float16, | |
| device="cuda", | |
| file_path_list: List[str] = [], | |
| ): | |
| self.torch_dtype = torch_dtype | |
| self.device = device | |
| self.model = [] | |
| self.model_path = [] | |
| self.model_name = [] | |
| self.model_detector = [ | |
| ModelDetectorFromSingleFile(model_loader_configs), | |
| ModelDetectorFromSplitedSingleFile(model_loader_configs), | |
| ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs), | |
| ModelDetectorFromPatchedSingleFile(patch_model_loader_configs), | |
| ] | |
| self.load_models(file_path_list) | |
| def load_model_from_single_file( | |
| self, | |
| file_path="", | |
| state_dict={}, | |
| model_names=[], | |
| model_classes=[], | |
| model_resource=None, | |
| ): | |
| print(f"Loading models from file: {file_path}") | |
| if len(state_dict) == 0: | |
| state_dict = load_state_dict(file_path) | |
| model_names, models = load_model_from_single_file( | |
| state_dict, | |
| model_names, | |
| model_classes, | |
| model_resource, | |
| self.torch_dtype, | |
| self.device, | |
| ) | |
| for model_name, model in zip(model_names, models): | |
| self.model.append(model) | |
| self.model_path.append(file_path) | |
| self.model_name.append(model_name) | |
| # print(f" The following models are loaded: {model_names}.") | |
| def load_model_from_huggingface_folder( | |
| self, file_path="", model_names=[], model_classes=[] | |
| ): | |
| print(f"Loading models from folder: {file_path}") | |
| model_names, models = load_model_from_huggingface_folder( | |
| file_path, model_names, model_classes, self.torch_dtype, self.device | |
| ) | |
| for model_name, model in zip(model_names, models): | |
| self.model.append(model) | |
| self.model_path.append(file_path) | |
| self.model_name.append(model_name) | |
| # print(f" The following models are loaded: {model_names}.") | |
| def load_patch_model_from_single_file( | |
| self, | |
| file_path="", | |
| state_dict={}, | |
| model_names=[], | |
| model_classes=[], | |
| extra_kwargs={}, | |
| ): | |
| print(f"Loading patch models from file: {file_path}") | |
| model_names, models = load_patch_model_from_single_file( | |
| state_dict, | |
| model_names, | |
| model_classes, | |
| extra_kwargs, | |
| self, | |
| self.torch_dtype, | |
| self.device, | |
| ) | |
| for model_name, model in zip(model_names, models): | |
| self.model.append(model) | |
| self.model_path.append(file_path) | |
| self.model_name.append(model_name) | |
| print(f" The following patched models are loaded: {model_names}.") | |
| def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0): | |
| if isinstance(file_path, list): | |
| for file_path_ in file_path: | |
| self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha) | |
| else: | |
| print(f"Loading LoRA models from file: {file_path}") | |
| is_loaded = False | |
| if len(state_dict) == 0: | |
| state_dict = load_state_dict(file_path) | |
| for model_name, model, model_path in zip( | |
| self.model_name, self.model, self.model_path | |
| ): | |
| for lora in get_lora_loaders(): | |
| match_results = lora.match(model, state_dict) | |
| if match_results is not None: | |
| print(f" Adding LoRA to {model_name} ({model_path}).") | |
| lora_prefix, model_resource = match_results | |
| lora.load( | |
| model, | |
| state_dict, | |
| lora_prefix, | |
| alpha=lora_alpha, | |
| model_resource=model_resource, | |
| ) | |
| is_loaded = True | |
| break | |
| if not is_loaded: | |
| print(f" Cannot load LoRA: {file_path}") | |
| def load_model(self, file_path, model_names=None, device=None, torch_dtype=None): | |
| # print(f"Loading models from: {file_path}") | |
| if device is None: | |
| device = self.device | |
| if torch_dtype is None: | |
| torch_dtype = self.torch_dtype | |
| if isinstance(file_path, list): | |
| state_dict = {} | |
| for path in file_path: | |
| state_dict.update(load_state_dict(path)) | |
| elif os.path.isfile(file_path): | |
| state_dict = load_state_dict(file_path) | |
| else: | |
| state_dict = None | |
| for model_detector in self.model_detector: | |
| if model_detector.match(file_path, state_dict): | |
| model_names, models = model_detector.load( | |
| file_path, | |
| state_dict, | |
| device=device, | |
| torch_dtype=torch_dtype, | |
| allowed_model_names=model_names, | |
| model_manager=self, | |
| ) | |
| for model_name, model in zip(model_names, models): | |
| self.model.append(model) | |
| self.model_path.append(file_path) | |
| self.model_name.append(model_name) | |
| # print(f" The following models are loaded: {model_names}.") | |
| break | |
| else: | |
| print(f" We cannot detect the model type. No models are loaded.") | |
| def load_models( | |
| self, file_path_list, model_names=None, device=None, torch_dtype=None | |
| ): | |
| for file_path in file_path_list: | |
| self.load_model( | |
| file_path, model_names, device=device, torch_dtype=torch_dtype | |
| ) | |
| def fetch_model(self, model_name, file_path=None, require_model_path=False): | |
| fetched_models = [] | |
| fetched_model_paths = [] | |
| for model, model_path, model_name_ in zip( | |
| self.model, self.model_path, self.model_name | |
| ): | |
| if file_path is not None and file_path != model_path: | |
| continue | |
| if model_name == model_name_: | |
| fetched_models.append(model) | |
| fetched_model_paths.append(model_path) | |
| if len(fetched_models) == 0: | |
| # print(f"No {model_name} models available.") | |
| return None | |
| if len(fetched_models) == 1: | |
| print(f"Using {model_name} from {fetched_model_paths[0]}") | |
| else: | |
| print( | |
| f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}" | |
| ) | |
| if require_model_path: | |
| return fetched_models[0], fetched_model_paths[0] | |
| else: | |
| return fetched_models[0] | |
| def to(self, device): | |
| for model in self.model: | |
| model.to(device) | |
| CACHE_T = 2 | |
| def check_is_instance(model, module_class): | |
| if isinstance(model, module_class): | |
| return True | |
| if hasattr(model, "module") and isinstance(model.module, module_class): | |
| return True | |
| return False | |
| def block_causal_mask(x, block_size): | |
| # params | |
| b, n, s, _, device = *x.size(), x.device | |
| assert s % block_size == 0 | |
| num_blocks = s // block_size | |
| # build mask | |
| mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device) | |
| for i in range(num_blocks): | |
| mask[:, :, i * block_size : (i + 1) * block_size, : (i + 1) * block_size] = 1 | |
| return mask | |
| class CausalConv3d(nn.Conv3d): | |
| """ | |
| Causal 3d convolusion. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self._padding = ( | |
| self.padding[2], | |
| self.padding[2], | |
| self.padding[1], | |
| self.padding[1], | |
| 2 * self.padding[0], | |
| 0, | |
| ) | |
| self.padding = (0, 0, 0) | |
| def forward(self, x, cache_x=None): | |
| padding = list(self._padding) | |
| if cache_x is not None and self._padding[4] > 0: | |
| cache_x = cache_x.to(x.device) | |
| # print('cache_x.shape', cache_x.shape, 'x.shape', x.shape) | |
| x = torch.cat([cache_x, x], dim=2) | |
| padding[4] -= cache_x.shape[2] | |
| x = F.pad(x, padding) | |
| return super().forward(x) | |
| class RMS_norm(nn.Module): | |
| def __init__(self, dim, channel_first=True, images=True, bias=False): | |
| super().__init__() | |
| broadcastable_dims = (1, 1, 1) if not images else (1, 1) | |
| shape = (dim, *broadcastable_dims) if channel_first else (dim,) | |
| self.channel_first = channel_first | |
| self.scale = dim**0.5 | |
| self.gamma = nn.Parameter(torch.ones(shape)) | |
| self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 | |
| def forward(self, x): | |
| return ( | |
| F.normalize(x, dim=(1 if self.channel_first else -1)) | |
| * self.scale | |
| * self.gamma | |
| + self.bias | |
| ) | |
| class Upsample(nn.Upsample): | |
| def forward(self, x): | |
| """ | |
| Fix bfloat16 support for nearest neighbor interpolation. | |
| """ | |
| return super().forward(x.float()).type_as(x) | |
| class Resample(nn.Module): | |
| def __init__(self, dim, mode): | |
| assert mode in ( | |
| "none", | |
| "upsample2d", | |
| "upsample3d", | |
| "downsample2d", | |
| "downsample3d", | |
| ) | |
| super().__init__() | |
| self.dim = dim | |
| self.mode = mode | |
| # layers | |
| if mode == "upsample2d": | |
| self.resample = nn.Sequential( | |
| Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), | |
| nn.Conv2d(dim, dim // 2, 3, padding=1), | |
| ) | |
| elif mode == "upsample3d": | |
| self.resample = nn.Sequential( | |
| Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), | |
| nn.Conv2d(dim, dim // 2, 3, padding=1), | |
| ) | |
| self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) | |
| elif mode == "downsample2d": | |
| self.resample = nn.Sequential( | |
| nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) | |
| ) | |
| elif mode == "downsample3d": | |
| self.resample = nn.Sequential( | |
| nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) | |
| ) | |
| self.time_conv = CausalConv3d( | |
| dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0) | |
| ) | |
| else: | |
| self.resample = nn.Identity() | |
| def forward(self, x, feat_cache=None, feat_idx=[0]): | |
| b, c, t, h, w = x.size() | |
| if self.mode == "upsample3d": | |
| if feat_cache is not None: | |
| idx = feat_idx[0] | |
| if feat_cache[idx] is None: | |
| feat_cache[idx] = "Rep" | |
| feat_idx[0] += 1 | |
| else: | |
| cache_x = x[:, :, -CACHE_T:, :, :].clone() | |
| if ( | |
| cache_x.shape[2] < 2 | |
| and feat_cache[idx] is not None | |
| and feat_cache[idx] != "Rep" | |
| ): | |
| # cache last frame of last two chunk | |
| cache_x = torch.cat( | |
| [ | |
| feat_cache[idx][:, :, -1, :, :] | |
| .unsqueeze(2) | |
| .to(cache_x.device), | |
| cache_x, | |
| ], | |
| dim=2, | |
| ) | |
| if ( | |
| cache_x.shape[2] < 2 | |
| and feat_cache[idx] is not None | |
| and feat_cache[idx] == "Rep" | |
| ): | |
| cache_x = torch.cat( | |
| [torch.zeros_like(cache_x).to(cache_x.device), cache_x], | |
| dim=2, | |
| ) | |
| if feat_cache[idx] == "Rep": | |
| x = self.time_conv(x) | |
| else: | |
| x = self.time_conv(x, feat_cache[idx]) | |
| feat_cache[idx] = cache_x | |
| feat_idx[0] += 1 | |
| x = x.reshape(b, 2, c, t, h, w) | |
| x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) | |
| x = x.reshape(b, c, t * 2, h, w) | |
| t = x.shape[2] | |
| x = rearrange(x, "b c t h w -> (b t) c h w") | |
| x = self.resample(x) | |
| x = rearrange(x, "(b t) c h w -> b c t h w", t=t) | |
| if self.mode == "downsample3d": | |
| if feat_cache is not None: | |
| idx = feat_idx[0] | |
| if feat_cache[idx] is None: | |
| feat_cache[idx] = x.clone() | |
| feat_idx[0] += 1 | |
| else: | |
| cache_x = x[:, :, -1:, :, :].clone() | |
| x = self.time_conv( | |
| torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2) | |
| ) | |
| feat_cache[idx] = cache_x | |
| feat_idx[0] += 1 | |
| return x | |
| def init_weight(self, conv): | |
| conv_weight = conv.weight | |
| nn.init.zeros_(conv_weight) | |
| c1, c2, t, h, w = conv_weight.size() | |
| one_matrix = torch.eye(c1, c2) | |
| init_matrix = one_matrix | |
| nn.init.zeros_(conv_weight) | |
| conv_weight.data[:, :, 1, 0, 0] = init_matrix | |
| conv.weight.data.copy_(conv_weight) | |
| nn.init.zeros_(conv.bias.data) | |
| def init_weight2(self, conv): | |
| conv_weight = conv.weight.data | |
| nn.init.zeros_(conv_weight) | |
| c1, c2, t, h, w = conv_weight.size() | |
| init_matrix = torch.eye(c1 // 2, c2) | |
| conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix | |
| conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix | |
| conv.weight.data.copy_(conv_weight) | |
| nn.init.zeros_(conv.bias.data) | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, in_dim, out_dim, dropout=0.0): | |
| super().__init__() | |
| self.in_dim = in_dim | |
| self.out_dim = out_dim | |
| # layers | |
| self.residual = nn.Sequential( | |
| RMS_norm(in_dim, images=False), | |
| nn.SiLU(), | |
| CausalConv3d(in_dim, out_dim, 3, padding=1), | |
| RMS_norm(out_dim, images=False), | |
| nn.SiLU(), | |
| nn.Dropout(dropout), | |
| CausalConv3d(out_dim, out_dim, 3, padding=1), | |
| ) | |
| self.shortcut = ( | |
| CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() | |
| ) | |
| def forward(self, x, feat_cache=None, feat_idx=[0]): | |
| h = self.shortcut(x) | |
| for layer in self.residual: | |
| if check_is_instance(layer, CausalConv3d) and feat_cache is not None: | |
| idx = feat_idx[0] | |
| cache_x = x[:, :, -CACHE_T:, :, :].clone() | |
| if cache_x.shape[2] < 2 and feat_cache[idx] is not None: | |
| # cache last frame of last two chunk | |
| cache_x = torch.cat( | |
| [ | |
| feat_cache[idx][:, :, -1, :, :] | |
| .unsqueeze(2) | |
| .to(cache_x.device), | |
| cache_x, | |
| ], | |
| dim=2, | |
| ) | |
| x = layer(x, feat_cache[idx]) | |
| feat_cache[idx] = cache_x | |
| feat_idx[0] += 1 | |
| else: | |
| x = layer(x) | |
| return x + h | |
| class AttentionBlock(nn.Module): | |
| """ | |
| Causal self-attention with a single head. | |
| """ | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| # layers | |
| self.norm = RMS_norm(dim) | |
| self.to_qkv = nn.Conv2d(dim, dim * 3, 1) | |
| self.proj = nn.Conv2d(dim, dim, 1) | |
| # zero out the last layer params | |
| nn.init.zeros_(self.proj.weight) | |
| def forward(self, x): | |
| identity = x | |
| b, c, t, h, w = x.size() | |
| x = rearrange(x, "b c t h w -> (b t) c h w") | |
| x = self.norm(x) | |
| # compute query, key, value | |
| q, k, v = ( | |
| self.to_qkv(x) | |
| .reshape(b * t, 1, c * 3, -1) | |
| .permute(0, 1, 3, 2) | |
| .contiguous() | |
| .chunk(3, dim=-1) | |
| ) | |
| # apply attention | |
| x = F.scaled_dot_product_attention( | |
| q, | |
| k, | |
| v, | |
| # attn_mask=block_causal_mask(q, block_size=h * w) | |
| ) | |
| x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) | |
| # output | |
| x = self.proj(x) | |
| x = rearrange(x, "(b t) c h w-> b c t h w", t=t) | |
| return x + identity | |
| class Encoder3d(nn.Module): | |
| def __init__( | |
| self, | |
| dim=128, | |
| z_dim=4, | |
| dim_mult=[1, 2, 4, 4], | |
| num_res_blocks=2, | |
| attn_scales=[], | |
| temperal_downsample=[True, True, False], | |
| dropout=0.0, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.z_dim = z_dim | |
| self.dim_mult = dim_mult | |
| self.num_res_blocks = num_res_blocks | |
| self.attn_scales = attn_scales | |
| self.temperal_downsample = temperal_downsample | |
| # dimensions | |
| dims = [dim * u for u in [1] + dim_mult] | |
| scale = 1.0 | |
| # init block | |
| self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) | |
| # downsample blocks | |
| downsamples = [] | |
| for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): | |
| # residual (+attention) blocks | |
| for _ in range(num_res_blocks): | |
| downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) | |
| if scale in attn_scales: | |
| downsamples.append(AttentionBlock(out_dim)) | |
| in_dim = out_dim | |
| # downsample block | |
| if i != len(dim_mult) - 1: | |
| mode = "downsample3d" if temperal_downsample[i] else "downsample2d" | |
| downsamples.append(Resample(out_dim, mode=mode)) | |
| scale /= 2.0 | |
| self.downsamples = nn.Sequential(*downsamples) | |
| # middle blocks | |
| self.middle = nn.Sequential( | |
| ResidualBlock(out_dim, out_dim, dropout), | |
| AttentionBlock(out_dim), | |
| ResidualBlock(out_dim, out_dim, dropout), | |
| ) | |
| # output blocks | |
| self.head = nn.Sequential( | |
| RMS_norm(out_dim, images=False), | |
| nn.SiLU(), | |
| CausalConv3d(out_dim, z_dim, 3, padding=1), | |
| ) | |
| def forward(self, x, feat_cache=None, feat_idx=[0]): | |
| if feat_cache is not None: | |
| idx = feat_idx[0] | |
| cache_x = x[:, :, -CACHE_T:, :, :].clone() | |
| if cache_x.shape[2] < 2 and feat_cache[idx] is not None: | |
| # cache last frame of last two chunk | |
| cache_x = torch.cat( | |
| [ | |
| feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), | |
| cache_x, | |
| ], | |
| dim=2, | |
| ) | |
| x = self.conv1(x, feat_cache[idx]) | |
| feat_cache[idx] = cache_x | |
| feat_idx[0] += 1 | |
| else: | |
| x = self.conv1(x) | |
| ## downsamples | |
| for layer in self.downsamples: | |
| if feat_cache is not None: | |
| x = layer(x, feat_cache, feat_idx) | |
| else: | |
| x = layer(x) | |
| ## middle | |
| for layer in self.middle: | |
| if check_is_instance(layer, ResidualBlock) and feat_cache is not None: | |
| x = layer(x, feat_cache, feat_idx) | |
| else: | |
| x = layer(x) | |
| ## head | |
| for layer in self.head: | |
| if check_is_instance(layer, CausalConv3d) and feat_cache is not None: | |
| idx = feat_idx[0] | |
| cache_x = x[:, :, -CACHE_T:, :, :].clone() | |
| if cache_x.shape[2] < 2 and feat_cache[idx] is not None: | |
| # cache last frame of last two chunk | |
| cache_x = torch.cat( | |
| [ | |
| feat_cache[idx][:, :, -1, :, :] | |
| .unsqueeze(2) | |
| .to(cache_x.device), | |
| cache_x, | |
| ], | |
| dim=2, | |
| ) | |
| x = layer(x, feat_cache[idx]) | |
| feat_cache[idx] = cache_x | |
| feat_idx[0] += 1 | |
| else: | |
| x = layer(x) | |
| return x | |
| class Decoder3d(nn.Module): | |
| def __init__( | |
| self, | |
| dim=128, | |
| z_dim=4, | |
| dim_mult=[1, 2, 4, 4], | |
| num_res_blocks=2, | |
| attn_scales=[], | |
| temperal_upsample=[False, True, True], | |
| dropout=0.0, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.z_dim = z_dim | |
| self.dim_mult = dim_mult | |
| self.num_res_blocks = num_res_blocks | |
| self.attn_scales = attn_scales | |
| self.temperal_upsample = temperal_upsample | |
| # dimensions | |
| dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] | |
| scale = 1.0 / 2 ** (len(dim_mult) - 2) | |
| # init block | |
| self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) | |
| # middle blocks | |
| self.middle = nn.Sequential( | |
| ResidualBlock(dims[0], dims[0], dropout), | |
| AttentionBlock(dims[0]), | |
| ResidualBlock(dims[0], dims[0], dropout), | |
| ) | |
| # upsample blocks | |
| upsamples = [] | |
| for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): | |
| # residual (+attention) blocks | |
| if i == 1 or i == 2 or i == 3: | |
| in_dim = in_dim // 2 | |
| for _ in range(num_res_blocks + 1): | |
| upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) | |
| if scale in attn_scales: | |
| upsamples.append(AttentionBlock(out_dim)) | |
| in_dim = out_dim | |
| # upsample block | |
| if i != len(dim_mult) - 1: | |
| mode = "upsample3d" if temperal_upsample[i] else "upsample2d" | |
| upsamples.append(Resample(out_dim, mode=mode)) | |
| scale *= 2.0 | |
| self.upsamples = nn.Sequential(*upsamples) | |
| # output blocks | |
| self.head = nn.Sequential( | |
| RMS_norm(out_dim, images=False), | |
| nn.SiLU(), | |
| CausalConv3d(out_dim, 3, 3, padding=1), | |
| ) | |
| def forward(self, x, feat_cache=None, feat_idx=[0]): | |
| ## conv1 | |
| if feat_cache is not None: | |
| idx = feat_idx[0] | |
| cache_x = x[:, :, -CACHE_T:, :, :].clone() | |
| if cache_x.shape[2] < 2 and feat_cache[idx] is not None: | |
| # cache last frame of last two chunk | |
| cache_x = torch.cat( | |
| [ | |
| feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), | |
| cache_x, | |
| ], | |
| dim=2, | |
| ) | |
| x = self.conv1(x, feat_cache[idx]) | |
| feat_cache[idx] = cache_x | |
| feat_idx[0] += 1 | |
| else: | |
| x = self.conv1(x) | |
| ## middle | |
| for layer in self.middle: | |
| if check_is_instance(layer, ResidualBlock) and feat_cache is not None: | |
| x = layer(x, feat_cache, feat_idx) | |
| else: | |
| x = layer(x) | |
| ## upsamples | |
| for layer in self.upsamples: | |
| if feat_cache is not None: | |
| x = layer(x, feat_cache, feat_idx) | |
| else: | |
| x = layer(x) | |
| ## head | |
| for layer in self.head: | |
| if check_is_instance(layer, CausalConv3d) and feat_cache is not None: | |
| idx = feat_idx[0] | |
| cache_x = x[:, :, -CACHE_T:, :, :].clone() | |
| if cache_x.shape[2] < 2 and feat_cache[idx] is not None: | |
| # cache last frame of last two chunk | |
| cache_x = torch.cat( | |
| [ | |
| feat_cache[idx][:, :, -1, :, :] | |
| .unsqueeze(2) | |
| .to(cache_x.device), | |
| cache_x, | |
| ], | |
| dim=2, | |
| ) | |
| x = layer(x, feat_cache[idx]) | |
| feat_cache[idx] = cache_x | |
| feat_idx[0] += 1 | |
| else: | |
| x = layer(x) | |
| return x | |
| def count_conv3d(model): | |
| count = 0 | |
| for m in model.modules(): | |
| if check_is_instance(m, CausalConv3d): | |
| count += 1 | |
| return count | |
| class VideoVAE_(nn.Module): | |
| def __init__( | |
| self, | |
| dim=96, | |
| z_dim=16, | |
| dim_mult=[1, 2, 4, 4], | |
| num_res_blocks=2, | |
| attn_scales=[], | |
| temperal_downsample=[False, True, True], | |
| dropout=0.0, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.z_dim = z_dim | |
| self.dim_mult = dim_mult | |
| self.num_res_blocks = num_res_blocks | |
| self.attn_scales = attn_scales | |
| self.temperal_downsample = temperal_downsample | |
| self.temperal_upsample = temperal_downsample[::-1] | |
| # modules | |
| self.encoder = Encoder3d( | |
| dim, | |
| z_dim * 2, | |
| dim_mult, | |
| num_res_blocks, | |
| attn_scales, | |
| self.temperal_downsample, | |
| dropout, | |
| ) | |
| self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) | |
| self.conv2 = CausalConv3d(z_dim, z_dim, 1) | |
| self.decoder = Decoder3d( | |
| dim, | |
| z_dim, | |
| dim_mult, | |
| num_res_blocks, | |
| attn_scales, | |
| self.temperal_upsample, | |
| dropout, | |
| ) | |
| def forward(self, x): | |
| mu, log_var = self.encode(x) | |
| z = self.reparameterize(mu, log_var) | |
| x_recon = self.decode(z) | |
| return x_recon, mu, log_var | |
| def encode(self, x, scale): | |
| self.clear_cache() | |
| ## cache | |
| t = x.shape[2] | |
| iter_ = 1 + (t - 1) // 4 | |
| for i in range(iter_): | |
| self._enc_conv_idx = [0] | |
| if i == 0: | |
| out = self.encoder( | |
| x[:, :, :1, :, :], | |
| feat_cache=self._enc_feat_map, | |
| feat_idx=self._enc_conv_idx, | |
| ) | |
| else: | |
| out_ = self.encoder( | |
| x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], | |
| feat_cache=self._enc_feat_map, | |
| feat_idx=self._enc_conv_idx, | |
| ) | |
| out = torch.cat([out, out_], 2) | |
| mu, log_var = self.conv1(out).chunk(2, dim=1) | |
| if isinstance(scale[0], torch.Tensor): | |
| scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale] | |
| mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( | |
| 1, self.z_dim, 1, 1, 1 | |
| ) | |
| else: | |
| scale = scale.to(dtype=mu.dtype, device=mu.device) | |
| mu = (mu - scale[0]) * scale[1] | |
| return mu | |
| def decode(self, z, scale): | |
| self.clear_cache() | |
| # z: [b,c,t,h,w] | |
| if isinstance(scale[0], torch.Tensor): | |
| scale = [s.to(dtype=z.dtype, device=z.device) for s in scale] | |
| z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( | |
| 1, self.z_dim, 1, 1, 1 | |
| ) | |
| else: | |
| scale = scale.to(dtype=z.dtype, device=z.device) | |
| z = z / scale[1] + scale[0] | |
| iter_ = z.shape[2] | |
| x = self.conv2(z) | |
| for i in range(iter_): | |
| self._conv_idx = [0] | |
| if i == 0: | |
| out = self.decoder( | |
| x[:, :, i : i + 1, :, :], | |
| feat_cache=self._feat_map, | |
| feat_idx=self._conv_idx, | |
| ) | |
| else: | |
| out_ = self.decoder( | |
| x[:, :, i : i + 1, :, :], | |
| feat_cache=self._feat_map, | |
| feat_idx=self._conv_idx, | |
| ) | |
| out = torch.cat([out, out_], 2) # may add tensor offload | |
| return out | |
| def stream_decode(self, z, scale): | |
| # self.clear_cache() | |
| # z: [b,c,t,h,w] | |
| if isinstance(scale[0], torch.Tensor): | |
| scale = [s.to(dtype=z.dtype, device=z.device) for s in scale] | |
| z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( | |
| 1, self.z_dim, 1, 1, 1 | |
| ) | |
| else: | |
| scale = scale.to(dtype=z.dtype, device=z.device) | |
| z = z / scale[1] + scale[0] | |
| iter_ = z.shape[2] | |
| x = self.conv2(z) | |
| for i in range(iter_): | |
| self._conv_idx = [0] | |
| if i == 0: | |
| out = self.decoder( | |
| x[:, :, i : i + 1, :, :], | |
| feat_cache=self._feat_map, | |
| feat_idx=self._conv_idx, | |
| ) | |
| else: | |
| out_ = self.decoder( | |
| x[:, :, i : i + 1, :, :], | |
| feat_cache=self._feat_map, | |
| feat_idx=self._conv_idx, | |
| ) | |
| out = torch.cat([out, out_], 2) # may add tensor offload | |
| return out | |
| def reparameterize(self, mu, log_var): | |
| std = torch.exp(0.5 * log_var) | |
| eps = torch.randn_like(std) | |
| return eps * std + mu | |
| def sample(self, imgs, deterministic=False): | |
| mu, log_var = self.encode(imgs) | |
| if deterministic: | |
| return mu | |
| std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) | |
| return mu + std * torch.randn_like(std) | |
| def clear_cache(self): | |
| self._conv_num = count_conv3d(self.decoder) | |
| self._conv_idx = [0] | |
| self._feat_map = [None] * self._conv_num | |
| # print('self._feat_map', len(self._feat_map)) | |
| # cache encode | |
| if self.encoder is not None: | |
| self._enc_conv_num = count_conv3d(self.encoder) | |
| self._enc_conv_idx = [0] | |
| self._enc_feat_map = [None] * self._enc_conv_num | |
| # print('self._enc_feat_map', len(self._enc_feat_map)) | |
| class WanVideoVAEStateDictConverter: | |
| def __init__(self): | |
| pass | |
| def from_civitai(self, state_dict): | |
| state_dict_ = {} | |
| if "model_state" in state_dict: | |
| state_dict = state_dict["model_state"] | |
| for name in state_dict: | |
| state_dict_["model." + name] = state_dict[name] | |
| return state_dict_ | |
| class FlowMatchScheduler: | |
| def __init__( | |
| self, | |
| num_inference_steps=100, | |
| num_train_timesteps=1000, | |
| shift=3.0, | |
| sigma_max=1.0, | |
| sigma_min=0.003 / 1.002, | |
| inverse_timesteps=False, | |
| extra_one_step=False, | |
| reverse_sigmas=False, | |
| ): | |
| self.num_train_timesteps = num_train_timesteps | |
| self.shift = shift | |
| self.sigma_max = sigma_max | |
| self.sigma_min = sigma_min | |
| self.inverse_timesteps = inverse_timesteps | |
| self.extra_one_step = extra_one_step | |
| self.reverse_sigmas = reverse_sigmas | |
| self.set_timesteps(num_inference_steps) | |
| def set_timesteps( | |
| self, | |
| num_inference_steps=100, | |
| denoising_strength=1.0, | |
| training=False, | |
| shift=None, | |
| ): | |
| if shift is not None: | |
| self.shift = shift | |
| sigma_start = ( | |
| self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength | |
| ) | |
| if self.extra_one_step: | |
| self.sigmas = torch.linspace( | |
| sigma_start, self.sigma_min, num_inference_steps + 1 | |
| )[:-1] | |
| else: | |
| self.sigmas = torch.linspace( | |
| sigma_start, self.sigma_min, num_inference_steps | |
| ) | |
| if self.inverse_timesteps: | |
| self.sigmas = torch.flip(self.sigmas, dims=[0]) | |
| self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas) | |
| if self.reverse_sigmas: | |
| self.sigmas = 1 - self.sigmas | |
| self.timesteps = self.sigmas * self.num_train_timesteps | |
| if training: | |
| x = self.timesteps | |
| y = torch.exp( | |
| -2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2 | |
| ) | |
| y_shifted = y - y.min() | |
| bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum()) | |
| self.linear_timesteps_weights = bsmntw_weighing | |
| def step(self, model_output, timestep, sample, to_final=False, **kwargs): | |
| if isinstance(timestep, torch.Tensor): | |
| timestep = timestep.cpu() | |
| timestep_id = torch.argmin((self.timesteps - timestep).abs()) | |
| sigma = self.sigmas[timestep_id] | |
| if to_final or timestep_id + 1 >= len(self.timesteps): | |
| sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0 | |
| else: | |
| sigma_ = self.sigmas[timestep_id + 1] | |
| prev_sample = sample + model_output * (sigma_ - sigma) | |
| return prev_sample | |
| def return_to_timestep(self, timestep, sample, sample_stablized): | |
| if isinstance(timestep, torch.Tensor): | |
| timestep = timestep.cpu() | |
| timestep_id = torch.argmin((self.timesteps - timestep).abs()) | |
| sigma = self.sigmas[timestep_id] | |
| model_output = (sample - sample_stablized) / sigma | |
| return model_output | |
| def add_noise(self, original_samples, noise, timestep): | |
| if isinstance(timestep, torch.Tensor): | |
| timestep = timestep.cpu() | |
| timestep_id = torch.argmin((self.timesteps - timestep).abs()) | |
| sigma = self.sigmas[timestep_id] | |
| sample = (1 - sigma) * original_samples + sigma * noise | |
| return sample | |
| def training_target(self, sample, noise, timestep): | |
| target = noise - sample | |
| return target | |
| def training_weight(self, timestep): | |
| timestep_id = torch.argmin( | |
| (self.timesteps - timestep.to(self.timesteps.device)).abs() | |
| ) | |
| weights = self.linear_timesteps_weights[timestep_id] | |
| return weights | |
| class BasePipeline(torch.nn.Module): | |
| def __init__( | |
| self, | |
| device="cuda", | |
| torch_dtype=torch.float16, | |
| height_division_factor=64, | |
| width_division_factor=64, | |
| ): | |
| super().__init__() | |
| self.device = device | |
| self.torch_dtype = torch_dtype | |
| self.height_division_factor = height_division_factor | |
| self.width_division_factor = width_division_factor | |
| self.cpu_offload = False | |
| self.model_names = [] | |
| def check_resize_height_width(self, height, width): | |
| if height % self.height_division_factor != 0: | |
| height = ( | |
| (height + self.height_division_factor - 1) | |
| // self.height_division_factor | |
| * self.height_division_factor | |
| ) | |
| print( | |
| f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}." | |
| ) | |
| if width % self.width_division_factor != 0: | |
| width = ( | |
| (width + self.width_division_factor - 1) | |
| // self.width_division_factor | |
| * self.width_division_factor | |
| ) | |
| print( | |
| f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}." | |
| ) | |
| return height, width | |
| def preprocess_image(self, image): | |
| image = ( | |
| torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1) | |
| .permute(2, 0, 1) | |
| .unsqueeze(0) | |
| ) | |
| return image | |
| def preprocess_images(self, images): | |
| return [self.preprocess_image(image) for image in images] | |
| def vae_output_to_image(self, vae_output): | |
| image = vae_output[0].cpu().float().permute(1, 2, 0).numpy() | |
| image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) | |
| return image | |
| def vae_output_to_video(self, vae_output): | |
| video = vae_output.cpu().permute(1, 2, 0).numpy() | |
| video = [ | |
| Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) | |
| for image in video | |
| ] | |
| return video | |
| def merge_latents( | |
| self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0 | |
| ): | |
| if len(latents) > 0: | |
| blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma) | |
| height, width = value.shape[-2:] | |
| weight = torch.ones_like(value) | |
| for latent, mask, scale in zip(latents, masks, scales): | |
| mask = ( | |
| self.preprocess_image(mask.resize((width, height))).mean( | |
| dim=1, keepdim=True | |
| ) | |
| > 0 | |
| ) | |
| mask = mask.repeat(1, latent.shape[1], 1, 1).to( | |
| dtype=latent.dtype, device=latent.device | |
| ) | |
| mask = blur(mask) | |
| value += latent * mask * scale | |
| weight += mask * scale | |
| value /= weight | |
| return value | |
| def control_noise_via_local_prompts( | |
| self, | |
| prompt_emb_global, | |
| prompt_emb_locals, | |
| masks, | |
| mask_scales, | |
| inference_callback, | |
| special_kwargs=None, | |
| special_local_kwargs_list=None, | |
| ): | |
| if special_kwargs is None: | |
| noise_pred_global = inference_callback(prompt_emb_global) | |
| else: | |
| noise_pred_global = inference_callback(prompt_emb_global, special_kwargs) | |
| if special_local_kwargs_list is None: | |
| noise_pred_locals = [ | |
| inference_callback(prompt_emb_local) | |
| for prompt_emb_local in prompt_emb_locals | |
| ] | |
| else: | |
| noise_pred_locals = [ | |
| inference_callback(prompt_emb_local, special_kwargs) | |
| for prompt_emb_local, special_kwargs in zip( | |
| prompt_emb_locals, special_local_kwargs_list | |
| ) | |
| ] | |
| noise_pred = self.merge_latents( | |
| noise_pred_global, noise_pred_locals, masks, mask_scales | |
| ) | |
| return noise_pred | |
| def extend_prompt(self, prompt, local_prompts, masks, mask_scales): | |
| local_prompts = local_prompts or [] | |
| masks = masks or [] | |
| mask_scales = mask_scales or [] | |
| extended_prompt_dict = self.prompter.extend_prompt(prompt) | |
| prompt = extended_prompt_dict.get("prompt", prompt) | |
| local_prompts += extended_prompt_dict.get("prompts", []) | |
| masks += extended_prompt_dict.get("masks", []) | |
| mask_scales += [100.0] * len(extended_prompt_dict.get("masks", [])) | |
| return prompt, local_prompts, masks, mask_scales | |
| def enable_cpu_offload(self): | |
| self.cpu_offload = True | |
| def load_models_to_device(self, loadmodel_names=[]): | |
| # only load models to device if cpu_offload is enabled | |
| if not self.cpu_offload: | |
| return | |
| # offload the unneeded models to cpu | |
| for model_name in self.model_names: | |
| if model_name not in loadmodel_names: | |
| model = getattr(self, model_name) | |
| if model is not None: | |
| if ( | |
| hasattr(model, "vram_management_enabled") | |
| and model.vram_management_enabled | |
| ): | |
| for module in model.modules(): | |
| if hasattr(module, "offload"): | |
| module.offload() | |
| else: | |
| model.cpu() | |
| # load the needed models to device | |
| for model_name in loadmodel_names: | |
| model = getattr(self, model_name) | |
| if model is not None: | |
| if ( | |
| hasattr(model, "vram_management_enabled") | |
| and model.vram_management_enabled | |
| ): | |
| for module in model.modules(): | |
| if hasattr(module, "onload"): | |
| module.onload() | |
| else: | |
| model.to(self.device) | |
| # fresh the cuda cache | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| if torch.backends.mps.is_available(): | |
| torch.mps.empty_cache() | |
| def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16): | |
| generator = None if seed is None else torch.Generator(device).manual_seed(seed) | |
| noise = torch.randn(shape, generator=generator, device=device, dtype=dtype) | |
| return noise | |
| # ----------------------------- | |
| # 基础工具:ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet) | |
| # ----------------------------- | |
| def _calc_mean_std( | |
| feat: torch.Tensor, eps: float = 1e-5 | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| assert feat.dim() == 4, "feat 必须是 (N, C, H, W)" | |
| N, C = feat.shape[:2] | |
| var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps | |
| std = var.sqrt().view(N, C, 1, 1) | |
| mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) | |
| return mean, std | |
| def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor: | |
| assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配" | |
| size = content_feat.size() | |
| style_mean, style_std = _calc_mean_std(style_feat) | |
| content_mean, content_std = _calc_mean_std(content_feat) | |
| normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size) | |
| return normalized * style_std.expand(size) + style_mean.expand(size) | |
| # ----------------------------- | |
| # 小波式模糊与分解/重构(ColorCorrector 用) | |
| # ----------------------------- | |
| def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor: | |
| vals = [ | |
| [0.0625, 0.125, 0.0625], | |
| [0.125, 0.25, 0.125], | |
| [0.0625, 0.125, 0.0625], | |
| ] | |
| return torch.tensor(vals, dtype=dtype, device=device) | |
| def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor: | |
| assert x.dim() == 4, "x 必须是 (N, C, H, W)" | |
| N, C, H, W = x.shape | |
| base = _make_gaussian3x3_kernel(x.dtype, x.device) | |
| weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1) | |
| pad = radius | |
| x_pad = F.pad(x, (pad, pad, pad, pad), mode="replicate") | |
| out = F.conv2d( | |
| x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C | |
| ) | |
| return out | |
| def _wavelet_decompose( | |
| x: torch.Tensor, levels: int = 5 | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| assert x.dim() == 4, "x 必须是 (N, C, H, W)" | |
| high = torch.zeros_like(x) | |
| low = x | |
| for i in range(levels): | |
| radius = 2**i | |
| blurred = _wavelet_blur(low, radius) | |
| high = high + (low - blurred) | |
| low = blurred | |
| return high, low | |
| def _wavelet_reconstruct( | |
| content: torch.Tensor, style: torch.Tensor, levels: int = 5 | |
| ) -> torch.Tensor: | |
| c_high, _ = _wavelet_decompose(content, levels=levels) | |
| _, s_low = _wavelet_decompose(style, levels=levels) | |
| return c_high + s_low | |
| # ----------------------------- | |
| # 无状态颜色矫正模块(视频友好,默认 wavelet) | |
| # ----------------------------- | |
| class TorchColorCorrectorWavelet(nn.Module): | |
| def __init__(self, levels: int = 5): | |
| super().__init__() | |
| self.levels = levels | |
| @staticmethod | |
| def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]: | |
| assert x.dim() == 5, "输入必须是 (B, C, f, H, W)" | |
| B, C, f, H, W = x.shape | |
| y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W) | |
| return y, B, f | |
| @staticmethod | |
| def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor: | |
| BF, C, H, W = y.shape | |
| assert BF == B * f | |
| return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4) | |
| def forward( | |
| self, | |
| hq_image: torch.Tensor, # (B, C, f, H, W) | |
| lq_image: torch.Tensor, # (B, C, f, H, W) | |
| clip_range: Tuple[float, float] = (-1.0, 1.0), | |
| method: Literal["wavelet", "adain"] = "wavelet", | |
| chunk_size: Optional[int] = None, | |
| ) -> torch.Tensor: | |
| assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致" | |
| assert ( | |
| hq_image.dim() == 5 and hq_image.shape[1] == 3 | |
| ), "输入必须是 (B, 3, f, H, W)" | |
| B, C, f, H, W = hq_image.shape | |
| if chunk_size is None or chunk_size >= f: | |
| hq4, B, f = self._flatten_time(hq_image) | |
| lq4, _, _ = self._flatten_time(lq_image) | |
| if method == "wavelet": | |
| out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels) | |
| elif method == "adain": | |
| out4 = _adain(hq4, lq4) | |
| else: | |
| raise ValueError(f"未知 method: {method}") | |
| out4 = torch.clamp(out4, *clip_range) | |
| out = self._unflatten_time(out4, B, f) | |
| return out | |
| outs = [] | |
| for start in range(0, f, chunk_size): | |
| end = min(start + chunk_size, f) | |
| hq_chunk = hq_image[:, :, start:end] | |
| lq_chunk = lq_image[:, :, start:end] | |
| hq4, B_, f_ = self._flatten_time(hq_chunk) | |
| lq4, _, _ = self._flatten_time(lq_chunk) | |
| if method == "wavelet": | |
| out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels) | |
| elif method == "adain": | |
| out4 = _adain(hq4, lq4) | |
| else: | |
| raise ValueError(f"未知 method: {method}") | |
| out4 = torch.clamp(out4, *clip_range) | |
| out_chunk = self._unflatten_time(out4, B_, f_) | |
| outs.append(out_chunk) | |
| out = torch.cat(outs, dim=2) | |
| return out | |
| def cast_to(weight, dtype, device): | |
| r = torch.empty_like(weight, dtype=dtype, device=device) | |
| r.copy_(weight) | |
| return r | |
| class AutoWrappedModule(torch.nn.Module): | |
| def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device): | |
| super().__init__() | |
| self.module = module.to(dtype=offload_dtype, device=offload_device) | |
| self.offload_dtype = offload_dtype | |
| self.offload_device = offload_device | |
| self.onload_dtype = onload_dtype | |
| self.onload_device = onload_device | |
| self.computation_dtype = computation_dtype | |
| self.computation_device = computation_device | |
| self.state = 0 | |
| def offload(self): | |
| if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): | |
| self.module.to(dtype=self.offload_dtype, device=self.offload_device) | |
| self.state = 0 | |
| def onload(self): | |
| if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): | |
| self.module.to(dtype=self.onload_dtype, device=self.onload_device) | |
| self.state = 1 | |
| def forward(self, *args, **kwargs): | |
| if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: | |
| module = self.module | |
| else: | |
| module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device) | |
| return module(*args, **kwargs) | |
| class AutoWrappedLinear(torch.nn.Linear): | |
| def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device): | |
| with init_weights_on_device(device=torch.device("meta")): | |
| super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device) | |
| self.weight = module.weight | |
| self.bias = module.bias | |
| self.offload_dtype = offload_dtype | |
| self.offload_device = offload_device | |
| self.onload_dtype = onload_dtype | |
| self.onload_device = onload_device | |
| self.computation_dtype = computation_dtype | |
| self.computation_device = computation_device | |
| self.state = 0 | |
| def offload(self): | |
| if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): | |
| self.to(dtype=self.offload_dtype, device=self.offload_device) | |
| self.state = 0 | |
| def onload(self): | |
| if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): | |
| self.to(dtype=self.onload_dtype, device=self.onload_device) | |
| self.state = 1 | |
| def forward(self, x, *args, **kwargs): | |
| if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: | |
| weight, bias = self.weight, self.bias | |
| else: | |
| weight = cast_to(self.weight, self.computation_dtype, self.computation_device) | |
| bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) | |
| return torch.nn.functional.linear(x, weight, bias) | |
| def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0): | |
| for name, module in model.named_children(): | |
| for source_module, target_module in module_map.items(): | |
| if isinstance(module, source_module): | |
| num_param = sum(p.numel() for p in module.parameters()) | |
| if max_num_param is not None and total_num_param + num_param > max_num_param: | |
| module_config_ = overflow_module_config | |
| else: | |
| module_config_ = module_config | |
| module_ = target_module(module, **module_config_) | |
| setattr(model, name, module_) | |
| total_num_param += num_param | |
| break | |
| else: | |
| total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param) | |
| return total_num_param | |
| def enable_vram_management_(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None): | |
| enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0) | |
| model.vram_management_enabled = True | |
| # ----------------------------- | |
| # 简化版 Pipeline(仅 dit + vae) | |
| # ----------------------------- | |
| class FlashVSRTinyPipeline(BasePipeline): | |
| def __init__(self, device="cuda", torch_dtype=torch.float16): | |
| super().__init__(device=device, torch_dtype=torch_dtype) | |
| self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True) | |
| self.dit: WanModel = None | |
| self.vae: WanVideoVAE = None | |
| self.model_names = ["dit", "vae"] | |
| self.height_division_factor = 16 | |
| self.width_division_factor = 16 | |
| self.use_unified_sequence_parallel = False | |
| self.prompt_emb_posi = None | |
| self.ColorCorrector = TorchColorCorrectorWavelet(levels=5) | |
| print( | |
| r""" | |
| ███████╗██╗ █████╗ ███████╗██╗ ██╗██╗ ██╗███████╗█████╗ | |
| ██╔════╝██║ ██╔══██╗██╔════╝██║ ██║██║ ██║██╔════╝██╔══██╗ ██╗ | |
| █████╗ ██║ ███████║███████╗███████║╚██╗ ██╔╝███████╗███████║ ██████╗ | |
| ██╔══╝ ██║ ██╔══██║╚════██║██╔══██║ ╚████╔╝ ╚════██║██╔═██║ ██╔═╝ | |
| ██║ ███████╗██║ ██║███████║██║ ██║ ╚██╔╝ ███████║██║ ██║ ╚═╝ | |
| ╚═╝ ╚══════╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝ | |
| """ | |
| ) | |
| def enable_vram_management(self, num_persistent_param_in_dit=None): | |
| # 仅管理 dit / vae | |
| dtype = next(iter(self.dit.parameters())).dtype | |
| enable_vram_management_( | |
| self.dit, | |
| module_map={ | |
| torch.nn.Linear: AutoWrappedLinear, | |
| torch.nn.Conv3d: AutoWrappedModule, | |
| torch.nn.LayerNorm: AutoWrappedModule, | |
| RMSNorm: AutoWrappedModule, | |
| }, | |
| module_config=dict( | |
| offload_dtype=dtype, | |
| offload_device="cpu", | |
| onload_dtype=dtype, | |
| onload_device=self.device, | |
| computation_dtype=self.torch_dtype, | |
| computation_device=self.device, | |
| ), | |
| max_num_param=num_persistent_param_in_dit, | |
| overflow_module_config=dict( | |
| offload_dtype=dtype, | |
| offload_device="cpu", | |
| onload_dtype=dtype, | |
| onload_device="cpu", | |
| computation_dtype=self.torch_dtype, | |
| computation_device=self.device, | |
| ), | |
| ) | |
| self.enable_cpu_offload() | |
| def fetch_models(self, model_manager: ModelManager): | |
| self.dit = model_manager.fetch_model("wan_video_dit") | |
| self.vae = model_manager.fetch_model("wan_video_vae") | |
| @staticmethod | |
| def from_model_manager( | |
| model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False | |
| ): | |
| if device is None: | |
| device = model_manager.device | |
| if torch_dtype is None: | |
| torch_dtype = model_manager.torch_dtype | |
| pipe = FlashVSRTinyPipeline(device=device, torch_dtype=torch_dtype) | |
| pipe.fetch_models(model_manager) | |
| # 可选:统一序列并行入口(此处默认关闭) | |
| pipe.use_unified_sequence_parallel = False | |
| return pipe | |
| def denoising_model(self): | |
| return self.dit | |
| # ------------------------- | |
| # 新增:显式 KV 预初始化函数 | |
| # ------------------------- | |
| def init_cross_kv( | |
| self, context_tensor: Optional[torch.Tensor] = None, prompt_path=None | |
| ): | |
| self.load_models_to_device(["dit"]) | |
| """ | |
| 使用固定 prompt 生成文本 context,并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。 | |
| 必须在 __call__ 前显式调用一次。 | |
| """ | |
| # prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth" | |
| if self.dit is None: | |
| raise RuntimeError( | |
| "请先通过 fetch_models / from_model_manager 初始化 self.dit" | |
| ) | |
| if context_tensor is None: | |
| if prompt_path is None: | |
| raise ValueError( | |
| "init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一" | |
| ) | |
| ctx = torch.load(prompt_path, map_location=self.device) | |
| else: | |
| ctx = context_tensor | |
| ctx = ctx.to(dtype=self.torch_dtype, device=self.device) | |
| if self.prompt_emb_posi is None: | |
| self.prompt_emb_posi = {} | |
| self.prompt_emb_posi["context"] = ctx | |
| self.prompt_emb_posi["stats"] = "load" | |
| if hasattr(self.dit, "reinit_cross_kv"): | |
| self.dit.reinit_cross_kv(ctx) | |
| else: | |
| raise AttributeError( | |
| "WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。" | |
| ) | |
| self.timestep = torch.tensor( | |
| [1000.0], device=self.device, dtype=self.torch_dtype | |
| ) | |
| self.t = self.dit.time_embedding( | |
| sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep) | |
| ) | |
| self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim)) | |
| # Scheduler | |
| self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0) | |
| self.load_models_to_device([]) | |
| def prepare_unified_sequence_parallel(self): | |
| return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel} | |
| def prepare_extra_input(self, latents=None): | |
| return {} | |
| def encode_video( | |
| self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16) | |
| ): | |
| latents = self.vae.encode( | |
| input_video, | |
| device=self.device, | |
| tiled=tiled, | |
| tile_size=tile_size, | |
| tile_stride=tile_stride, | |
| ) | |
| return latents | |
| def _decode_video( | |
| self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16) | |
| ): | |
| frames = self.vae.decode( | |
| latents, | |
| device=self.device, | |
| tiled=tiled, | |
| tile_size=tile_size, | |
| tile_stride=tile_stride, | |
| ) | |
| return frames | |
| def decode_video(self, latents, cond=None, **kwargs): | |
| frames = ( | |
| self.TCDecoder.decode_video( | |
| latents.transpose(1, 2), # TCDecoder 需要 (B, F, C, H, W) | |
| parallel=False, | |
| show_progress_bar=False, | |
| cond=cond, | |
| ) | |
| .transpose(1, 2) | |
| .mul_(2) | |
| .sub_(1) | |
| ) # 转回 (B, C, F, H, W) 格式,范围 -1 to 1 | |
| return frames | |
| def offload_model(self, keep_vae=False): | |
| self.dit.clear_cross_kv() | |
| self.prompt_emb_posi["stats"] = "offload" | |
| self.load_models_to_device([]) | |
| if hasattr(self.dit, "LQ_proj_in"): | |
| self.dit.LQ_proj_in.to("cpu") | |
| if not keep_vae: | |
| self.TCDecoder.to("cpu") | |
| @torch.no_grad() | |
| def __call__( | |
| self, | |
| prompt=None, | |
| negative_prompt="", | |
| denoising_strength=1.0, | |
| seed=None, | |
| rand_device="gpu", | |
| height=480, | |
| width=832, | |
| num_frames=81, | |
| cfg_scale=5.0, | |
| num_inference_steps=50, | |
| sigma_shift=5.0, | |
| tiled=True, | |
| tile_size=(60, 104), | |
| tile_stride=(30, 52), | |
| tea_cache_l1_thresh=None, | |
| tea_cache_model_id="Wan2.1-T2V-1.3B", | |
| progress_bar_cmd=tqdm, | |
| progress_bar_st=None, | |
| LQ_video=None, | |
| is_full_block=False, | |
| if_buffer=False, | |
| topk_ratio=2.0, | |
| kv_ratio=3.0, | |
| local_range=9, | |
| color_fix=True, | |
| unload_dit=False, | |
| force_offload=False, | |
| **kwargs, | |
| ): | |
| # 只接受 cfg=1.0(与原代码一致) | |
| assert cfg_scale == 1.0, "cfg_scale must be 1.0" | |
| # 要求:必须先 init_cross_kv() | |
| if self.prompt_emb_posi is None or "context" not in self.prompt_emb_posi: | |
| raise RuntimeError( | |
| "Cross-Attention KV not initialized. Please call __call__ only after:\n" | |
| " pipe.init_cross_kv()\n" | |
| "Or provide a custom context:\n" | |
| " pipe.init_cross_kv(context_tensor=your_context_tensor)" | |
| ) | |
| # 尺寸修正 | |
| height, width = self.check_resize_height_width(height, width) | |
| if num_frames % 4 != 1: | |
| num_frames = (num_frames + 2) // 4 * 4 + 1 | |
| print( | |
| f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}." | |
| ) | |
| # Tiler 参数 | |
| tiler_kwargs = { | |
| "tiled": tiled, | |
| "tile_size": tile_size, | |
| "tile_stride": tile_stride, | |
| } | |
| # 初始化噪声 | |
| if if_buffer: | |
| noise = self.generate_noise( | |
| (1, 16, (num_frames - 1) // 4, height // 8, width // 8), | |
| seed=seed, | |
| device=self.device, | |
| dtype=self.torch_dtype, | |
| ) | |
| else: | |
| noise = self.generate_noise( | |
| (1, 16, (num_frames - 1) // 4 + 1, height // 8, width // 8), | |
| seed=seed, | |
| device=self.device, | |
| dtype=self.torch_dtype, | |
| ) | |
| # noise = noise.to(dtype=self.torch_dtype, device=self.device) | |
| latents = noise | |
| process_total_num = (num_frames - 1) // 8 - 2 | |
| is_stream = True | |
| if self.prompt_emb_posi["stats"] == "offload": | |
| self.init_cross_kv(context_tensor=self.prompt_emb_posi["context"]) | |
| self.load_models_to_device(["dit"]) | |
| self.dit.LQ_proj_in.to(self.device) | |
| self.TCDecoder.to(self.device) | |
| # 清理可能存在的 LQ_proj_in cache | |
| if hasattr(self.dit, "LQ_proj_in"): | |
| self.dit.LQ_proj_in.clear_cache() | |
| latents_total = [] | |
| self.TCDecoder.clean_mem() | |
| LQ_pre_idx = 0 | |
| LQ_cur_idx = 0 | |
| with torch.no_grad(): | |
| for cur_process_idx in progress_bar_cmd(range(process_total_num)): | |
| if cur_process_idx == 0: | |
| pre_cache_k = [None] * len(self.dit.blocks) | |
| pre_cache_v = [None] * len(self.dit.blocks) | |
| LQ_latents = None | |
| inner_loop_num = 7 | |
| for inner_idx in range(inner_loop_num): | |
| cur = ( | |
| self.denoising_model().LQ_proj_in.stream_forward( | |
| LQ_video[ | |
| :, | |
| :, | |
| max(0, inner_idx * 4 - 3) : (inner_idx + 1) * 4 - 3, | |
| :, | |
| :, | |
| ] | |
| ) | |
| if LQ_video is not None | |
| else None | |
| ) | |
| if cur is None: | |
| continue | |
| if LQ_latents is None: | |
| LQ_latents = cur | |
| else: | |
| for layer_idx in range(len(LQ_latents)): | |
| LQ_latents[layer_idx] = torch.cat( | |
| [LQ_latents[layer_idx], cur[layer_idx]], dim=1 | |
| ) | |
| LQ_cur_idx = (inner_loop_num - 1) * 4 - 3 | |
| cur_latents = latents[:, :, :6, :, :] | |
| else: | |
| LQ_latents = None | |
| inner_loop_num = 2 | |
| for inner_idx in range(inner_loop_num): | |
| cur = ( | |
| self.denoising_model().LQ_proj_in.stream_forward( | |
| LQ_video[ | |
| :, | |
| :, | |
| cur_process_idx * 8 | |
| + 17 | |
| + inner_idx * 4 : cur_process_idx * 8 | |
| + 21 | |
| + inner_idx * 4, | |
| :, | |
| :, | |
| ] | |
| ) | |
| if LQ_video is not None | |
| else None | |
| ) | |
| if cur is None: | |
| continue | |
| if LQ_latents is None: | |
| LQ_latents = cur | |
| else: | |
| for layer_idx in range(len(LQ_latents)): | |
| LQ_latents[layer_idx] = torch.cat( | |
| [LQ_latents[layer_idx], cur[layer_idx]], dim=1 | |
| ) | |
| LQ_cur_idx = cur_process_idx * 8 + 21 + (inner_loop_num - 2) * 4 | |
| cur_latents = latents[ | |
| :, :, 4 + cur_process_idx * 2 : 6 + cur_process_idx * 2, :, : | |
| ] | |
| # 推理(无 motion_controller / vace) | |
| noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video( | |
| self.dit, | |
| x=cur_latents, | |
| timestep=self.timestep, | |
| context=None, | |
| tea_cache=None, | |
| use_unified_sequence_parallel=False, | |
| LQ_latents=LQ_latents, | |
| is_full_block=is_full_block, | |
| is_stream=is_stream, | |
| pre_cache_k=pre_cache_k, | |
| pre_cache_v=pre_cache_v, | |
| topk_ratio=topk_ratio, | |
| kv_ratio=kv_ratio, | |
| cur_process_idx=cur_process_idx, | |
| t_mod=self.t_mod, | |
| t=self.t, | |
| local_range=local_range, | |
| ) | |
| # 更新 latent | |
| cur_latents = cur_latents - noise_pred_posi | |
| latents_total.append(cur_latents) | |
| LQ_pre_idx = LQ_cur_idx | |
| if hasattr(self.dit, "LQ_proj_in"): | |
| self.dit.LQ_proj_in.clear_cache() | |
| if ( | |
| unload_dit | |
| and hasattr(self, "dit") | |
| and not next(self.dit.parameters()).is_cpu | |
| ): | |
| print("[FlashVSR] Offloading DiT to the CPU to free up VRAM...") | |
| self.offload_model(keep_vae=True) | |
| latents = torch.cat(latents_total, dim=2) | |
| # Decode | |
| print("[FlashVSR] Starting VAE decoding...") | |
| frames = ( | |
| self.TCDecoder.decode_video( | |
| latents.transpose(1, 2), | |
| parallel=False, | |
| show_progress_bar=False, | |
| cond=LQ_video[:, :, :LQ_cur_idx, :, :], | |
| ) | |
| .transpose(1, 2) | |
| .mul_(2) | |
| .sub_(1) | |
| ) | |
| self.TCDecoder.clean_mem() | |
| if force_offload: | |
| self.offload_model() | |
| # 颜色校正(wavelet) | |
| try: | |
| if color_fix: | |
| frames = self.ColorCorrector( | |
| frames.to(device=LQ_video.device), | |
| LQ_video[:, :, : frames.shape[2], :, :], | |
| clip_range=(-1, 1), | |
| chunk_size=16, | |
| method="adain", | |
| ) | |
| except: | |
| pass | |
| return frames[0] | |
| # ----------------------------- | |
| # TeaCache(保留原逻辑;此处默认不启用) | |
| # ----------------------------- | |
| class TeaCache: | |
| def __init__(self, num_inference_steps, rel_l1_thresh, model_id): | |
| self.num_inference_steps = num_inference_steps | |
| self.step = 0 | |
| self.accumulated_rel_l1_distance = 0 | |
| self.previous_modulated_input = None | |
| self.rel_l1_thresh = rel_l1_thresh | |
| self.previous_residual = None | |
| self.previous_hidden_states = None | |
| self.coefficients_dict = { | |
| "Wan2.1-T2V-1.3B": [ | |
| -5.21862437e04, | |
| 9.23041404e03, | |
| -5.28275948e02, | |
| 1.36987616e01, | |
| -4.99875664e-02, | |
| ], | |
| "Wan2.1-T2V-14B": [ | |
| -3.03318725e05, | |
| 4.90537029e04, | |
| -2.65530556e03, | |
| 5.87365115e01, | |
| -3.15583525e-01, | |
| ], | |
| "Wan2.1-I2V-14B-480P": [ | |
| 2.57151496e05, | |
| -3.54229917e04, | |
| 1.40286849e03, | |
| -1.35890334e01, | |
| 1.32517977e-01, | |
| ], | |
| "Wan2.1-I2V-14B-720P": [ | |
| 8.10705460e03, | |
| 2.13393892e03, | |
| -3.72934672e02, | |
| 1.66203073e01, | |
| -4.17769401e-02, | |
| ], | |
| } | |
| if model_id not in self.coefficients_dict: | |
| supported_model_ids = ", ".join([i for i in self.coefficients_dict]) | |
| raise ValueError( | |
| f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids})." | |
| ) | |
| self.coefficients = self.coefficients_dict[model_id] | |
| def check(self, dit: WanModel, x, t_mod): | |
| modulated_inp = t_mod.clone() | |
| if self.step == 0 or self.step == self.num_inference_steps - 1: | |
| should_calc = True | |
| self.accumulated_rel_l1_distance = 0 | |
| else: | |
| coefficients = self.coefficients | |
| rescale_func = np.poly1d(coefficients) | |
| self.accumulated_rel_l1_distance += rescale_func( | |
| ( | |
| (modulated_inp - self.previous_modulated_input).abs().mean() | |
| / self.previous_modulated_input.abs().mean() | |
| ) | |
| .cpu() | |
| .item() | |
| ) | |
| should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh) | |
| if should_calc: | |
| self.accumulated_rel_l1_distance = 0 | |
| self.previous_modulated_input = modulated_inp | |
| self.step = (self.step + 1) % self.num_inference_steps | |
| if should_calc: | |
| self.previous_hidden_states = x.clone() | |
| return not should_calc | |
| def store(self, hidden_states): | |
| self.previous_residual = hidden_states - self.previous_hidden_states | |
| self.previous_hidden_states = None | |
| def update(self, hidden_states): | |
| hidden_states = hidden_states + self.previous_residual | |
| return hidden_states | |
| # ----------------------------- | |
| # 简化版模型前向封装(无 vace / 无 motion_controller) | |
| # ----------------------------- | |
| def model_fn_wan_video( | |
| dit: WanModel, | |
| x: torch.Tensor, | |
| timestep: torch.Tensor, | |
| context: torch.Tensor, | |
| tea_cache: Optional[TeaCache] = None, | |
| use_unified_sequence_parallel: bool = False, | |
| LQ_latents: Optional[torch.Tensor] = None, | |
| is_full_block: bool = False, | |
| is_stream: bool = False, | |
| pre_cache_k: Optional[list[torch.Tensor]] = None, | |
| pre_cache_v: Optional[list[torch.Tensor]] = None, | |
| topk_ratio: float = 2.0, | |
| kv_ratio: float = 3.0, | |
| cur_process_idx: int = 0, | |
| t_mod: torch.Tensor = None, | |
| t: torch.Tensor = None, | |
| local_range: int = 9, | |
| **kwargs, | |
| ): | |
| # patchify | |
| x, (f, h, w) = dit.patchify(x) | |
| win = (2, 8, 8) | |
| seqlen = f // win[0] | |
| local_num = seqlen | |
| window_size = win[0] * h * w // 128 | |
| square_num = window_size * window_size | |
| topk = int(square_num * topk_ratio) - 1 | |
| kv_len = int(kv_ratio) | |
| # RoPE 位置(分段) | |
| if cur_process_idx == 0: | |
| freqs = ( | |
| torch.cat( | |
| [ | |
| dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), | |
| dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), | |
| dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), | |
| ], | |
| dim=-1, | |
| ) | |
| .reshape(f * h * w, 1, -1) | |
| .to(x.device) | |
| ) | |
| else: | |
| freqs = ( | |
| torch.cat( | |
| [ | |
| dit.freqs[0][4 + cur_process_idx * 2 : 4 + cur_process_idx * 2 + f] | |
| .view(f, 1, 1, -1) | |
| .expand(f, h, w, -1), | |
| dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), | |
| dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), | |
| ], | |
| dim=-1, | |
| ) | |
| .reshape(f * h * w, 1, -1) | |
| .to(x.device) | |
| ) | |
| # TeaCache(默认不启用) | |
| tea_cache_update = ( | |
| tea_cache.check(dit, x, t_mod) if tea_cache is not None else False | |
| ) | |
| # 统一序列并行(此处默认关闭) | |
| if use_unified_sequence_parallel: | |
| import torch.distributed as dist | |
| from xfuser.core.distributed import ( | |
| get_sequence_parallel_rank, | |
| get_sequence_parallel_world_size, | |
| get_sp_group, | |
| ) | |
| if dist.is_initialized() and dist.get_world_size() > 1: | |
| x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[ | |
| get_sequence_parallel_rank() | |
| ] | |
| # Block 堆叠 | |
| if tea_cache_update: | |
| x = tea_cache.update(x) | |
| else: | |
| for block_id, block in enumerate(dit.blocks): | |
| if LQ_latents is not None and block_id < len(LQ_latents): | |
| x = x + LQ_latents[block_id] | |
| x, last_pre_cache_k, last_pre_cache_v = block( | |
| x, | |
| context, | |
| t_mod, | |
| freqs, | |
| f, | |
| h, | |
| w, | |
| local_num, | |
| topk, | |
| block_id=block_id, | |
| kv_len=kv_len, | |
| is_full_block=is_full_block, | |
| is_stream=is_stream, | |
| pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None, | |
| pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None, | |
| local_range=local_range, | |
| ) | |
| if pre_cache_k is not None: | |
| pre_cache_k[block_id] = last_pre_cache_k | |
| if pre_cache_v is not None: | |
| pre_cache_v[block_id] = last_pre_cache_v | |
| x = dit.head(x, t) | |
| if use_unified_sequence_parallel: | |
| import torch.distributed as dist | |
| from xfuser.core.distributed import get_sp_group | |
| if dist.is_initialized() and dist.get_world_size() > 1: | |
| x = get_sp_group().all_gather(x, dim=1) | |
| x = dit.unpatchify(x, (f, h, w)) | |
| return x, pre_cache_k, pre_cache_v | |
| try: | |
| import flash_attn_interface | |
| FLASH_ATTN_3_AVAILABLE = True | |
| except ModuleNotFoundError: | |
| FLASH_ATTN_3_AVAILABLE = False | |
| try: | |
| import flash_attn | |
| FLASH_ATTN_2_AVAILABLE = True | |
| except ModuleNotFoundError: | |
| FLASH_ATTN_2_AVAILABLE = False | |
| try: | |
| from sageattention import sageattn | |
| SAGE_ATTN_AVAILABLE = True | |
| except ModuleNotFoundError: | |
| SAGE_ATTN_AVAILABLE = False | |
| try: | |
| from block_sparse_attn import block_sparse_attn_func | |
| BLOCK_ATTN_AVAILABLE = True | |
| except: | |
| BLOCK_ATTN_AVAILABLE = False | |
| @triton.jit | |
| def quant_per_block_int8_kernel( | |
| Input, | |
| Output, | |
| Scale, | |
| L, | |
| stride_iz, | |
| stride_ih, | |
| stride_in, | |
| stride_oz, | |
| stride_oh, | |
| stride_on, | |
| stride_sz, | |
| stride_sh, | |
| sm_scale, | |
| C: tl.constexpr, | |
| BLK: tl.constexpr, | |
| ): | |
| off_blk = tl.program_id(0) | |
| off_h = tl.program_id(1) | |
| off_b = tl.program_id(2) | |
| offs_n = off_blk * BLK + tl.arange(0, BLK) | |
| offs_k = tl.arange(0, C) | |
| input_ptrs = ( | |
| Input | |
| + off_b * stride_iz | |
| + off_h * stride_ih | |
| + offs_n[:, None] * stride_in | |
| + offs_k[None, :] | |
| ) | |
| output_ptrs = ( | |
| Output | |
| + off_b * stride_oz | |
| + off_h * stride_oh | |
| + offs_n[:, None] * stride_on | |
| + offs_k[None, :] | |
| ) | |
| scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk | |
| x = tl.load(input_ptrs, mask=offs_n[:, None] < L) | |
| x = x.to(tl.float32) | |
| x *= sm_scale | |
| scale = tl.max(tl.abs(x)) / 127.0 | |
| x_int8 = x / scale | |
| x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) | |
| x_int8 = x_int8.to(tl.int8) | |
| tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) | |
| tl.store(scale_ptrs, scale) | |
| def per_block_int8( | |
| q, k, km=None, BLKQ=128, BLKK=64, sm_scale=None, tensor_layout="HND" | |
| ): | |
| q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) | |
| k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) | |
| if km is not None: | |
| k = k - km | |
| if tensor_layout == "HND": | |
| b, h_qo, qo_len, head_dim = q.shape | |
| _, h_kv, kv_len, _ = k.shape | |
| stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) | |
| stride_bz_qo, stride_h_qo, stride_seq_qo = ( | |
| q_int8.stride(0), | |
| q_int8.stride(1), | |
| q_int8.stride(2), | |
| ) | |
| stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) | |
| stride_bz_ko, stride_h_ko, stride_seq_ko = ( | |
| k_int8.stride(0), | |
| k_int8.stride(1), | |
| k_int8.stride(2), | |
| ) | |
| elif tensor_layout == "NHD": | |
| b, qo_len, h_qo, head_dim = q.shape | |
| _, kv_len, h_kv, _ = k.shape | |
| stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) | |
| stride_bz_qo, stride_h_qo, stride_seq_qo = ( | |
| q_int8.stride(0), | |
| q_int8.stride(2), | |
| q_int8.stride(1), | |
| ) | |
| stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) | |
| stride_bz_ko, stride_h_ko, stride_seq_ko = ( | |
| k_int8.stride(0), | |
| k_int8.stride(2), | |
| k_int8.stride(1), | |
| ) | |
| else: | |
| raise ValueError(f"Unknown tensor layout: {tensor_layout}") | |
| q_scale = torch.empty( | |
| (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32 | |
| ) | |
| k_scale = torch.empty( | |
| (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 | |
| ) | |
| if sm_scale is None: | |
| sm_scale = head_dim**-0.5 | |
| grid = ((qo_len + BLKQ - 1) // BLKQ, h_qo, b) | |
| quant_per_block_int8_kernel[grid]( | |
| q, | |
| q_int8, | |
| q_scale, | |
| qo_len, | |
| stride_bz_q, | |
| stride_h_q, | |
| stride_seq_q, | |
| stride_bz_qo, | |
| stride_h_qo, | |
| stride_seq_qo, | |
| q_scale.stride(0), | |
| q_scale.stride(1), | |
| sm_scale=(sm_scale * 1.44269504), | |
| C=head_dim, | |
| BLK=BLKQ, | |
| ) | |
| grid = ((kv_len + BLKK - 1) // BLKK, h_kv, b) | |
| quant_per_block_int8_kernel[grid]( | |
| k, | |
| k_int8, | |
| k_scale, | |
| kv_len, | |
| stride_bz_k, | |
| stride_h_k, | |
| stride_seq_k, | |
| stride_bz_ko, | |
| stride_h_ko, | |
| stride_seq_ko, | |
| k_scale.stride(0), | |
| k_scale.stride(1), | |
| sm_scale=1.0, | |
| C=head_dim, | |
| BLK=BLKK, | |
| ) | |
| return q_int8, q_scale, k_int8, k_scale | |
| @triton.jit | |
| def _attn_fwd_inner( | |
| acc, | |
| l_i, | |
| old_m, | |
| q, | |
| q_scale, | |
| kv_len, | |
| K_ptrs, | |
| K_bid_ptr, | |
| K_scale_ptr, | |
| V_ptrs, | |
| stride_kn, | |
| stride_vn, | |
| start_m, | |
| BLOCK_M: tl.constexpr, | |
| HEAD_DIM: tl.constexpr, | |
| BLOCK_N: tl.constexpr, | |
| STAGE: tl.constexpr, | |
| offs_m: tl.constexpr, | |
| offs_n: tl.constexpr, | |
| ): | |
| if STAGE == 1: | |
| lo, hi = 0, start_m * BLOCK_M | |
| elif STAGE == 2: | |
| lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M | |
| lo = tl.multiple_of(lo, BLOCK_M) | |
| K_scale_ptr += lo // BLOCK_N | |
| K_ptrs += stride_kn * lo | |
| V_ptrs += stride_vn * lo | |
| elif STAGE == 3: | |
| lo, hi = 0, kv_len | |
| for start_n in range(lo, hi, BLOCK_N): | |
| kbid = tl.load(K_bid_ptr + start_n // BLOCK_N) | |
| if kbid: | |
| k_mask = offs_n[None, :] < (kv_len - start_n) | |
| k = tl.load(K_ptrs, mask=k_mask) | |
| k_scale = tl.load(K_scale_ptr) | |
| qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale | |
| if STAGE == 2: | |
| mask = offs_m[:, None] >= (start_n + offs_n[None, :]) | |
| qk = qk + tl.where(mask, 0, -1.0e6) | |
| local_m = tl.max(qk, 1) | |
| new_m = tl.maximum(old_m, local_m) | |
| qk -= new_m[:, None] | |
| else: | |
| local_m = tl.max(qk, 1) | |
| new_m = tl.maximum(old_m, local_m) | |
| qk = qk - new_m[:, None] | |
| p = tl.math.exp2(qk) | |
| l_ij = tl.sum(p, 1) | |
| alpha = tl.math.exp2(old_m - new_m) | |
| l_i = l_i * alpha + l_ij | |
| acc = acc * alpha[:, None] | |
| v = tl.load(V_ptrs, mask=offs_n[:, None] < (kv_len - start_n)) | |
| p = p.to(tl.float16) | |
| acc += tl.dot(p, v, out_dtype=tl.float16) | |
| old_m = new_m | |
| K_ptrs += BLOCK_N * stride_kn | |
| K_scale_ptr += 1 | |
| V_ptrs += BLOCK_N * stride_vn | |
| return acc, l_i, old_m | |
| @triton.jit | |
| def _attn_fwd( | |
| Q, | |
| K, | |
| K_blkid, | |
| V, | |
| Q_scale, | |
| K_scale, | |
| Out, | |
| stride_qz, | |
| stride_qh, | |
| stride_qn, | |
| stride_kz, | |
| stride_kh, | |
| stride_kn, | |
| stride_vz, | |
| stride_vh, | |
| stride_vn, | |
| stride_oz, | |
| stride_oh, | |
| stride_on, | |
| stride_kbidq, | |
| stride_kbidk, | |
| qo_len, | |
| kv_len, | |
| H: tl.constexpr, | |
| num_kv_groups: tl.constexpr, | |
| HEAD_DIM: tl.constexpr, | |
| BLOCK_M: tl.constexpr, | |
| BLOCK_N: tl.constexpr, | |
| STAGE: tl.constexpr, | |
| ): | |
| start_m = tl.program_id(0) | |
| off_z = tl.program_id(2).to(tl.int64) | |
| off_h = tl.program_id(1).to(tl.int64) | |
| q_scale_offset = (off_z * H + off_h) * tl.cdiv(qo_len, BLOCK_M) | |
| k_scale_offset = (off_z * (H // num_kv_groups) + off_h // num_kv_groups) * tl.cdiv( | |
| kv_len, BLOCK_N | |
| ) | |
| k_bid_offset = ( | |
| off_z * (H // num_kv_groups) + off_h // num_kv_groups | |
| ) * stride_kbidq | |
| offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
| offs_n = tl.arange(0, BLOCK_N) | |
| offs_k = tl.arange(0, HEAD_DIM) | |
| Q_ptrs = ( | |
| Q | |
| + (off_z * stride_qz + off_h * stride_qh) | |
| + offs_m[:, None] * stride_qn | |
| + offs_k[None, :] | |
| ) | |
| Q_scale_ptr = Q_scale + q_scale_offset + start_m | |
| K_ptrs = ( | |
| K | |
| + (off_z * stride_kz + (off_h // num_kv_groups) * stride_kh) | |
| + offs_n[None, :] * stride_kn | |
| + offs_k[:, None] | |
| ) | |
| K_scale_ptr = K_scale + k_scale_offset | |
| K_bid_ptr = K_blkid + k_bid_offset + start_m * stride_kbidk | |
| V_ptrs = ( | |
| V | |
| + (off_z * stride_vz + (off_h // num_kv_groups) * stride_vh) | |
| + offs_n[:, None] * stride_vn | |
| + offs_k[None, :] | |
| ) | |
| O_block_ptr = ( | |
| Out | |
| + (off_z * stride_oz + off_h * stride_oh) | |
| + offs_m[:, None] * stride_on | |
| + offs_k[None, :] | |
| ) | |
| m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") | |
| l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 | |
| acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) | |
| q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len) | |
| q_scale = tl.load(Q_scale_ptr) | |
| acc, l_i, m_i = _attn_fwd_inner( | |
| acc, | |
| l_i, | |
| m_i, | |
| q, | |
| q_scale, | |
| kv_len, | |
| K_ptrs, | |
| K_bid_ptr, | |
| K_scale_ptr, | |
| V_ptrs, | |
| stride_kn, | |
| stride_vn, | |
| start_m, | |
| BLOCK_M, | |
| HEAD_DIM, | |
| BLOCK_N, | |
| 4 - STAGE, | |
| offs_m, | |
| offs_n, | |
| ) | |
| if STAGE != 1: | |
| acc, l_i, _ = _attn_fwd_inner( | |
| acc, | |
| l_i, | |
| m_i, | |
| q, | |
| q_scale, | |
| kv_len, | |
| K_ptrs, | |
| K_bid_ptr, | |
| K_scale_ptr, | |
| V_ptrs, | |
| stride_kn, | |
| stride_vn, | |
| start_m, | |
| BLOCK_M, | |
| HEAD_DIM, | |
| BLOCK_N, | |
| 2, | |
| offs_m, | |
| offs_n, | |
| ) | |
| acc = acc / l_i[:, None] | |
| tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=(offs_m[:, None] < qo_len)) | |
| def sparse_sageattn_fwd( | |
| q, | |
| k, | |
| k_block_id, | |
| v, | |
| q_scale, | |
| k_scale, | |
| is_causal=False, | |
| tensor_layout="HND", | |
| output_dtype=torch.float16, | |
| ): | |
| BLOCK_M = 128 | |
| BLOCK_N = 64 | |
| stage = 3 if is_causal else 1 | |
| o = torch.empty(q.shape, dtype=output_dtype, device=q.device) | |
| if tensor_layout == "HND": | |
| b, h_qo, qo_len, head_dim = q.shape | |
| _, h_kv, kv_len, _ = k.shape | |
| stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) | |
| stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) | |
| stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(1), v.stride(2) | |
| stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(1), o.stride(2) | |
| elif tensor_layout == "NHD": | |
| b, qo_len, h_qo, head_dim = q.shape | |
| _, kv_len, h_kv, _ = k.shape | |
| stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) | |
| stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) | |
| stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(2), v.stride(1) | |
| stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(2), o.stride(1) | |
| else: | |
| raise ValueError(f"tensor_layout {tensor_layout} not supported") | |
| if is_causal: | |
| assert qo_len == kv_len, "qo_len and kv_len must be equal for causal attention" | |
| HEAD_DIM_K = head_dim | |
| num_kv_groups = h_qo // h_kv | |
| grid = (triton.cdiv(qo_len, BLOCK_M), h_qo, b) | |
| _attn_fwd[grid]( | |
| q, | |
| k, | |
| k_block_id, | |
| v, | |
| q_scale, | |
| k_scale, | |
| o, | |
| stride_bz_q, | |
| stride_h_q, | |
| stride_seq_q, | |
| stride_bz_k, | |
| stride_h_k, | |
| stride_seq_k, | |
| stride_bz_v, | |
| stride_h_v, | |
| stride_seq_v, | |
| stride_bz_o, | |
| stride_h_o, | |
| stride_seq_o, | |
| k_block_id.stride(1), | |
| k_block_id.stride(2), | |
| qo_len, | |
| kv_len, | |
| h_qo, | |
| num_kv_groups, | |
| BLOCK_M=BLOCK_M, | |
| BLOCK_N=BLOCK_N, | |
| HEAD_DIM=HEAD_DIM_K, | |
| STAGE=stage, | |
| num_warps=4 if head_dim == 64 else 8, | |
| num_stages=4, | |
| ) | |
| return o | |
| def sparse_sageattn(q, k, v, mask_id=None, is_causal=False, tensor_layout="HND"): | |
| if mask_id is None: | |
| mask_id = torch.ones( | |
| ( | |
| q.shape[0], | |
| q.shape[1], | |
| (q.shape[2] + 128 - 1) // 128, | |
| (q.shape[3] + 64 - 1) // 64, | |
| ), | |
| dtype=torch.int8, | |
| device=q.device, | |
| ) # TODO | |
| output_dtype = q.dtype | |
| if output_dtype == torch.bfloat16 or output_dtype == torch.float32: | |
| v = v.to(torch.float16) | |
| seq_dim = 1 if tensor_layout == "NHD" else 2 | |
| km = k.mean(dim=seq_dim, keepdim=True) | |
| # km = torch.zeros((k.size(0), k.size(1), 1, k.size(3)), dtype=torch.float16, device=k.device) # Placeholder for mean, not used in quantization | |
| q_int8, q_scale, k_int8, k_scale = per_block_int8( | |
| q, k, km=km, tensor_layout=tensor_layout | |
| ) | |
| o = sparse_sageattn_fwd( | |
| q_int8, | |
| k_int8, | |
| mask_id, | |
| v, | |
| q_scale, | |
| k_scale, | |
| is_causal=is_causal, | |
| tensor_layout=tensor_layout, | |
| output_dtype=output_dtype, | |
| ) | |
| return o | |
| USE_BLOCK_ATTN = False | |
| # ---------------------------- | |
| # Local / window masks | |
| # ---------------------------- | |
| @torch.no_grad() | |
| def build_local_block_mask_shifted_vec( | |
| block_h: int, | |
| block_w: int, | |
| win_h: int = 6, | |
| win_w: int = 6, | |
| include_self: bool = True, | |
| device=None, | |
| ) -> torch.Tensor: | |
| device = device or torch.device("cpu") | |
| H, W = block_h, block_w | |
| r = torch.arange(H, device=device) | |
| c = torch.arange(W, device=device) | |
| YY, XX = torch.meshgrid(r, c, indexing="ij") | |
| r_all = YY.reshape(-1) | |
| c_all = XX.reshape(-1) | |
| r_half = win_h // 2 | |
| c_half = win_w // 2 | |
| start_r = torch.clamp(r_all - r_half, 0, H - win_h) | |
| end_r = start_r + win_h - 1 | |
| start_c = torch.clamp(c_all - c_half, 0, W - win_w) | |
| end_c = start_c + win_w - 1 | |
| in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None]) | |
| in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None]) | |
| mask = in_row & in_col | |
| if not include_self: | |
| mask.fill_diagonal_(False) | |
| return mask | |
| @torch.no_grad() | |
| def build_local_block_mask_shifted_vec_normal_slide( | |
| block_h: int, | |
| block_w: int, | |
| win_h: int = 6, | |
| win_w: int = 6, | |
| include_self: bool = True, | |
| device=None, | |
| ) -> torch.Tensor: | |
| device = device or torch.device("cpu") | |
| H, W = block_h, block_w | |
| r = torch.arange(H, device=device) | |
| c = torch.arange(W, device=device) | |
| YY, XX = torch.meshgrid(r, c, indexing="ij") | |
| r_all = YY.reshape(-1) | |
| c_all = XX.reshape(-1) | |
| r_half = win_h // 2 | |
| c_half = win_w // 2 | |
| start_r = r_all - r_half | |
| end_r = start_r + win_h - 1 | |
| start_c = c_all - c_half | |
| end_c = start_c + win_w - 1 | |
| in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None]) | |
| in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None]) | |
| mask = in_row & in_col | |
| if not include_self: | |
| mask.fill_diagonal_(False) | |
| return mask | |
| class WindowPartition3D: | |
| """Partition / reverse-partition helpers for 5-D tensors (B,F,H,W,C).""" | |
| @staticmethod | |
| def partition(x: torch.Tensor, win: Tuple[int, int, int]): | |
| B, F, H, W, C = x.shape | |
| wf, wh, ww = win | |
| assert ( | |
| F % wf == 0 and H % wh == 0 and W % ww == 0 | |
| ), "Dims must divide by window size." | |
| x = x.view(B, F // wf, wf, H // wh, wh, W // ww, ww, C) | |
| x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous() | |
| return x.view(-1, wf * wh * ww, C) | |
| @staticmethod | |
| def reverse( | |
| windows: torch.Tensor, win: Tuple[int, int, int], orig: Tuple[int, int, int] | |
| ): | |
| F, H, W = orig | |
| wf, wh, ww = win | |
| nf, nh, nw = F // wf, H // wh, W // ww | |
| B = windows.size(0) // (nf * nh * nw) | |
| x = windows.view(B, nf, nh, nw, wf, wh, ww, -1) | |
| x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous() | |
| return x.view(B, F, H, W, -1) | |
| @torch.no_grad() | |
| def generate_draft_block_mask( | |
| batch_size, nheads, seqlen, q_w, k_w, topk=10, local_attn_mask=None | |
| ): | |
| assert batch_size == 1, "Only batch_size=1 supported for now" | |
| assert local_attn_mask is not None, "local_attn_mask must be provided" | |
| avgpool_q = torch.mean(q_w, dim=1) | |
| avgpool_k = torch.mean(k_w, dim=1) | |
| avgpool_q = rearrange(avgpool_q, "s (h d) -> s h d", h=nheads) | |
| avgpool_k = rearrange(avgpool_k, "s (h d) -> s h d", h=nheads) | |
| q_heads = avgpool_q.permute(1, 0, 2) | |
| k_heads = avgpool_k.permute(1, 0, 2) | |
| D = avgpool_q.shape[-1] | |
| scores = torch.einsum("hld,hmd->hlm", q_heads, k_heads) / math.sqrt(D) | |
| repeat_head = scores.shape[0] | |
| repeat_len = scores.shape[1] // local_attn_mask.shape[0] | |
| repeat_num = scores.shape[2] // local_attn_mask.shape[1] | |
| local_attn_mask = ( | |
| local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1) | |
| ) | |
| local_attn_mask = rearrange(local_attn_mask, "x a y b -> (x a) (y b)") | |
| local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1) | |
| local_attn_mask = local_attn_mask.to(torch.float32) | |
| local_attn_mask = local_attn_mask.masked_fill( | |
| local_attn_mask == False, -float("inf") | |
| ) | |
| local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0) | |
| scores = scores + local_attn_mask | |
| attn_map = torch.softmax(scores, dim=-1) | |
| attn_map = rearrange(attn_map, "h (it s1) s2 -> (h it) s1 s2", it=seqlen) | |
| loop_num, s1, s2 = attn_map.shape | |
| flat = attn_map.reshape(loop_num, -1) | |
| n = flat.shape[1] | |
| apply_topk = min(flat.shape[1] - 1, topk) | |
| thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1] | |
| thresholds = thresholds.unsqueeze(1) | |
| mask_new = (flat > thresholds).reshape(loop_num, s1, s2) | |
| mask_new = rearrange( | |
| mask_new, "(h it) s1 s2 -> h (it s1) s2", it=seqlen | |
| ) # keep shape note | |
| # 修正:上行变量名统一 | |
| # mask_new = rearrange(attn_map, 'h (it s1) s2 -> h (it s1) s2', it=seqlen) * 0 + mask_new | |
| mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1) | |
| return mask | |
| @torch.no_grad() | |
| def generate_draft_block_mask_sage( | |
| batch_size, nheads, seqlen, q_w, k_w, topk=10, local_attn_mask=None | |
| ): | |
| assert batch_size == 1, "Only batch_size=1 supported for now" | |
| assert local_attn_mask is not None, "local_attn_mask must be provided" | |
| avgpool_q = torch.mean(q_w, dim=1) | |
| avgpool_q = rearrange(avgpool_q, "s (h d) -> s h d", h=nheads) | |
| q_heads = avgpool_q.permute(1, 0, 2) | |
| D = avgpool_q.shape[-1] | |
| k_w_split = k_w.view(k_w.shape[0], 2, 64, k_w.shape[2]) | |
| avgpool_k_split = torch.mean(k_w_split, dim=2) | |
| avgpool_k_refined = rearrange( | |
| avgpool_k_split, "s two d -> (s two) d", two=2 | |
| ) # shape: (s*2, C) | |
| avgpool_k_refined = rearrange( | |
| avgpool_k_refined, "s (h d) -> s h d", h=nheads | |
| ) # shape: (s*2, h, d) | |
| k_heads_doubled = avgpool_k_refined.permute(1, 0, 2) # shape: (h, s*2, d) | |
| k_heads_1, k_heads_2 = torch.chunk(k_heads_doubled, 2, dim=1) | |
| scores_1 = torch.einsum("hld,hmd->hlm", q_heads, k_heads_1) / math.sqrt(D) | |
| scores_2 = torch.einsum("hld,hmd->hlm", q_heads, k_heads_2) / math.sqrt(D) | |
| scores = torch.cat([scores_1, scores_2], dim=-1) | |
| repeat_head = scores.shape[0] | |
| repeat_len = scores.shape[1] // local_attn_mask.shape[0] | |
| repeat_num = (scores.shape[2] // 2) // local_attn_mask.shape[1] | |
| local_attn_mask = ( | |
| local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1) | |
| ) | |
| local_attn_mask = rearrange(local_attn_mask, "x a y b -> (x a) (y b)") | |
| local_attn_mask = local_attn_mask.repeat_interleave(2, dim=1) | |
| local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1) | |
| assert ( | |
| scores.shape == local_attn_mask.shape | |
| ), f"Scores shape {scores.shape} != Mask shape {local_attn_mask.shape}" | |
| local_attn_mask = local_attn_mask.to(torch.float32) | |
| local_attn_mask = local_attn_mask.masked_fill( | |
| local_attn_mask == False, -float("inf") | |
| ) | |
| local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0) | |
| scores = scores + local_attn_mask | |
| attn_map = torch.softmax(scores, dim=-1) | |
| attn_map = rearrange(attn_map, "h (it s1) s2 -> (h it) s1 s2", it=seqlen) | |
| loop_num, s1, s2 = attn_map.shape | |
| flat = attn_map.reshape(loop_num, -1) | |
| apply_topk = min(flat.shape[1] - 1, topk) | |
| if apply_topk <= 0: | |
| mask_new = torch.zeros_like(flat, dtype=torch.bool).reshape(loop_num, s1, s2) | |
| else: | |
| thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[ | |
| :, -1 | |
| ] | |
| thresholds = thresholds.unsqueeze(1) | |
| mask_new = (flat > thresholds).reshape(loop_num, s1, s2) | |
| mask_new = rearrange(mask_new, "(h it) s1 s2 -> h (it s1) s2", it=seqlen) | |
| mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1) | |
| return mask | |
| # ---------------------------- | |
| # Attention kernels | |
| # ---------------------------- | |
| def flash_attention( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| num_heads: int, | |
| compatibility_mode=False, | |
| attention_mask=None, | |
| return_KV=False, | |
| ): | |
| if attention_mask is not None: | |
| seqlen = q.shape[1] | |
| seqlen_kv = k.shape[1] | |
| if USE_BLOCK_ATTN and BLOCK_ATTN_AVAILABLE: | |
| q = rearrange(q, "b s (n d) -> (b s) n d", n=num_heads) | |
| k = rearrange(k, "b s (n d) -> (b s) n d", n=num_heads) | |
| v = rearrange(v, "b s (n d) -> (b s) n d", n=num_heads) | |
| else: | |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) | |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) | |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) | |
| cu_seqlens_q = torch.tensor([0, seqlen], device=q.device, dtype=torch.int32) | |
| cu_seqlens_k = torch.tensor([0, seqlen_kv], device=q.device, dtype=torch.int32) | |
| head_mask_type = torch.tensor( | |
| [1] * num_heads, device=q.device, dtype=torch.int32 | |
| ) | |
| streaming_info = None | |
| base_blockmask = attention_mask | |
| max_seqlen_q_ = seqlen | |
| max_seqlen_k_ = seqlen_kv | |
| p_dropout = 0.0 | |
| if USE_BLOCK_ATTN and BLOCK_ATTN_AVAILABLE: | |
| x = block_sparse_attn_func( | |
| q, | |
| k, | |
| v, | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| head_mask_type, | |
| streaming_info, | |
| base_blockmask, | |
| max_seqlen_q_, | |
| max_seqlen_k_, | |
| p_dropout, | |
| deterministic=False, | |
| softmax_scale=None, | |
| is_causal=False, | |
| exact_streaming=False, | |
| return_attn_probs=False, | |
| ).unsqueeze(0) | |
| x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) | |
| else: | |
| x = sparse_sageattn( | |
| q, | |
| k, | |
| v, | |
| mask_id=base_blockmask.to(torch.int8), | |
| is_causal=False, | |
| tensor_layout="HND", | |
| ) | |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) | |
| elif compatibility_mode: | |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) | |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) | |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) | |
| x = F.scaled_dot_product_attention(q, k, v) | |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) | |
| elif FLASH_ATTN_3_AVAILABLE: | |
| q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) | |
| k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) | |
| v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) | |
| x = flash_attn_interface.flash_attn_func(q, k, v) | |
| if isinstance(x, tuple): | |
| x = x[0] | |
| x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) | |
| elif FLASH_ATTN_2_AVAILABLE: | |
| q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) | |
| k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) | |
| v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) | |
| x = flash_attn.flash_attn_func(q, k, v) | |
| x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) | |
| elif SAGE_ATTN_AVAILABLE: | |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) | |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) | |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) | |
| x = sageattn(q, k, v) | |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) | |
| else: | |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) | |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) | |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) | |
| x = F.scaled_dot_product_attention(q, k, v) | |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) | |
| return x | |
| def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): | |
| return x * (1 + scale) + shift | |
| def sinusoidal_embedding_1d(dim, position): | |
| sinusoid = torch.outer( | |
| position.type(torch.float64), | |
| torch.pow( | |
| 10000, | |
| -torch.arange(dim // 2, dtype=torch.float64, device=position.device).div( | |
| dim // 2 | |
| ), | |
| ), | |
| ) | |
| x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) | |
| return x.to(position.dtype) | |
| def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0): | |
| f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta) | |
| h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) | |
| w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) | |
| return f_freqs_cis, h_freqs_cis, w_freqs_cis | |
| def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): | |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].double() / dim)) | |
| freqs = torch.outer(torch.arange(end, device=freqs.device), freqs) | |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) | |
| return freqs_cis | |
| def rope_apply(x, freqs, num_heads): | |
| x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) | |
| x_out = torch.view_as_complex( | |
| x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2) | |
| ) | |
| x_out = torch.view_as_real(x_out * freqs).flatten(2) | |
| return x_out.to(x.dtype) | |
| # ---------------------------- | |
| # Norms & Blocks | |
| # ---------------------------- | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim, eps=1e-5): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def norm(self, x): | |
| return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) | |
| def forward(self, x): | |
| dtype = x.dtype | |
| return self.norm(x.float()).to(dtype) * self.weight | |
| class AttentionModule(nn.Module): | |
| def __init__(self, num_heads): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| def forward(self, q, k, v, attention_mask=None): | |
| x = flash_attention( | |
| q=q, k=k, v=v, num_heads=self.num_heads, attention_mask=attention_mask | |
| ) | |
| return x | |
| class SelfAttention(nn.Module): | |
| def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| self.head_dim = dim // num_heads | |
| self.q = nn.Linear(dim, dim) | |
| self.k = nn.Linear(dim, dim) | |
| self.v = nn.Linear(dim, dim) | |
| self.o = nn.Linear(dim, dim) | |
| self.norm_q = RMSNorm(dim, eps=eps) | |
| self.norm_k = RMSNorm(dim, eps=eps) | |
| self.attn = AttentionModule(self.num_heads) | |
| self.local_attn_mask = None | |
| def forward( | |
| self, | |
| x, | |
| freqs, | |
| f=None, | |
| h=None, | |
| w=None, | |
| local_num=None, | |
| topk=None, | |
| train_img=False, | |
| block_id=None, | |
| kv_len=None, | |
| is_full_block=False, | |
| is_stream=False, | |
| pre_cache_k=None, | |
| pre_cache_v=None, | |
| local_range=9, | |
| ): | |
| B, L, D = x.shape | |
| if is_stream and pre_cache_k is not None and pre_cache_v is not None: | |
| assert f == 2, "f must be 2" | |
| if is_stream and (pre_cache_k is None or pre_cache_v is None): | |
| assert f == 6, " start f must be 6" | |
| assert L == f * h * w, "Sequence length mismatch with provided (f,h,w)." | |
| q = self.norm_q(self.q(x)) | |
| k = self.norm_k(self.k(x)) | |
| v = self.v(x) | |
| q = rope_apply(q, freqs, self.num_heads) | |
| k = rope_apply(k, freqs, self.num_heads) | |
| win = (2, 8, 8) | |
| q = q.view(B, f, h, w, D) | |
| k = k.view(B, f, h, w, D) | |
| v = v.view(B, f, h, w, D) | |
| q_w = WindowPartition3D.partition(q, win) | |
| k_w = WindowPartition3D.partition(k, win) | |
| v_w = WindowPartition3D.partition(v, win) | |
| seqlen = f // win[0] | |
| one_len = k_w.shape[0] // B // seqlen | |
| if pre_cache_k is not None and pre_cache_v is not None: | |
| k_w = torch.cat([pre_cache_k, k_w], dim=0) | |
| v_w = torch.cat([pre_cache_v, v_w], dim=0) | |
| block_n = q_w.shape[0] // B | |
| block_s = q_w.shape[1] | |
| block_n_kv = k_w.shape[0] // B | |
| reorder_q = rearrange( | |
| q_w, | |
| "(b block_n) (block_s) d -> b (block_n block_s) d", | |
| block_n=block_n, | |
| block_s=block_s, | |
| ) | |
| reorder_k = rearrange( | |
| k_w, | |
| "(b block_n) (block_s) d -> b (block_n block_s) d", | |
| block_n=block_n_kv, | |
| block_s=block_s, | |
| ) | |
| reorder_v = rearrange( | |
| v_w, | |
| "(b block_n) (block_s) d -> b (block_n block_s) d", | |
| block_n=block_n_kv, | |
| block_s=block_s, | |
| ) | |
| window_size = win[0] * h * w // 128 | |
| if ( | |
| self.local_attn_mask is None | |
| or self.local_attn_mask_h != h // 8 | |
| or self.local_attn_mask_w != w // 8 | |
| or self.local_range != local_range | |
| ): | |
| self.local_attn_mask = build_local_block_mask_shifted_vec_normal_slide( | |
| h // 8, | |
| w // 8, | |
| local_range, | |
| local_range, | |
| include_self=True, | |
| device=k_w.device, | |
| ) | |
| self.local_attn_mask_h = h // 8 | |
| self.local_attn_mask_w = w // 8 | |
| self.local_range = local_range | |
| if USE_BLOCK_ATTN and BLOCK_ATTN_AVAILABLE: | |
| attention_mask = generate_draft_block_mask( | |
| B, | |
| self.num_heads, | |
| seqlen, | |
| q_w, | |
| k_w, | |
| topk=topk, | |
| local_attn_mask=self.local_attn_mask, | |
| ) | |
| else: | |
| attention_mask = generate_draft_block_mask_sage( | |
| B, | |
| self.num_heads, | |
| seqlen, | |
| q_w, | |
| k_w, | |
| topk=topk, | |
| local_attn_mask=self.local_attn_mask, | |
| ) | |
| x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask) | |
| cur_block_n, cur_block_s, _ = k_w.shape | |
| cache_num = cur_block_n // one_len | |
| if cache_num > kv_len: | |
| cache_k = k_w[one_len:, :, :] | |
| cache_v = v_w[one_len:, :, :] | |
| else: | |
| cache_k = k_w | |
| cache_v = v_w | |
| x = rearrange( | |
| x, | |
| "b (block_n block_s) d -> (b block_n) (block_s) d", | |
| block_n=block_n, | |
| block_s=block_s, | |
| ) | |
| x = WindowPartition3D.reverse(x, win, (f, h, w)) | |
| x = x.view(B, f * h * w, D) | |
| if is_stream: | |
| return self.o(x), cache_k, cache_v | |
| return self.o(x) | |
| class CrossAttention(nn.Module): | |
| """ | |
| 仅考虑文本 context;提供持久 KV 缓存。 | |
| """ | |
| def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| self.head_dim = dim // num_heads | |
| self.q = nn.Linear(dim, dim) | |
| self.k = nn.Linear(dim, dim) | |
| self.v = nn.Linear(dim, dim) | |
| self.o = nn.Linear(dim, dim) | |
| self.norm_q = RMSNorm(dim, eps=eps) | |
| self.norm_k = RMSNorm(dim, eps=eps) | |
| self.attn = AttentionModule(self.num_heads) | |
| # 持久缓存 | |
| self.cache_k = None | |
| self.cache_v = None | |
| @torch.no_grad() | |
| def init_cache(self, ctx: torch.Tensor): | |
| """ctx: [B, S_ctx, dim] —— 经过 text_embedding 之后的上下文""" | |
| self.cache_k = self.norm_k(self.k(ctx)) | |
| self.cache_v = self.v(ctx) | |
| def clear_cache(self): | |
| self.cache_k = None | |
| self.cache_v = None | |
| def forward(self, x: torch.Tensor, y: torch.Tensor, is_stream: bool = False): | |
| """ | |
| y 即文本上下文(未做其他分支)。 | |
| """ | |
| q = self.norm_q(self.q(x)) | |
| assert self.cache_k is not None and self.cache_v is not None | |
| k = self.cache_k | |
| v = self.cache_v | |
| x = self.attn(q, k, v) | |
| return self.o(x) | |
| class GateModule(nn.Module): | |
| def __init__( | |
| self, | |
| ): | |
| super().__init__() | |
| def forward(self, x, gate, residual): | |
| return x + gate * residual | |
| class DiTBlock(nn.Module): | |
| def __init__(self, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| self.ffn_dim = ffn_dim | |
| self.self_attn = SelfAttention(dim, num_heads, eps) | |
| self.cross_attn = CrossAttention(dim, num_heads, eps) | |
| self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) | |
| self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) | |
| self.norm3 = nn.LayerNorm(dim, eps=eps) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(dim, ffn_dim), | |
| nn.GELU(approximate="tanh"), | |
| nn.Linear(ffn_dim, dim), | |
| ) | |
| self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) | |
| self.gate = GateModule() | |
| def forward( | |
| self, | |
| x, | |
| context, | |
| t_mod, | |
| freqs, | |
| f, | |
| h, | |
| w, | |
| local_num=None, | |
| topk=None, | |
| train_img=False, | |
| block_id=None, | |
| kv_len=None, | |
| is_full_block=False, | |
| is_stream=False, | |
| pre_cache_k=None, | |
| pre_cache_v=None, | |
| local_range=9, | |
| ): | |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( | |
| self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod | |
| ).chunk(6, dim=1) | |
| input_x = modulate(self.norm1(x), shift_msa, scale_msa) | |
| self_attn_output, self_attn_cache_k, self_attn_cache_v = self.self_attn( | |
| input_x, | |
| freqs, | |
| f, | |
| h, | |
| w, | |
| local_num, | |
| topk, | |
| train_img, | |
| block_id, | |
| kv_len=kv_len, | |
| is_full_block=is_full_block, | |
| is_stream=is_stream, | |
| pre_cache_k=pre_cache_k, | |
| pre_cache_v=pre_cache_v, | |
| local_range=local_range, | |
| ) | |
| x = self.gate(x, gate_msa, self_attn_output) | |
| x = x + self.cross_attn(self.norm3(x), context, is_stream=is_stream) | |
| input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) | |
| x = self.gate(x, gate_mlp, self.ffn(input_x)) | |
| if is_stream: | |
| return x, self_attn_cache_k, self_attn_cache_v | |
| return x | |
| class MLP(torch.nn.Module): | |
| def __init__(self, in_dim, out_dim, has_pos_emb=False): | |
| super().__init__() | |
| self.proj = torch.nn.Sequential( | |
| nn.LayerNorm(in_dim), | |
| nn.Linear(in_dim, in_dim), | |
| nn.GELU(), | |
| nn.Linear(in_dim, out_dim), | |
| nn.LayerNorm(out_dim), | |
| ) | |
| self.has_pos_emb = has_pos_emb | |
| if has_pos_emb: | |
| self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280))) | |
| def forward(self, x): | |
| if self.has_pos_emb: | |
| x = x + self.emb_pos.to(dtype=x.dtype, device=x.device) | |
| return self.proj(x) | |
| class Head(nn.Module): | |
| def __init__( | |
| self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.patch_size = patch_size | |
| self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) | |
| self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) | |
| self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) | |
| def forward(self, x, t_mod): | |
| shift, scale = ( | |
| self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod | |
| ).chunk(2, dim=1) | |
| x = self.head(self.norm(x) * (1 + scale) + shift) | |
| return x | |
| # ---------------------------- | |
| # State dict converter(保持原映射;已忽略 has_image_input 使用) | |
| # ---------------------------- | |
| class WanModelStateDictConverter: | |
| def __init__(self): | |
| pass | |
| def from_diffusers(self, state_dict): | |
| rename_dict = { | |
| "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight", | |
| "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight", | |
| "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias", | |
| "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight", | |
| "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias", | |
| "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight", | |
| "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias", | |
| "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight", | |
| "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias", | |
| "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight", | |
| "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight", | |
| "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight", | |
| "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias", | |
| "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight", | |
| "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias", | |
| "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight", | |
| "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias", | |
| "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight", | |
| "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias", | |
| "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight", | |
| "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias", | |
| "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight", | |
| "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias", | |
| "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight", | |
| "blocks.0.norm2.bias": "blocks.0.norm3.bias", | |
| "blocks.0.norm2.weight": "blocks.0.norm3.weight", | |
| "blocks.0.scale_shift_table": "blocks.0.modulation", | |
| "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias", | |
| "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight", | |
| "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias", | |
| "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight", | |
| "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias", | |
| "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight", | |
| "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias", | |
| "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight", | |
| "condition_embedder.time_proj.bias": "time_projection.1.bias", | |
| "condition_embedder.time_proj.weight": "time_projection.1.weight", | |
| "patch_embedding.bias": "patch_embedding.bias", | |
| "patch_embedding.weight": "patch_embedding.weight", | |
| "scale_shift_table": "head.modulation", | |
| "proj_out.bias": "head.head.bias", | |
| "proj_out.weight": "head.head.weight", | |
| } | |
| state_dict_ = {} | |
| for name, param in state_dict.items(): | |
| if name in rename_dict: | |
| state_dict_[rename_dict[name]] = param | |
| else: | |
| name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:]) | |
| if name_ in rename_dict: | |
| name_ = rename_dict[name_] | |
| name_ = ".".join( | |
| name_.split(".")[:1] | |
| + [name.split(".")[1]] | |
| + name_.split(".")[2:] | |
| ) | |
| state_dict_[name_] = param | |
| if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b": | |
| config = { | |
| "model_type": "t2v", | |
| "patch_size": (1, 2, 2), | |
| "text_len": 512, | |
| "in_dim": 16, | |
| "dim": 5120, | |
| "ffn_dim": 13824, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 16, | |
| "num_heads": 40, | |
| "num_layers": 40, | |
| "window_size": (-1, -1), | |
| "qk_norm": True, | |
| "cross_attn_norm": True, | |
| "eps": 1e-6, | |
| } | |
| else: | |
| config = {} | |
| return state_dict_, config | |
| def from_civitai(self, state_dict): | |
| state_dict = { | |
| name: param | |
| for name, param in state_dict.items() | |
| if not name.startswith("vace") | |
| } | |
| # 保留原有哈希匹配返回的 config;实现本身不使用 has_image_input 分支 | |
| if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814": | |
| config = { | |
| "has_image_input": False, | |
| "patch_size": [1, 2, 2], | |
| "in_dim": 16, | |
| "dim": 1536, | |
| "ffn_dim": 8960, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 16, | |
| "num_heads": 12, | |
| "num_layers": 30, | |
| "eps": 1e-6, | |
| } | |
| elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70": | |
| config = { | |
| "has_image_input": False, | |
| "patch_size": [1, 2, 2], | |
| "in_dim": 16, | |
| "dim": 5120, | |
| "ffn_dim": 13824, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 16, | |
| "num_heads": 40, | |
| "num_layers": 40, | |
| "eps": 1e-6, | |
| } | |
| elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e": | |
| config = { | |
| "has_image_input": False, | |
| "patch_size": [1, 2, 2], | |
| "in_dim": 36, | |
| "dim": 5120, | |
| "ffn_dim": 13824, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 16, | |
| "num_heads": 40, | |
| "num_layers": 40, | |
| "eps": 1e-6, | |
| } | |
| elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893": | |
| config = { | |
| "has_image_input": False, | |
| "patch_size": [1, 2, 2], | |
| "in_dim": 36, | |
| "dim": 1536, | |
| "ffn_dim": 8960, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 16, | |
| "num_heads": 12, | |
| "num_layers": 30, | |
| "eps": 1e-6, | |
| } | |
| elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677": | |
| config = { | |
| "has_image_input": False, | |
| "patch_size": [1, 2, 2], | |
| "in_dim": 48, | |
| "dim": 1536, | |
| "ffn_dim": 8960, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 16, | |
| "num_heads": 12, | |
| "num_layers": 30, | |
| "eps": 1e-6, | |
| } | |
| elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c": | |
| config = { | |
| "has_image_input": False, | |
| "patch_size": [1, 2, 2], | |
| "in_dim": 48, | |
| "dim": 5120, | |
| "ffn_dim": 13824, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 16, | |
| "num_heads": 40, | |
| "num_layers": 40, | |
| "eps": 1e-6, | |
| } | |
| elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f": | |
| config = { | |
| "has_image_input": False, | |
| "patch_size": [1, 2, 2], | |
| "in_dim": 36, | |
| "dim": 5120, | |
| "ffn_dim": 13824, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 16, | |
| "num_heads": 40, | |
| "num_layers": 40, | |
| "eps": 1e-6, | |
| "has_image_pos_emb": False, | |
| } | |
| else: | |
| config = {} | |
| return state_dict, config | |
| DecoderResult = namedtuple("DecoderResult", ("frame", "memory")) | |
| TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index")) | |
| # ---------------------------- | |
| # Utility / building blocks | |
| # ---------------------------- | |
| class IdentityConv2d(nn.Conv2d): | |
| """Same-shape Conv2d initialized to identity (Dirac).""" | |
| def __init__(self, C, kernel_size=3, bias=False): | |
| pad = kernel_size // 2 | |
| super().__init__(C, C, kernel_size, padding=pad, bias=bias) | |
| with torch.no_grad(): | |
| init.dirac_(self.weight) | |
| if self.bias is not None: | |
| self.bias.zero_() | |
| def conv(n_in, n_out, **kwargs): | |
| return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) | |
| class Clamp(nn.Module): | |
| def forward(self, x): | |
| return torch.tanh(x / 3) * 3 | |
| class MemBlock(nn.Module): | |
| def __init__(self, n_in, n_out): | |
| super().__init__() | |
| self.conv = nn.Sequential( | |
| conv(n_in * 2, n_out), | |
| nn.ReLU(inplace=True), | |
| conv(n_out, n_out), | |
| nn.ReLU(inplace=True), | |
| conv(n_out, n_out), | |
| ) | |
| self.skip = ( | |
| nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() | |
| ) | |
| self.act = nn.ReLU(inplace=True) | |
| def forward(self, x, past): | |
| return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x)) | |
| class TPool(nn.Module): | |
| def __init__(self, n_f, stride): | |
| super().__init__() | |
| self.stride = stride | |
| self.conv = nn.Conv2d(n_f * stride, n_f, 1, bias=False) | |
| def forward(self, x): | |
| _NT, C, H, W = x.shape | |
| return self.conv(x.reshape(-1, self.stride * C, H, W)) | |
| class TGrow(nn.Module): | |
| def __init__(self, n_f, stride): | |
| super().__init__() | |
| self.stride = stride | |
| self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False) | |
| def forward(self, x): | |
| _NT, C, H, W = x.shape | |
| x = self.conv(x) | |
| return x.reshape(-1, C, H, W) | |
| class PixelShuffle3dTAEHV(nn.Module): | |
| def __init__(self, ff, hh, ww): | |
| super().__init__() | |
| self.ff = ff | |
| self.hh = hh | |
| self.ww = ww | |
| def forward(self, x): | |
| # x: (B, C, F, H, W) | |
| B, C, F, H, W = x.shape | |
| if F % self.ff != 0: | |
| first_frame = x[:, :, 0:1, :, :].repeat(1, 1, self.ff - F % self.ff, 1, 1) | |
| x = torch.cat([first_frame, x], dim=2) | |
| return rearrange( | |
| x, | |
| "b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w", | |
| ff=self.ff, | |
| hh=self.hh, | |
| ww=self.ww, | |
| ).transpose(1, 2) | |
| # ---------------------------- | |
| # Generic NTCHW graph executor (kept; used by decoder) | |
| # ---------------------------- | |
| def apply_model_with_memblocks(model, x, parallel, show_progress_bar, mem=None): | |
| """ | |
| Apply a sequential model with memblocks to the given input. | |
| Args: | |
| - model: nn.Sequential of blocks to apply | |
| - x: input data, of dimensions NTCHW | |
| - parallel: if True, parallelize over timesteps (fast but uses O(T) memory) | |
| if False, each timestep will be processed sequentially (slow but uses O(1) memory) | |
| - show_progress_bar: if True, enables tqdm progressbar display | |
| Returns NTCHW tensor of output data. | |
| """ | |
| assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor" | |
| N, T, C, H, W = x.shape | |
| if parallel: | |
| x = x.reshape(N * T, C, H, W) | |
| for b in tqdm(model, disable=not show_progress_bar): | |
| if isinstance(b, MemBlock): | |
| NT, C, H, W = x.shape | |
| T = NT // N | |
| _x = x.reshape(N, T, C, H, W) | |
| mem = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape( | |
| x.shape | |
| ) | |
| x = b(x, mem) | |
| else: | |
| x = b(x) | |
| NT, C, H, W = x.shape | |
| T = NT // N | |
| x = x.view(N, T, C, H, W) | |
| else: | |
| out = [] | |
| work_queue = [ | |
| TWorkItem(xt, 0) | |
| for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1)) | |
| ] | |
| progress_bar = tqdm(range(T), disable=not show_progress_bar) | |
| while work_queue: | |
| xt, i = work_queue.pop(0) | |
| if i == 0: | |
| progress_bar.update(1) | |
| if i == len(model): | |
| out.append(xt) | |
| else: | |
| b = model[i] | |
| if isinstance(b, MemBlock): | |
| if mem[i] is None: | |
| xt_new = b(xt, xt * 0) | |
| mem[i] = xt | |
| else: | |
| xt_new = b(xt, mem[i]) | |
| mem[i].copy_(xt) | |
| work_queue.insert(0, TWorkItem(xt_new, i + 1)) | |
| elif isinstance(b, TPool): | |
| if mem[i] is None: | |
| mem[i] = [] | |
| mem[i].append(xt) | |
| if len(mem[i]) > b.stride: | |
| raise ValueError("TPool internal state invalid.") | |
| elif len(mem[i]) == b.stride: | |
| N_, C_, H_, W_ = xt.shape | |
| xt = b(torch.cat(mem[i], 1).view(N_ * b.stride, C_, H_, W_)) | |
| mem[i] = [] | |
| work_queue.insert(0, TWorkItem(xt, i + 1)) | |
| elif isinstance(b, TGrow): | |
| xt = b(xt) | |
| NT, C_, H_, W_ = xt.shape | |
| for xt_next in reversed( | |
| xt.view(N, b.stride * C_, H_, W_).chunk(b.stride, 1) | |
| ): | |
| work_queue.insert(0, TWorkItem(xt_next, i + 1)) | |
| else: | |
| xt = b(xt) | |
| work_queue.insert(0, TWorkItem(xt, i + 1)) | |
| progress_bar.close() | |
| x = torch.stack(out, 1) | |
| return x, mem | |
| # ---------------------------- | |
| # Decoder-only TAEHV | |
| # ---------------------------- | |
| class TAEHV(nn.Module): | |
| image_channels = 3 | |
| def __init__( | |
| self, | |
| checkpoint_path="taehv.pth", | |
| decoder_time_upscale=(True, True), | |
| decoder_space_upscale=(True, True, True), | |
| channels=[256, 128, 64, 64], | |
| latent_channels=16, | |
| ): | |
| """Initialize TAEHV (decoder-only) with built-in deepening after every ReLU. | |
| Deepening config: how_many_each=1, k=3 (fixed as requested). | |
| """ | |
| super().__init__() | |
| self.latent_channels = latent_channels | |
| n_f = channels | |
| self.frames_to_trim = 2 ** sum(decoder_time_upscale) - 1 | |
| # Build the decoder "skeleton" | |
| base_decoder = nn.Sequential( | |
| Clamp(), | |
| conv(self.latent_channels, n_f[0]), | |
| nn.ReLU(inplace=True), | |
| MemBlock(n_f[0], n_f[0]), | |
| MemBlock(n_f[0], n_f[0]), | |
| MemBlock(n_f[0], n_f[0]), | |
| nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), | |
| TGrow(n_f[0], 1), | |
| conv(n_f[0], n_f[1], bias=False), | |
| MemBlock(n_f[1], n_f[1]), | |
| MemBlock(n_f[1], n_f[1]), | |
| MemBlock(n_f[1], n_f[1]), | |
| nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), | |
| TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), | |
| conv(n_f[1], n_f[2], bias=False), | |
| MemBlock(n_f[2], n_f[2]), | |
| MemBlock(n_f[2], n_f[2]), | |
| MemBlock(n_f[2], n_f[2]), | |
| nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), | |
| TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), | |
| conv(n_f[2], n_f[3], bias=False), | |
| nn.ReLU(inplace=True), | |
| conv(n_f[3], TAEHV.image_channels), | |
| ) | |
| # Inline deepening: insert (IdentityConv2d(k=3) + ReLU) after every ReLU | |
| self.decoder = self._apply_identity_deepen(base_decoder, how_many_each=1, k=3) | |
| self.pixel_shuffle = PixelShuffle3dTAEHV(4, 8, 8) | |
| if checkpoint_path is not None: | |
| missing_keys = self.load_state_dict( | |
| self.patch_tgrow_layers( | |
| torch.load(checkpoint_path, map_location="cpu", weights_only=True) | |
| ), | |
| strict=False, | |
| ) | |
| print("missing_keys", missing_keys) | |
| # Initialize decoder mem state | |
| self.mem = [None] * len(self.decoder) | |
| @staticmethod | |
| def _apply_identity_deepen( | |
| decoder: nn.Sequential, how_many_each=1, k=3 | |
| ) -> nn.Sequential: | |
| """Return a new Sequential where every nn.ReLU is followed by how_many_each*(IdentityConv2d(k)+ReLU).""" | |
| new_layers = [] | |
| for b in decoder: | |
| new_layers.append(b) | |
| if isinstance(b, nn.ReLU): | |
| # Deduce channel count from preceding layer | |
| C = None | |
| if len(new_layers) >= 2 and isinstance(new_layers[-2], nn.Conv2d): | |
| C = new_layers[-2].out_channels | |
| elif len(new_layers) >= 2 and isinstance(new_layers[-2], MemBlock): | |
| C = new_layers[-2].conv[-1].out_channels | |
| if C is not None: | |
| for _ in range(how_many_each): | |
| new_layers.append(IdentityConv2d(C, kernel_size=k, bias=False)) | |
| new_layers.append(nn.ReLU(inplace=True)) | |
| return nn.Sequential(*new_layers) | |
| def patch_tgrow_layers(self, sd): | |
| """Patch TGrow layers to use a smaller kernel if needed (decoder-only).""" | |
| new_sd = self.state_dict() | |
| for i, layer in enumerate(self.decoder): | |
| if isinstance(layer, TGrow): | |
| key = f"decoder.{i}.conv.weight" | |
| if key in sd and sd[key].shape[0] > new_sd[key].shape[0]: | |
| sd[key] = sd[key][-new_sd[key].shape[0] :] | |
| return sd | |
| def decode_video(self, x, parallel=True, show_progress_bar=False, cond=None): | |
| """Decode a sequence of frames from latents. | |
| x: NTCHW latent tensor; returns NTCHW RGB in ~[0, 1]. | |
| """ | |
| trim_flag = self.mem[-8] is None # keeps original relative check | |
| if cond is not None: | |
| x = torch.cat([self.pixel_shuffle(cond), x], dim=2) | |
| x, self.mem = apply_model_with_memblocks( | |
| self.decoder, x, parallel, show_progress_bar, mem=self.mem | |
| ) | |
| if trim_flag: | |
| return x[:, self.frames_to_trim :] | |
| return x | |
| def forward(self, *args, **kwargs): | |
| raise NotImplementedError("Decoder-only model: call decode_video(...) instead.") | |
| def clean_mem(self): | |
| self.mem = [None] * len(self.decoder) | |
| class DotDict(dict): | |
| __getattr__ = dict.__getitem__ | |
| __setattr__ = dict.__setitem__ | |
| class TAEW2_1DiffusersWrapper(nn.Module): | |
| def __init__(self, pretrained_path=None, channels=[256, 128, 64, 64]): | |
| super().__init__() | |
| self.dtype = torch.bfloat16 | |
| self.device = "cuda" | |
| self.taehv = TAEHV(pretrained_path, channels=channels).to(self.dtype) | |
| self.temperal_downsample = [True, True, False] # [sic] | |
| self.config = DotDict( | |
| scaling_factor=1.0, | |
| latents_mean=torch.zeros(16), | |
| z_dim=16, | |
| latents_std=torch.ones(16), | |
| ) | |
| def decode(self, latents, return_dict=None): | |
| n, c, t, h, w = latents.shape | |
| return ( | |
| self.taehv.decode_video(latents.transpose(1, 2), parallel=False) | |
| .transpose(1, 2) | |
| .mul_(2) | |
| .sub_(1), | |
| ) | |
| def stream_decode_with_cond(self, latents, tiled=False, cond=None): | |
| n, c, t, h, w = latents.shape | |
| return ( | |
| self.taehv.decode_video(latents.transpose(1, 2), parallel=False, cond=cond) | |
| .transpose(1, 2) | |
| .mul_(2) | |
| .sub_(1) | |
| ) | |
| def clean_mem(self): | |
| self.taehv.clean_mem() | |
| # ---------------------------- | |
| # Simplified builder (no small, no transplant, no post-hoc deepening) | |
| # ---------------------------- | |
| def build_tcdecoder( | |
| new_channels=[512, 256, 128, 128], | |
| device="cuda", | |
| dtype=torch.bfloat16, | |
| new_latent_channels=None, | |
| ): | |
| """ | |
| 构建“更宽”的 decoder;深度增强(IdentityConv2d+ReLU)已在 TAEHV 内部完成。 | |
| - 不创建 small / 不做移植 | |
| - base_ckpt_path 参数保留但不使用(接口兼容) | |
| 返回:big (单个模型) | |
| """ | |
| if new_latent_channels is not None: | |
| big = ( | |
| TAEHV( | |
| checkpoint_path=None, | |
| channels=new_channels, | |
| latent_channels=new_latent_channels, | |
| ) | |
| .to(device) | |
| .to(dtype) | |
| .train() | |
| ) | |
| else: | |
| big = ( | |
| TAEHV(checkpoint_path=None, channels=new_channels) | |
| .to(device) | |
| .to(dtype) | |
| .train() | |
| ) | |
| big.clean_mem() | |
| return big | |
| CACHE_T = 2 | |
| @contextmanager | |
| def init_weights_on_device(device=torch.device("meta"), include_buffers: bool = False): | |
| old_register_parameter = torch.nn.Module.register_parameter | |
| if include_buffers: | |
| old_register_buffer = torch.nn.Module.register_buffer | |
| def register_empty_parameter(module, name, param): | |
| old_register_parameter(module, name, param) | |
| if param is not None: | |
| param_cls = type(module._parameters[name]) | |
| kwargs = module._parameters[name].__dict__ | |
| kwargs["requires_grad"] = param.requires_grad | |
| module._parameters[name] = param_cls( | |
| module._parameters[name].to(device), **kwargs | |
| ) | |
| def register_empty_buffer(module, name, buffer, persistent=True): | |
| old_register_buffer(module, name, buffer, persistent=persistent) | |
| if buffer is not None: | |
| module._buffers[name] = module._buffers[name].to(device) | |
| def patch_tensor_constructor(fn): | |
| def wrapper(*args, **kwargs): | |
| kwargs["device"] = device | |
| return fn(*args, **kwargs) | |
| return wrapper | |
| if include_buffers: | |
| tensor_constructors_to_patch = { | |
| torch_function_name: getattr(torch, torch_function_name) | |
| for torch_function_name in ["empty", "zeros", "ones", "full"] | |
| } | |
| else: | |
| tensor_constructors_to_patch = {} | |
| try: | |
| torch.nn.Module.register_parameter = register_empty_parameter | |
| if include_buffers: | |
| torch.nn.Module.register_buffer = register_empty_buffer | |
| for torch_function_name in tensor_constructors_to_patch.keys(): | |
| setattr( | |
| torch, | |
| torch_function_name, | |
| patch_tensor_constructor(getattr(torch, torch_function_name)), | |
| ) | |
| yield | |
| finally: | |
| torch.nn.Module.register_parameter = old_register_parameter | |
| if include_buffers: | |
| torch.nn.Module.register_buffer = old_register_buffer | |
| for ( | |
| torch_function_name, | |
| old_torch_function, | |
| ) in tensor_constructors_to_patch.items(): | |
| setattr(torch, torch_function_name, old_torch_function) | |
| def load_state_dict_from_folder(file_path, torch_dtype=None): | |
| state_dict = {} | |
| for file_name in os.listdir(file_path): | |
| if "." in file_name and file_name.split(".")[-1] in [ | |
| "safetensors", | |
| "bin", | |
| "ckpt", | |
| "pth", | |
| "pt", | |
| ]: | |
| state_dict.update( | |
| load_state_dict( | |
| os.path.join(file_path, file_name), torch_dtype=torch_dtype | |
| ) | |
| ) | |
| return state_dict | |
| def load_state_dict(file_path, torch_dtype=None): | |
| if file_path.endswith(".safetensors"): | |
| return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype) | |
| else: | |
| return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype) | |
| def load_state_dict_from_safetensors(file_path, torch_dtype=None): | |
| state_dict = {} | |
| with safe_open(file_path, framework="pt", device="cpu") as f: | |
| for k in f.keys(): | |
| state_dict[k] = f.get_tensor(k) | |
| if torch_dtype is not None: | |
| state_dict[k] = state_dict[k].to(torch_dtype) | |
| return state_dict | |
| def load_state_dict_from_bin(file_path, torch_dtype=None): | |
| state_dict = torch.load(file_path, map_location="cpu", weights_only=True) | |
| if torch_dtype is not None: | |
| for i in state_dict: | |
| if isinstance(state_dict[i], torch.Tensor): | |
| state_dict[i] = state_dict[i].to(torch_dtype) | |
| return state_dict | |
| def search_for_embeddings(state_dict): | |
| embeddings = [] | |
| for k in state_dict: | |
| if isinstance(state_dict[k], torch.Tensor): | |
| embeddings.append(state_dict[k]) | |
| elif isinstance(state_dict[k], dict): | |
| embeddings += search_for_embeddings(state_dict[k]) | |
| return embeddings | |
| def search_parameter(param, state_dict): | |
| for name, param_ in state_dict.items(): | |
| if param.numel() == param_.numel(): | |
| if param.shape == param_.shape: | |
| if torch.dist(param, param_) < 1e-3: | |
| return name | |
| else: | |
| if torch.dist(param.flatten(), param_.flatten()) < 1e-3: | |
| return name | |
| return None | |
| def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False): | |
| matched_keys = set() | |
| with torch.no_grad(): | |
| for name in source_state_dict: | |
| rename = search_parameter(source_state_dict[name], target_state_dict) | |
| if rename is not None: | |
| print(f'"{name}": "{rename}",') | |
| matched_keys.add(rename) | |
| elif ( | |
| split_qkv | |
| and len(source_state_dict[name].shape) >= 1 | |
| and source_state_dict[name].shape[0] % 3 == 0 | |
| ): | |
| length = source_state_dict[name].shape[0] // 3 | |
| rename = [] | |
| for i in range(3): | |
| rename.append( | |
| search_parameter( | |
| source_state_dict[name][i * length : i * length + length], | |
| target_state_dict, | |
| ) | |
| ) | |
| if None not in rename: | |
| print(f'"{name}": {rename},') | |
| for rename_ in rename: | |
| matched_keys.add(rename_) | |
| for name in target_state_dict: | |
| if name not in matched_keys: | |
| print("Cannot find", name, target_state_dict[name].shape) | |
| def search_for_files(folder, extensions): | |
| files = [] | |
| if os.path.isdir(folder): | |
| for file in sorted(os.listdir(folder)): | |
| files += search_for_files(os.path.join(folder, file), extensions) | |
| elif os.path.isfile(folder): | |
| for extension in extensions: | |
| if folder.endswith(extension): | |
| files.append(folder) | |
| break | |
| return files | |
| def convert_state_dict_keys_to_single_str(state_dict, with_shape=True): | |
| keys = [] | |
| for key, value in state_dict.items(): | |
| if isinstance(key, str): | |
| if isinstance(value, torch.Tensor): | |
| if with_shape: | |
| shape = "_".join(map(str, list(value.shape))) | |
| keys.append(key + ":" + shape) | |
| keys.append(key) | |
| elif isinstance(value, dict): | |
| keys.append( | |
| key | |
| + "|" | |
| + convert_state_dict_keys_to_single_str( | |
| value, with_shape=with_shape | |
| ) | |
| ) | |
| keys.sort() | |
| keys_str = ",".join(keys) | |
| return keys_str | |
| def split_state_dict_with_prefix(state_dict): | |
| keys = sorted([key for key in state_dict if isinstance(key, str)]) | |
| prefix_dict = {} | |
| for key in keys: | |
| prefix = key if "." not in key else key.split(".")[0] | |
| if prefix not in prefix_dict: | |
| prefix_dict[prefix] = [] | |
| prefix_dict[prefix].append(key) | |
| state_dicts = [] | |
| for prefix, keys in prefix_dict.items(): | |
| sub_state_dict = {key: state_dict[key] for key in keys} | |
| state_dicts.append(sub_state_dict) | |
| return state_dicts | |
| def hash_state_dict_keys(state_dict, with_shape=True): | |
| keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape) | |
| keys_str = keys_str.encode(encoding="UTF-8") | |
| return hashlib.md5(keys_str).hexdigest() | |
| def clean_vram(): | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| if torch.mps.is_available(): | |
| torch.mps.empty_cache() | |
| def get_device_list(): | |
| devs = [] | |
| try: | |
| if ( | |
| hasattr(torch, "cuda") | |
| and hasattr(torch.cuda, "is_available") | |
| and torch.cuda.is_available() | |
| ): | |
| devs += [f"cuda:{i}" for i in range(torch.cuda.device_count())] | |
| except Exception: | |
| pass | |
| try: | |
| if ( | |
| hasattr(torch, "mps") | |
| and hasattr(torch.mps, "is_available") | |
| and torch.mps.is_available() | |
| ): | |
| devs += [f"mps:{i}" for i in range(torch.mps.device_count())] | |
| except Exception: | |
| pass | |
| return devs | |
| class RMS_norm(nn.Module): | |
| def __init__(self, dim, channel_first=True, images=True, bias=False): | |
| super().__init__() | |
| broadcastable_dims = (1, 1, 1) if not images else (1, 1) | |
| shape = (dim, *broadcastable_dims) if channel_first else (dim,) | |
| self.channel_first = channel_first | |
| self.scale = dim**0.5 | |
| self.gamma = nn.Parameter(torch.ones(shape)) | |
| self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 | |
| def forward(self, x): | |
| return ( | |
| F.normalize(x, dim=(1 if self.channel_first else -1)) | |
| * self.scale | |
| * self.gamma | |
| + self.bias | |
| ) | |
| class CausalConv3d(nn.Conv3d): | |
| """ | |
| Causal 3d convolusion. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self._padding = ( | |
| self.padding[2], | |
| self.padding[2], | |
| self.padding[1], | |
| self.padding[1], | |
| 2 * self.padding[0], | |
| 0, | |
| ) | |
| self.padding = (0, 0, 0) | |
| def forward(self, x, cache_x=None): | |
| padding = list(self._padding) | |
| if cache_x is not None and self._padding[4] > 0: | |
| cache_x = cache_x.to(x.device) | |
| # print(cache_x.shape, x.shape) | |
| x = torch.cat([cache_x, x], dim=2) | |
| padding[4] -= cache_x.shape[2] | |
| # print('cache!') | |
| x = F.pad(x, padding, mode="replicate") # mode='replicate' | |
| # print(x[0,0,:,0,0]) | |
| return super().forward(x) | |
| class PixelShuffle3d(nn.Module): | |
| def __init__(self, ff, hh, ww): | |
| super().__init__() | |
| self.ff = ff | |
| self.hh = hh | |
| self.ww = ww | |
| def forward(self, x): | |
| # x: (B, C, F, H, W) | |
| return rearrange( | |
| x, | |
| "b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w", | |
| ff=self.ff, | |
| hh=self.hh, | |
| ww=self.ww, | |
| ) | |
| class Buffer_LQ4x_Proj(nn.Module): | |
| def __init__(self, in_dim, out_dim, layer_num=30): | |
| super().__init__() | |
| self.ff = 1 | |
| self.hh = 16 | |
| self.ww = 16 | |
| self.hidden_dim1 = 2048 | |
| self.hidden_dim2 = 3072 | |
| self.layer_num = layer_num | |
| self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww) | |
| self.conv1 = CausalConv3d( | |
| in_dim * self.ff * self.hh * self.ww, | |
| self.hidden_dim1, | |
| (4, 3, 3), | |
| stride=(2, 1, 1), | |
| padding=(1, 1, 1), | |
| ) # f -> f/2 h -> h w -> w | |
| self.norm1 = RMS_norm(self.hidden_dim1, images=False) | |
| self.act1 = nn.SiLU() | |
| self.conv2 = CausalConv3d( | |
| self.hidden_dim1, | |
| self.hidden_dim2, | |
| (4, 3, 3), | |
| stride=(2, 1, 1), | |
| padding=(1, 1, 1), | |
| ) # f -> f/2 h -> h w -> w | |
| self.norm2 = RMS_norm(self.hidden_dim2, images=False) | |
| self.act2 = nn.SiLU() | |
| self.linear_layers = nn.ModuleList( | |
| [nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)] | |
| ) | |
| self.clip_idx = 0 | |
| def forward(self, video): | |
| self.clear_cache() | |
| # x: (B, C, F, H, W) | |
| t = video.shape[2] | |
| iter_ = 1 + (t - 1) // 4 | |
| first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1) | |
| video = torch.cat([first_frame, video], dim=2) | |
| # print(video.shape) | |
| out_x = [] | |
| for i in range(iter_): | |
| x = self.pixel_shuffle(video[:, :, i * 4 : (i + 1) * 4, :, :]) | |
| cache1_x = x[:, :, -CACHE_T:, :, :].clone() | |
| self.cache["conv1"] = cache1_x | |
| x = self.conv1(x, self.cache["conv1"]) | |
| x = self.norm1(x) | |
| x = self.act1(x) | |
| cache2_x = x[:, :, -CACHE_T:, :, :].clone() | |
| self.cache["conv2"] = cache2_x | |
| if i == 0: | |
| continue | |
| x = self.conv2(x, self.cache["conv2"]) | |
| x = self.norm2(x) | |
| x = self.act2(x) | |
| out_x.append(x) | |
| out_x = torch.cat(out_x, dim=2) | |
| # print(out_x.shape) | |
| out_x = rearrange(out_x, "b c f h w -> b (f h w) c") | |
| outputs = [] | |
| for i in range(self.layer_num): | |
| outputs.append(self.linear_layers[i](out_x)) | |
| return outputs | |
| def clear_cache(self): | |
| self.cache = {} | |
| self.cache["conv1"] = None | |
| self.cache["conv2"] = None | |
| self.clip_idx = 0 | |
| def stream_forward(self, video_clip): | |
| if self.clip_idx == 0: | |
| # self.clear_cache() | |
| first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1) | |
| video_clip = torch.cat([first_frame, video_clip], dim=2) | |
| x = self.pixel_shuffle(video_clip) | |
| cache1_x = x[:, :, -CACHE_T:, :, :].clone() | |
| self.cache["conv1"] = cache1_x | |
| x = self.conv1(x, self.cache["conv1"]) | |
| x = self.norm1(x) | |
| x = self.act1(x) | |
| cache2_x = x[:, :, -CACHE_T:, :, :].clone() | |
| self.cache["conv2"] = cache2_x | |
| self.clip_idx += 1 | |
| return None | |
| else: | |
| x = self.pixel_shuffle(video_clip) | |
| cache1_x = x[:, :, -CACHE_T:, :, :].clone() | |
| self.cache["conv1"] = cache1_x | |
| x = self.conv1(x, self.cache["conv1"]) | |
| x = self.norm1(x) | |
| x = self.act1(x) | |
| cache2_x = x[:, :, -CACHE_T:, :, :].clone() | |
| self.cache["conv2"] = cache2_x | |
| x = self.conv2(x, self.cache["conv2"]) | |
| x = self.norm2(x) | |
| x = self.act2(x) | |
| out_x = rearrange(x, "b c f h w -> b (f h w) c") | |
| outputs = [] | |
| for i in range(self.layer_num): | |
| outputs.append(self.linear_layers[i](out_x)) | |
| self.clip_idx += 1 | |
| return outputs | |
| class Causal_LQ4x_Proj(nn.Module): | |
| def __init__(self, in_dim, out_dim, layer_num=30): | |
| super().__init__() | |
| self.ff = 1 | |
| self.hh = 16 | |
| self.ww = 16 | |
| self.hidden_dim1 = 2048 | |
| self.hidden_dim2 = 3072 | |
| self.layer_num = layer_num | |
| self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww) | |
| self.conv1 = CausalConv3d( | |
| in_dim * self.ff * self.hh * self.ww, | |
| self.hidden_dim1, | |
| (4, 3, 3), | |
| stride=(2, 1, 1), | |
| padding=(1, 1, 1), | |
| ) # f -> f/2 h -> h w -> w | |
| self.norm1 = RMS_norm(self.hidden_dim1, images=False) | |
| self.act1 = nn.SiLU() | |
| self.conv2 = CausalConv3d( | |
| self.hidden_dim1, | |
| self.hidden_dim2, | |
| (4, 3, 3), | |
| stride=(2, 1, 1), | |
| padding=(1, 1, 1), | |
| ) # f -> f/2 h -> h w -> w | |
| self.norm2 = RMS_norm(self.hidden_dim2, images=False) | |
| self.act2 = nn.SiLU() | |
| self.linear_layers = nn.ModuleList( | |
| [nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)] | |
| ) | |
| self.clip_idx = 0 | |
| def forward(self, video): | |
| self.clear_cache() | |
| # x: (B, C, F, H, W) | |
| t = video.shape[2] | |
| iter_ = 1 + (t - 1) // 4 | |
| first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1) | |
| video = torch.cat([first_frame, video], dim=2) | |
| # print(video.shape) | |
| out_x = [] | |
| for i in range(iter_): | |
| x = self.pixel_shuffle(video[:, :, i * 4 : (i + 1) * 4, :, :]) | |
| cache1_x = x[:, :, -CACHE_T:, :, :].clone() | |
| x = self.conv1(x, self.cache["conv1"]) | |
| self.cache["conv1"] = cache1_x | |
| x = self.norm1(x) | |
| x = self.act1(x) | |
| cache2_x = x[:, :, -CACHE_T:, :, :].clone() | |
| if i == 0: | |
| self.cache["conv2"] = cache2_x | |
| continue | |
| x = self.conv2(x, self.cache["conv2"]) | |
| self.cache["conv2"] = cache2_x | |
| x = self.norm2(x) | |
| x = self.act2(x) | |
| out_x.append(x) | |
| out_x = torch.cat(out_x, dim=2) | |
| out_x = rearrange(out_x, "b c f h w -> b (f h w) c") | |
| outputs = [] | |
| for i in range(self.layer_num): | |
| outputs.append(self.linear_layers[i](out_x)) | |
| return outputs | |
| def clear_cache(self): | |
| self.cache = {} | |
| self.cache["conv1"] = None | |
| self.cache["conv2"] = None | |
| self.clip_idx = 0 | |
| def stream_forward(self, video_clip): | |
| if self.clip_idx == 0: | |
| # self.clear_cache() | |
| first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1) | |
| video_clip = torch.cat([first_frame, video_clip], dim=2) | |
| x = self.pixel_shuffle(video_clip) | |
| cache1_x = x[:, :, -CACHE_T:, :, :].clone() | |
| x = self.conv1(x, self.cache["conv1"]) | |
| self.cache["conv1"] = cache1_x | |
| x = self.norm1(x) | |
| x = self.act1(x) | |
| cache2_x = x[:, :, -CACHE_T:, :, :].clone() | |
| self.cache["conv2"] = cache2_x | |
| self.clip_idx += 1 | |
| return None | |
| else: | |
| x = self.pixel_shuffle(video_clip) | |
| cache1_x = x[:, :, -CACHE_T:, :, :].clone() | |
| x = self.conv1(x, self.cache["conv1"]) | |
| self.cache["conv1"] = cache1_x | |
| x = self.norm1(x) | |
| x = self.act1(x) | |
| cache2_x = x[:, :, -CACHE_T:, :, :].clone() | |
| x = self.conv2(x, self.cache["conv2"]) | |
| self.cache["conv2"] = cache2_x | |
| x = self.norm2(x) | |
| x = self.act2(x) | |
| out_x = rearrange(x, "b c f h w -> b (f h w) c") | |
| outputs = [] | |
| for i in range(self.layer_num): | |
| outputs.append(self.linear_layers[i](out_x)) | |
| self.clip_idx += 1 | |
| return outputs | |
| class FrameStreamBuffer: | |
| def __init__( | |
| self, | |
| frame_generator: types.GeneratorType, | |
| buffer_size: int = 60, | |
| device="cpu", | |
| dtype=torch.float16, | |
| ): | |
| self.generator = frame_generator | |
| self.buffer_size = buffer_size | |
| self.device = device | |
| self.dtype = dtype | |
| self.buffer = deque() | |
| self.start_frame_index = 0 | |
| self._fill_buffer(initial_fill_count=self.buffer_size) | |
| def _fill_buffer(self, initial_fill_count: int): | |
| try: | |
| for _ in range(initial_fill_count): | |
| frame = next(self.generator) | |
| self.buffer.append(frame) | |
| except StopIteration: | |
| pass | |
| def get_chunk(self, start: int, end: int) -> torch.Tensor: | |
| if start < self.start_frame_index: | |
| raise IndexError( | |
| f"Start frame {start} has already been discarded (current buffer starts at {self.start_frame_index})" | |
| ) | |
| while end > self.start_frame_index + len(self.buffer): | |
| try: | |
| self.buffer.append(next(self.generator)) | |
| except StopIteration: | |
| if end > self.start_frame_index + len(self.buffer): | |
| print( | |
| f"End frame {end} is out of range! It will be truncated to {self.start_frame_index + len(self.buffer)}" | |
| ) | |
| end = self.start_frame_index + len(self.buffer) | |
| break | |
| while len(self.buffer) > self.buffer_size: | |
| self.buffer.popleft() | |
| self.start_frame_index += 1 | |
| relative_start = start - self.start_frame_index | |
| relative_end = end - self.start_frame_index | |
| chunk_list = [self.buffer[i] for i in range(relative_start, relative_end)] | |
| if not chunk_list: | |
| C, H, W = self.buffer[0].shape | |
| return torch.empty((1, C, 0, H, W), device=self.device, dtype=self.dtype) | |
| chunk_tensor = torch.stack(chunk_list, dim=1) # (C, chunk_len, H, W) | |
| return chunk_tensor.unsqueeze(0).to( | |
| device=self.device | |
| ) # (1, C, chunk_len, H, W) | |
| class TensorAsBuffer: | |
| def __init__(self, tensor: torch.Tensor): | |
| self.tensor = tensor | |
| def get_chunk(self, start: int, end: int) -> torch.Tensor: | |
| return self.tensor[:, :, start:end, :, :] | |
| def tensor_to_imageio_frame(frame_tensor: torch.Tensor) -> np.ndarray: | |
| img_tensor = (frame_tensor + 1.0) / 2.0 | |
| img_tensor_hwc = img_tensor.permute(1, 2, 0) | |
| img_tensor_hwc_u8 = (img_tensor_hwc * 255.0).clamp(0, 255).to(torch.uint8) | |
| img_np = img_tensor_hwc_u8.cpu().numpy() | |
| return img_np | |
| root = os.path.dirname(os.path.abspath(__file__)) | |
| temp = os.path.join(root, "_temp") | |
| devices = get_device_list() | |
| def model_downlod(model_name="JunhaoZhuang/FlashVSR"): | |
| model_dir = os.path.join(root, "models", model_name.split("/")[-1]) | |
| print(f"model dir: {model_dir}") | |
| if not os.path.exists(model_dir): | |
| log( | |
| f"Downloading model '{model_name}' from huggingface...", message_type="info" | |
| ) | |
| snapshot_download( | |
| repo_id=model_name, | |
| local_dir=model_dir, | |
| local_dir_use_symlinks=False, | |
| resume_download=True, | |
| ) | |
| def is_ffmpeg_available(): | |
| ffmpeg_path = shutil.which("ffmpeg") | |
| if ffmpeg_path is None: | |
| log("[FlashVSR] FFmpeg not found!", message_type="warning") | |
| log("Please install FFmpeg and ensure it is in your system's PATH.") | |
| log( | |
| "- Windows: Download from https://www.ffmpeg.org/download.html and add the 'bin' directory to PATH." | |
| ) | |
| log("- macOS (via Homebrew): brew install ffmpeg") | |
| log("- Linux (Ubuntu/Debian): sudo apt-get install ffmpeg") | |
| return False | |
| return True | |
| def tensor2video(frames: torch.Tensor): | |
| video_squeezed = frames.squeeze(0) | |
| video_permuted = rearrange(video_squeezed, "C F H W -> F H W C") | |
| video_final = (video_permuted.float() + 1.0) / 2.0 | |
| return video_final | |
| def natural_key(name: str): | |
| return [ | |
| int(t) if t.isdigit() else t.lower() | |
| for t in re.split(r"([0-9]+)", os.path.basename(name)) | |
| ] | |
| def list_images_natural(folder: str): | |
| exts = (".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG") | |
| fs = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(exts)] | |
| fs.sort(key=natural_key) | |
| return fs | |
| def largest_8n1_leq(n): # 8n+1 | |
| return 0 if n < 1 else ((n - 1) // 8) * 8 + 1 | |
| def next_8n5(n): # next 8n+5 | |
| return 21 if n < 21 else ((n - 5 + 7) // 8) * 8 + 5 | |
| def is_video(path): | |
| return os.path.isfile(path) and path.lower().endswith( | |
| (".mp4", ".mov", ".avi", ".mkv") | |
| ) | |
| def save_video(frames, save_path, fps=30, quality=5): | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| frames_np = (frames.cpu().float() * 255.0).clip(0, 255).numpy().astype(np.uint8) | |
| w = imageio.get_writer(save_path, fps=fps, quality=quality) | |
| for frame_np in tqdm(frames_np, desc=f"[FlashVSR] Saving video"): | |
| w.append_data(frame_np) | |
| w.close() | |
| def merge_video_with_audio(video_path, audio_source_path): | |
| temp = video_path + "temp.mp4" | |
| if os.path.isdir(audio_source_path): | |
| log(f"[FlashVSR] Output video saved to '{video_path}'", message_type="info") | |
| return | |
| if not is_ffmpeg_available(): | |
| log(f"[FlashVSR] Output video saved to '{video_path}'", message_type="info") | |
| return | |
| try: | |
| probe = ffmpeg.probe(audio_source_path) | |
| audio_streams = [s for s in probe["streams"] if s["codec_type"] == "audio"] | |
| if not audio_streams: | |
| log(f"[FlashVSR] Output video saved to '{video_path}'", message_type="info") | |
| return | |
| log("[FlashVSR] Copying audio tracks...") | |
| os.rename(video_path, temp) | |
| input_video = ffmpeg.input(temp)["v"] | |
| input_audio = ffmpeg.input(audio_source_path)["a"] | |
| output_ffmpeg = ffmpeg.output( | |
| input_video, input_audio, video_path, vcodec="copy", acodec="copy" | |
| ).run(overwrite_output=True, quiet=True) | |
| log(f"[FlashVSR] Output video saved to '{video_path}'", message_type="info") | |
| except ffmpeg.Error as e: | |
| print( | |
| "[ERROR] FFmpeg error during merge:", | |
| e.stderr.decode() if e.stderr else "Unknown error", | |
| ) | |
| log( | |
| f"[FlashVSR] Audio merge failed. A silent video has been saved to '{video_path}'.", | |
| message_type="warning", | |
| ) | |
| finally: | |
| if os.path.exists(temp): | |
| try: | |
| os.remove(temp) | |
| except OSError as e: | |
| lgo( | |
| f"[FlashVSR] Could not remove temporary file '{temp}': {e}", | |
| message_type="error", | |
| ) | |
| def compute_scaled_and_target_dims( | |
| w0: int, h0: int, scale: int = 4, multiple: int = 128 | |
| ): | |
| if w0 <= 0 or h0 <= 0: | |
| raise ValueError("invalid original size") | |
| sW, sH = w0 * scale, h0 * scale | |
| tW = max(multiple, (sW // multiple) * multiple) | |
| tH = max(multiple, (sH // multiple) * multiple) | |
| return sW, sH, tW, tH | |
| def tensor_upscale_then_center_crop( | |
| frame_tensor: torch.Tensor, scale: int, tW: int, tH: int | |
| ) -> torch.Tensor: | |
| h0, w0, c = frame_tensor.shape | |
| tensor_bchw = frame_tensor.permute(2, 0, 1).unsqueeze(0) # HWC -> CHW -> BCHW | |
| sW, sH = w0 * scale, h0 * scale | |
| upscaled_tensor = F.interpolate( | |
| tensor_bchw, size=(sH, sW), mode="bicubic", align_corners=False | |
| ) | |
| l = max(0, (sW - tW) // 2) | |
| t = max(0, (sH - tH) // 2) | |
| cropped_tensor = upscaled_tensor[:, :, t : t + tH, l : l + tW] | |
| return cropped_tensor.squeeze(0) | |
| def prepare_tensors(path: str, dtype=torch.bfloat16): | |
| if os.path.isdir(path): | |
| paths0 = list_images_natural(path) | |
| if not paths0: | |
| raise FileNotFoundError(f"No images in {path}") | |
| with Image.open(paths0[0]) as _img0: | |
| w0, h0 = _img0.size | |
| N0 = len(paths0) | |
| frames = [] | |
| for p in paths0: | |
| with Image.open(p).convert("RGB") as img: | |
| img_np = np.array(img).astype(np.float32) / 255.0 | |
| frames.append(torch.from_numpy(img_np).to(dtype)) | |
| vid = torch.stack(frames, 0) | |
| fps = 30 | |
| return vid, fps | |
| if is_video(path): | |
| rdr = imageio.get_reader(path) | |
| meta = {} | |
| try: | |
| meta = rdr.get_meta_data() | |
| first_frame = rdr.get_data(0) | |
| h0, w0, _ = first_frame.shape | |
| except Exception: | |
| first_frame = rdr.get_data(0) | |
| h0, w0, _ = first_frame.shape | |
| fps_val = meta.get("fps", 30) | |
| fps = int(round(fps_val)) if isinstance(fps_val, (int, float)) else 30 | |
| total = meta.get("nframes", rdr.count_frames()) | |
| if total is None or total <= 0: | |
| total = len([_ for _ in rdr]) | |
| rdr = imageio.get_reader(path) | |
| if total <= 0: | |
| rdr.close() | |
| raise RuntimeError(f"Cannot read frames from {path}") | |
| frames = [] | |
| try: | |
| for frame_data in rdr: | |
| frame_np = frame_data.astype(np.float32) / 255.0 | |
| frames.append(torch.from_numpy(frame_np).to(dtype)) | |
| finally: | |
| try: | |
| rdr.close() | |
| except Exception: | |
| pass | |
| vid = torch.stack(frames, 0) | |
| return vid, fps | |
| raise ValueError(f"Unsupported input: {path}") | |
| def get_input_params(image_tensor, scale): | |
| N0, h0, w0, _ = image_tensor.shape | |
| multiple = 128 | |
| sW, sH, tW, tH = compute_scaled_and_target_dims( | |
| w0, h0, scale=scale, multiple=multiple | |
| ) | |
| num_frames_with_padding = N0 + 4 | |
| F = largest_8n1_leq(num_frames_with_padding) | |
| if F == 0: | |
| raise RuntimeError( | |
| f"Not enough frames after padding. Got {num_frames_with_padding}." | |
| ) | |
| return tH, tW, F | |
| def input_tensor_generator( | |
| image_tensor: torch.Tensor, device, scale: int = 4, dtype=torch.bfloat16 | |
| ): | |
| """ | |
| 一个生成器函数,逐帧处理并 yield 准备好的帧张量,以节省内存。 | |
| 产出的每个张量形状为 (C, H, W)。 | |
| """ | |
| N0, h0, w0, _ = image_tensor.shape | |
| tH, tW, F = get_input_params(image_tensor, scale) | |
| for i in range(F): | |
| frame_idx = min(i, N0 - 1) | |
| frame_slice = image_tensor[frame_idx].to(device) | |
| tensor_chw = tensor_upscale_then_center_crop( | |
| frame_slice, scale=scale, tW=tW, tH=tH | |
| ) | |
| tensor_out = tensor_chw * 2.0 - 1.0 | |
| del tensor_chw | |
| yield tensor_out.to("cpu").to(dtype) | |
| def prepare_input_tensor( | |
| image_tensor: torch.Tensor, device, scale: int = 4, dtype=torch.bfloat16 | |
| ): | |
| N0, h0, w0, _ = image_tensor.shape | |
| multiple = 128 | |
| sW, sH, tW, tH = compute_scaled_and_target_dims( | |
| w0, h0, scale=scale, multiple=multiple | |
| ) | |
| num_frames_with_padding = N0 + 4 | |
| F = largest_8n1_leq(num_frames_with_padding) | |
| if F == 0: | |
| raise RuntimeError( | |
| f"Not enough frames after padding. Got {num_frames_with_padding}." | |
| ) | |
| frames = [] | |
| for i in range(F): | |
| frame_idx = min(i, N0 - 1) | |
| frame_slice = image_tensor[frame_idx].to(device) | |
| tensor_chw = tensor_upscale_then_center_crop( | |
| frame_slice, scale=scale, tW=tW, tH=tH | |
| ) | |
| tensor_out = tensor_chw * 2.0 - 1.0 | |
| tensor_out = tensor_out.to("cpu").to(dtype) | |
| frames.append(tensor_out) | |
| vid_stacked = torch.stack(frames, 0) | |
| vid_final = vid_stacked.permute(1, 0, 2, 3).unsqueeze(0) | |
| del vid_stacked | |
| clean_vram() | |
| return vid_final, tH, tW, F | |
| def calculate_tile_coords(height, width, tile_size, overlap): | |
| coords = [] | |
| stride = tile_size - overlap | |
| num_rows = math.ceil((height - overlap) / stride) | |
| num_cols = math.ceil((width - overlap) / stride) | |
| for r in range(num_rows): | |
| for c in range(num_cols): | |
| y1 = r * stride | |
| x1 = c * stride | |
| y2 = min(y1 + tile_size, height) | |
| x2 = min(x1 + tile_size, width) | |
| if y2 - y1 < tile_size: | |
| y1 = max(0, y2 - tile_size) | |
| if x2 - x1 < tile_size: | |
| x1 = max(0, x2 - tile_size) | |
| coords.append((x1, y1, x2, y2)) | |
| return coords | |
| def create_feather_mask_numpy(size, overlap): | |
| H, W = size | |
| mask = np.ones((H, W, 1), dtype=np.float32) | |
| ramp = np.linspace(0, 1, overlap, dtype=np.float32) | |
| mask[:, :overlap, :] *= ramp[np.newaxis, :, np.newaxis] | |
| mask[:, -overlap:, :] *= np.flip(ramp)[np.newaxis, :, np.newaxis] | |
| mask[:overlap, :, :] *= ramp[:, np.newaxis, np.newaxis] | |
| mask[-overlap:, :, :] *= np.flip(ramp)[:, np.newaxis, np.newaxis] | |
| return mask | |
| def create_feather_mask(size, overlap): | |
| H, W = size | |
| mask = torch.ones(1, 1, H, W) | |
| ramp = torch.linspace(0, 1, overlap) | |
| mask[:, :, :, :overlap] = torch.minimum( | |
| mask[:, :, :, :overlap], ramp.view(1, 1, 1, -1) | |
| ) | |
| mask[:, :, :, -overlap:] = torch.minimum( | |
| mask[:, :, :, -overlap:], ramp.flip(0).view(1, 1, 1, -1) | |
| ) | |
| mask[:, :, :overlap, :] = torch.minimum( | |
| mask[:, :, :overlap, :], ramp.view(1, 1, -1, 1) | |
| ) | |
| mask[:, :, -overlap:, :] = torch.minimum( | |
| mask[:, :, -overlap:, :], ramp.flip(0).view(1, 1, -1, 1) | |
| ) | |
| return mask | |
| def stitch_video_tiles( | |
| tile_paths, | |
| tile_coords, | |
| final_dims, | |
| scale, | |
| overlap, | |
| output_path, | |
| fps, | |
| quality, | |
| cleanup=True, | |
| chunk_size=40, # --- 新增参数:每次在内存中处理的帧数 --- | |
| ): | |
| if not tile_paths: | |
| log("No tile videos found to stitch.", message_type="error") | |
| return | |
| final_W, final_H = final_dims | |
| # 1. 一次性打开所有视频文件 | |
| readers = [imageio.get_reader(p) for p in tile_paths] | |
| try: | |
| # 获取总帧数 | |
| num_frames = readers[0].count_frames() | |
| if num_frames is None or num_frames <= 0: | |
| num_frames = len([_ for _ in readers[0]]) | |
| for r in readers: | |
| r.close() | |
| readers = [imageio.get_reader(p) for p in tile_paths] | |
| # 打开最终的写入器 | |
| with imageio.get_writer(output_path, fps=fps, quality=quality) as writer: | |
| # 2. 按 chunk_size 遍历所有帧 | |
| # tqdm 现在描述的是处理了多少个“块” | |
| for start_frame in tqdm( | |
| range(0, num_frames, chunk_size), desc="[FlashVSR] Stitching Chunks" | |
| ): | |
| end_frame = min(start_frame + chunk_size, num_frames) | |
| current_chunk_size = end_frame - start_frame | |
| # 3. 为整个“块”在内存中创建画布 | |
| # 形状: (Frames, Height, Width, Channels) | |
| chunk_canvas = np.zeros( | |
| (current_chunk_size, final_H, final_W, 3), dtype=np.float32 | |
| ) | |
| weight_canvas = np.zeros_like(chunk_canvas, dtype=np.float32) | |
| # 4. 遍历每个分块视频 (tile) | |
| for i, reader in enumerate(readers): | |
| # 5. 一次性读取这个 tile 在当前 chunk 中的所有帧 | |
| # 这是利用顺序读取的关键优化 | |
| try: | |
| # get_reader().iter_data() 是高效读取连续帧的方式 | |
| tile_chunk_frames = [ | |
| frame.astype(np.float32) / 255.0 | |
| for idx, frame in enumerate(reader.iter_data()) | |
| if start_frame <= idx < end_frame | |
| ] | |
| # 将帧列表转换为一个 NumPy 数组 | |
| tile_chunk_np = np.stack(tile_chunk_frames, axis=0) | |
| except Exception as e: | |
| log( | |
| f"Warning: Could not read chunk from tile {i}. Error: {e}", | |
| message_type="warning", | |
| ) | |
| continue | |
| if tile_chunk_np.shape[0] != current_chunk_size: | |
| log( | |
| f"Warning: Tile {i} chunk has incorrect frame count. Skipping.", | |
| message_type="warning", | |
| ) | |
| continue | |
| # 6. 创建羽化蒙版 (只需要创建一次) | |
| tile_H, tile_W, _ = tile_chunk_np.shape[1:] | |
| ramp = np.linspace(0, 1, overlap * scale, dtype=np.float32) | |
| mask = np.ones((tile_H, tile_W, 1), dtype=np.float32) | |
| mask[:, : overlap * scale, :] *= ramp[np.newaxis, :, np.newaxis] | |
| mask[:, -overlap * scale :, :] *= np.flip(ramp)[ | |
| np.newaxis, :, np.newaxis | |
| ] | |
| mask[: overlap * scale, :, :] *= ramp[:, np.newaxis, np.newaxis] | |
| mask[-overlap * scale :, :, :] *= np.flip(ramp)[ | |
| :, np.newaxis, np.newaxis | |
| ] | |
| # 扩展蒙版以匹配 chunk 的帧数维度 | |
| mask_4d = mask[np.newaxis, :, :, :] # 形状: (1, H, W, C) | |
| # 7. 在内存中拼接整个 chunk | |
| x1_orig, y1_orig, _, _ = tile_coords[i] | |
| out_y1, out_x1 = y1_orig * scale, x1_orig * scale | |
| out_y2, out_x2 = out_y1 + tile_H, out_x1 + tile_W | |
| # 使用 NumPy 的广播机制 (broadcasting) | |
| chunk_canvas[:, out_y1:out_y2, out_x1:out_x2, :] += ( | |
| tile_chunk_np * mask_4d | |
| ) | |
| weight_canvas[:, out_y1:out_y2, out_x1:out_x2, :] += mask_4d | |
| # 8. 归一化整个 chunk | |
| weight_canvas[weight_canvas == 0] = 1.0 | |
| stitched_chunk = chunk_canvas / weight_canvas | |
| # 9. 将这个 chunk 的所有帧一次性写入文件 | |
| for frame_idx_in_chunk in range(current_chunk_size): | |
| frame_uint8 = ( | |
| np.clip(stitched_chunk[frame_idx_in_chunk], 0, 1) * 255 | |
| ).astype(np.uint8) | |
| writer.append_data(frame_uint8) | |
| finally: | |
| log("Closing all tile reader instances...") | |
| for reader in readers: | |
| reader.close() | |
| if cleanup: | |
| log("Cleaning up temporary tile files...") | |
| for path in tile_paths: | |
| try: | |
| os.remove(path) | |
| except OSError as e: | |
| log( | |
| f"Could not remove temporary file '{path}': {e}", | |
| message_type="warning", | |
| ) | |
| def init_pipeline(version, mode, device, dtype): | |
| if version == "10": | |
| model = "FlashVSR" | |
| else: | |
| model = "FlashVSR-v1.1" | |
| model_downlod(model_name="JunhaoZhuang/" + model) | |
| model_path = os.path.join(root, "models", model) | |
| if not os.path.exists(model_path): | |
| raise RuntimeError( | |
| f'Model directory does not exist! Please save all weights to "{model_path}"' | |
| ) | |
| ckpt_path = os.path.join( | |
| model_path, "diffusion_pytorch_model_streaming_dmd.safetensors" | |
| ) | |
| if not os.path.exists(ckpt_path): | |
| raise RuntimeError( | |
| f'"diffusion_pytorch_model_streaming_dmd.safetensors" does not exist! Please save it to "{model_path}"' | |
| ) | |
| vae_path = os.path.join(model_path, "Wan2.1_VAE.pth") | |
| if not os.path.exists(vae_path): | |
| raise RuntimeError( | |
| f'"Wan2.1_VAE.pth" does not exist! Please save it to "{model_path}"' | |
| ) | |
| lq_path = os.path.join(model_path, "LQ_proj_in.ckpt") | |
| if not os.path.exists(lq_path): | |
| raise RuntimeError( | |
| f'"LQ_proj_in.ckpt" does not exist! Please save it to "{model_path}"' | |
| ) | |
| tcd_path = os.path.join(model_path, "TCDecoder.ckpt") | |
| if not os.path.exists(tcd_path): | |
| raise RuntimeError( | |
| f'"TCDecoder.ckpt" does not exist! Please save it to "{model_path}"' | |
| ) | |
| prompt_path = os.path.join(root, "models", "posi_prompt.pth") | |
| mm = ModelManager(torch_dtype=dtype, device="cpu") | |
| if mode == "full": | |
| mm.load_models([ckpt_path, vae_path]) | |
| pipe = FlashVSRFullPipeline.from_model_manager(mm, device=device) | |
| pipe.vae.model.encoder = None | |
| pipe.vae.model.conv1 = None | |
| else: | |
| mm.load_models([ckpt_path]) | |
| if mode == "tiny": | |
| pipe = FlashVSRTinyPipeline.from_model_manager(mm, device=device) | |
| else: | |
| pipe = FlashVSRTinyLongPipeline.from_model_manager(mm, device=device) | |
| multi_scale_channels = [512, 256, 128, 128] | |
| pipe.TCDecoder = build_tcdecoder( | |
| new_channels=multi_scale_channels, | |
| device=device, | |
| dtype=dtype, | |
| new_latent_channels=16 + 768, | |
| ) | |
| mis = pipe.TCDecoder.load_state_dict( | |
| torch.load(tcd_path, map_location=device), strict=False | |
| ) | |
| pipe.TCDecoder.clean_mem() | |
| if model == "FlashVSR": | |
| pipe.denoising_model().LQ_proj_in = Buffer_LQ4x_Proj( | |
| in_dim=3, out_dim=1536, layer_num=1 | |
| ).to(device, dtype=dtype) | |
| else: | |
| pipe.denoising_model().LQ_proj_in = Causal_LQ4x_Proj( | |
| in_dim=3, out_dim=1536, layer_num=1 | |
| ).to(device, dtype=dtype) | |
| pipe.denoising_model().LQ_proj_in.load_state_dict( | |
| torch.load(lq_path, map_location="cpu"), strict=True | |
| ) | |
| pipe.denoising_model().LQ_proj_in.to(device) | |
| pipe.to(device, dtype=dtype) | |
| pipe.enable_vram_management(num_persistent_param_in_dit=None) | |
| pipe.init_cross_kv(prompt_path=prompt_path) | |
| pipe.load_models_to_device(["dit", "vae"]) | |
| return pipe | |
| def main( | |
| input, | |
| version, | |
| mode, | |
| scale, | |
| color_fix, | |
| tiled_vae, | |
| tiled_dit, | |
| tile_size, | |
| tile_overlap, | |
| unload_dit, | |
| dtype, | |
| sparse_ratio=2, | |
| kv_ratio=3, | |
| local_range=11, | |
| seed=0, | |
| device="auto", | |
| quality=6, | |
| output=None, | |
| ): | |
| _device = device | |
| if device == "auto": | |
| _device = ( | |
| "cuda:0" | |
| if torch.cuda.is_available() | |
| else "mps" if torch.backends.mps.is_available() else device | |
| ) | |
| if _device == "auto" or _device not in devices: | |
| raise RuntimeError("No devices found to run FlashVSR!") | |
| if _device.startswith("cuda"): | |
| torch.cuda.set_device(_device) | |
| if tiled_dit and (tile_overlap > tile_size / 2): | |
| raise ValueError('The "tile_overlap" must be less than half of "tile_size"!') | |
| _frames, fps = prepare_tensors(input, dtype=dtype) | |
| _fps = fps if is_video(input) else args.fps | |
| add = next_8n5(_frames.shape[0]) - _frames.shape[0] | |
| padding_frames = _frames[-1:, :, :, :].repeat(add, 1, 1, 1) | |
| frames = torch.cat([_frames, padding_frames], dim=0) | |
| frame_count = _frames.shape[0] | |
| del _frames | |
| clean_vram() | |
| log("[FlashVSR] Preparing frames...", message_type="finish") | |
| if tiled_dit: | |
| N, H, W, C = frames.shape | |
| num_aligned_frames = largest_8n1_leq(N + 4) - 4 | |
| if mode == "tiny-long": | |
| local_temp = os.path.join(temp, str(uuid.uuid4())) | |
| os.makedirs(local_temp, exist_ok=True) | |
| else: | |
| final_output_canvas = torch.zeros( | |
| (num_aligned_frames, H * scale, W * scale, C), dtype=dtype, device="cpu" | |
| ) | |
| weight_sum_canvas = torch.zeros_like(final_output_canvas) | |
| tile_coords = calculate_tile_coords(H, W, tile_size, tile_overlap) | |
| latent_tiles_cpu = [] | |
| temp_videos = [] | |
| pipe = init_pipeline(version, mode, _device, dtype) | |
| for i, (x1, y1, x2, y2) in enumerate(tile_coords): | |
| input_tile = frames[:, y1:y2, x1:x2, :] | |
| if mode == "tiny-long": | |
| temp_name = os.path.join(local_temp, f"{i+1:05d}.mp4") | |
| th, tw, F = get_input_params(input_tile, scale=scale) | |
| LQ_tile = input_tensor_generator( | |
| input_tile, _device, scale=scale, dtype=dtype | |
| ) | |
| else: | |
| LQ_tile, th, tw, F = prepare_input_tensor( | |
| input_tile, _device, scale=scale, dtype=dtype | |
| ) | |
| LQ_tile = LQ_tile.to(_device) | |
| if i == 0: | |
| log( | |
| f"[FlashVSR] Processing {frame_count} frames...", | |
| message_type="info", | |
| ) | |
| log( | |
| f"[FlashVSR] Processing tile {i+1}/{len(tile_coords)}: ({x1},{y1}) to ({x2},{y2})", | |
| message_type="info", | |
| ) | |
| output_tile_gpu = pipe( | |
| prompt="", | |
| negative_prompt="", | |
| cfg_scale=1.0, | |
| num_inference_steps=1, | |
| seed=seed, | |
| tiled=tiled_vae, | |
| LQ_video=LQ_tile, | |
| num_frames=F, | |
| height=th, | |
| width=tw, | |
| is_full_block=False, | |
| if_buffer=True, | |
| topk_ratio=sparse_ratio * 768 * 1280 / (th * tw), | |
| kv_ratio=kv_ratio, | |
| local_range=local_range, | |
| color_fix=color_fix, | |
| unload_dit=unload_dit, | |
| fps=_fps, | |
| quality=10, | |
| output_path=temp_name, | |
| tiled_dit=True, | |
| ) | |
| temp_videos.append(temp_name) | |
| if mode == "tiny-long": | |
| final_output = output_tile_gpu | |
| del LQ_tile, input_tile | |
| clean_vram() | |
| continue | |
| processed_tile_cpu = tensor2video(output_tile_gpu).to("cpu") | |
| mask_nchw = create_feather_mask( | |
| (processed_tile_cpu.shape[1], processed_tile_cpu.shape[2]), | |
| tile_overlap * scale, | |
| ).to("cpu") | |
| mask_nhwc = mask_nchw.permute(0, 2, 3, 1) | |
| out_x1, out_y1 = x1 * scale, y1 * scale | |
| tile_H_scaled = processed_tile_cpu.shape[1] | |
| tile_W_scaled = processed_tile_cpu.shape[2] | |
| out_x2, out_y2 = out_x1 + tile_W_scaled, out_y1 + tile_H_scaled | |
| final_output_canvas[:, out_y1:out_y2, out_x1:out_x2, :] += ( | |
| processed_tile_cpu * mask_nhwc | |
| ) | |
| weight_sum_canvas[:, out_y1:out_y2, out_x1:out_x2, :] += mask_nhwc | |
| del LQ_tile, output_tile_gpu, processed_tile_cpu, input_tile | |
| clean_vram() | |
| if mode == "tiny-long": | |
| stitch_video_tiles( | |
| tile_paths=temp_videos, | |
| tile_coords=tile_coords, | |
| final_dims=(W * scale, H * scale), | |
| scale=scale, | |
| overlap=tile_overlap, | |
| output_path=output, | |
| fps=_fps, | |
| quality=quality, | |
| cleanup=True, | |
| ) | |
| shutil.rmtree(local_temp) | |
| else: | |
| weight_sum_canvas[weight_sum_canvas == 0] = 1.0 | |
| final_output = final_output_canvas / weight_sum_canvas | |
| else: | |
| if mode == "tiny-long": | |
| th, tw, F = get_input_params(frames, scale=scale) | |
| LQ = input_tensor_generator(frames, _device, scale=scale, dtype=dtype) | |
| else: | |
| LQ, th, tw, F = prepare_input_tensor( | |
| frames, _device, scale=scale, dtype=dtype | |
| ) | |
| LQ = LQ.to(_device) | |
| pipe = init_pipeline(version, mode, _device, dtype) | |
| log(f"[FlashVSR] Processing {frame_count} frames...", message_type="info") | |
| video = pipe( | |
| prompt="", | |
| negative_prompt="", | |
| cfg_scale=1.0, | |
| num_inference_steps=1, | |
| seed=seed, | |
| tiled=tiled_vae, | |
| LQ_video=LQ, | |
| num_frames=F, | |
| height=th, | |
| width=tw, | |
| is_full_block=False, | |
| if_buffer=True, | |
| topk_ratio=sparse_ratio * 768 * 1280 / (th * tw), | |
| kv_ratio=kv_ratio, | |
| local_range=local_range, | |
| color_fix=color_fix, | |
| unload_dit=unload_dit, | |
| fps=_fps, | |
| output_path=output, | |
| tiled_dit=True, | |
| ) | |
| if mode == "tiny-long": | |
| del pipe, LQ | |
| clean_vram() | |
| return video, _fps | |
| log("[FlashVSR] Preparing frames...") | |
| final_output = tensor2video(video).to("cpu") | |
| del pipe, video, LQ | |
| clean_vram() | |
| return final_output[:frame_count, :, :, :], fps | |
| if __name__ == "__main__": | |
| dtype_map = { | |
| "fp32": torch.float32, | |
| "fp16": torch.float16, | |
| "bf16": torch.bfloat16, | |
| } | |
| try: | |
| dtype = dtype_map[args.dtype] | |
| except: | |
| dtype = torch.bfloat16 | |
| if args.attention == "sage": | |
| USE_BLOCK_ATTN = False | |
| else: | |
| USE_BLOCK_ATTN = True | |
| if os.path.exists(temp): | |
| shutil.rmtree(temp) | |
| os.makedirs(temp, exist_ok=True) | |
| name = os.path.basename(args.input.rstrip("/")) | |
| final = "/root/output.mp4" | |
| result, fps = main( | |
| args.input, | |
| args.version, | |
| args.mode, | |
| args.scale, | |
| args.color_fix, | |
| args.tiled_vae, | |
| args.tiled_dit, | |
| args.tile_size, | |
| args.overlap, | |
| args.unload_dit, | |
| dtype, | |
| seed=args.seed, | |
| device=args.device, | |
| quality=args.quality, | |
| output=final, | |
| ) | |
| if args.mode != "tiny-long": | |
| save_video(result, final, fps=fps, quality=args.quality) | |
| merge_video_with_audio(final, args.input) | |
| log("[FlashVSR] Done.", message_type="finish") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment