Skip to content

Instantly share code, notes, and snippets.

@tamnguyenvan
Last active December 24, 2025 03:45
Show Gist options
  • Select an option

  • Save tamnguyenvan/505041d9a124c767080706e3851e6aad to your computer and use it in GitHub Desktop.

Select an option

Save tamnguyenvan/505041d9a124c767080706e3851e6aad to your computer and use it in GitHub Desktop.
from dataclasses import dataclass
from tqdm import tqdm
import os
import re
import math
import uuid
import torch
import shutil
import imageio
import ffmpeg
import numpy as np
import torch.nn.functional as F
from PIL import Image
import numpy as np
from PIL import Image
from einops import rearrange
from huggingface_hub import snapshot_download
import os, torch, json, importlib
from typing import List
from typing_extensions import Literal, TypeAlias
import types
import os
import time
from typing import Optional, Tuple, Literal
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange
from PIL import Image
from tqdm import tqdm
from einops import rearrange, repeat
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import torch
import gc
import numpy as np
from PIL import Image
from torchvision.transforms import GaussianBlur
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
import os
import time
from typing import Tuple, Optional, List
from einops import rearrange
import torch, math
import triton
import triton.language as tl
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from collections import namedtuple
from einops import rearrange
import torch.nn.init as init
import torch, os, gc
from safetensors import safe_open
from contextlib import contextmanager
from einops import rearrange, repeat
import torch.nn as nn
import torch.nn.functional as F
import time
import hashlib
import types
from collections import deque
import numpy as np
@dataclass
class Args:
input: str = "/root/example0.mp4"
output_folder: str = "./outputs"
scale: int = 4
version: str = "10"
mode: str = "tiny"
tiled_vae: bool = False
tiled_dit: bool = False
tile_size: int = 256
overlap: int = 24
unload_dit: bool = False
color_fix: bool = False
seed: int = 0
dtype: str = "bf16"
device: str = "auto"
fps: int = 30
quality: int = 6
attention: str = "sage"
args = Args()
def log(message: str, message_type: str = "normal"):
if message_type == "error":
message = "\033[1;41m" + message + "\033[m"
elif message_type == "warning":
message = "\033[1;31m" + message + "\033[m"
elif message_type == "finish":
message = "\033[1;32m" + message + "\033[m"
elif message_type == "info":
message = "\033[1;33m" + message + "\033[m"
else:
message = message
print(f"{message}")
# ----------------------------
# WanModel (no image branch) — init 时即产生 KV 缓存
# ----------------------------
class WanModel(torch.nn.Module):
def __init__(
self,
dim: int,
in_dim: int,
ffn_dim: int,
out_dim: int,
text_dim: int,
freq_dim: int,
eps: float,
patch_size: Tuple[int, int, int],
num_heads: int,
num_layers: int,
# init_context: torch.Tensor, # <<<< 必填:在 __init__ 里用它生成 cross-attn KV 缓存
has_image_input: bool = False,
):
super().__init__()
self.dim = dim
self.freq_dim = freq_dim
self.patch_size = patch_size
# patch embed
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size
)
# text / time embed
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)
)
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)
)
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
# blocks
self.blocks = nn.ModuleList(
[DiTBlock(dim, num_heads, ffn_dim, eps) for _ in range(num_layers)]
)
self.head = Head(dim, out_dim, patch_size, eps)
head_dim = dim // num_heads
self.freqs = precompute_freqs_cis_3d(head_dim)
self._cross_kv_initialized = False
# 可选:手动清空 / 重新初始化
def clear_cross_kv(self):
for blk in self.blocks:
blk.cross_attn.clear_cache()
self._cross_kv_initialized = False
@torch.no_grad()
def reinit_cross_kv(self, new_context: torch.Tensor):
ctx_txt = self.text_embedding(new_context)
for blk in self.blocks:
blk.cross_attn.init_cache(ctx_txt)
self._cross_kv_initialized = True
def patchify(self, x: torch.Tensor):
x = self.patch_embedding(x)
grid_size = x.shape[2:]
x = rearrange(x, "b c f h w -> b (f h w) c").contiguous()
return x, grid_size # x, grid_size: (f, h, w)
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
return rearrange(
x,
"b (f h w) (x y z c) -> b c (f x) (h y) (w z)",
f=grid_size[0],
h=grid_size[1],
w=grid_size[2],
x=self.patch_size[0],
y=self.patch_size[1],
z=self.patch_size[2],
)
def forward(
self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
LQ_latents: Optional[List[torch.Tensor]] = None,
train_img: bool = False,
topk_ratio: Optional[float] = None,
kv_ratio: Optional[float] = None,
local_num: Optional[int] = None,
is_full_block: bool = False,
causal_idx: Optional[int] = None,
**kwargs,
):
# time / text embeds
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
# 这里仍会嵌入 text(CrossAttention 若已有缓存会忽略它)
# context = self.text_embedding(context)
# 输入打补丁
x, (f, h, w) = self.patchify(x)
B = x.shape[0]
# window / masks 超参
win = (2, 8, 8)
seqlen = f // win[0]
if local_num is None:
local_random = random.random()
if local_random < 0.3:
local_num = seqlen - 3
elif local_random < 0.4:
local_num = seqlen - 4
elif local_random < 0.5:
local_num = seqlen - 2
else:
local_num = seqlen
window_size = win[0] * h * w // 128
square_num = window_size * window_size
topk_ratio = 2.0
topk = min(max(int(square_num * topk_ratio), 1), int(square_num * seqlen) - 1)
if kv_ratio is None:
kv_ratio = (random.uniform(0.0, 1.0) ** 2) * (local_num - 2 - 2) + 2
kv_len = min(max(int(window_size * kv_ratio), 1), int(window_size * seqlen) - 1)
decay_ratio = random.uniform(0.7, 1.0)
# RoPE 3D
freqs = (
torch.cat(
[
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
)
.reshape(f * h * w, 1, -1)
.to(x.device)
)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# blocks
for block_id, block in enumerate(self.blocks):
if LQ_latents is not None and block_id < len(LQ_latents):
x += LQ_latents[block_id]
if self.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x,
context,
t_mod,
freqs,
f,
h,
w,
local_num,
topk,
train_img,
block_id,
kv_len,
is_full_block,
False,
None,
None,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x,
context,
t_mod,
freqs,
f,
h,
w,
local_num,
topk,
train_img,
block_id,
kv_len,
is_full_block,
False,
None,
None,
use_reentrant=False,
)
else:
x = block(
x,
context,
t_mod,
freqs,
f,
h,
w,
local_num,
topk,
train_img,
block_id,
kv_len,
is_full_block,
False,
None,
None,
)
x = self.head(x, t)
x = self.unpatchify(x, (f, h, w))
return x
@staticmethod
def state_dict_converter():
return WanModelStateDictConverter()
class WanVideoVAE(nn.Module):
def __init__(self, z_dim=16, dim=96):
super().__init__()
mean = [
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921,
]
std = [
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.9160,
]
self.mean = torch.tensor(mean)
self.std = torch.tensor(std)
self.scale = [self.mean, 1.0 / self.std]
# init model
self.model = VideoVAE_(z_dim=z_dim, dim=dim).eval().requires_grad_(False)
self.upsampling_factor = 8
def build_1d_mask(self, length, left_bound, right_bound, border_width):
x = torch.ones((length,))
if not left_bound:
x[:border_width] = (torch.arange(border_width) + 1) / border_width
if not right_bound:
x[-border_width:] = torch.flip(
(torch.arange(border_width) + 1) / border_width, dims=(0,)
)
return x
def build_mask(self, data, is_bound, border_width):
_, _, _, H, W = data.shape
h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
h = repeat(h, "H -> H W", H=H, W=W)
w = repeat(w, "W -> H W", H=H, W=W)
mask = torch.stack([h, w]).min(dim=0).values
mask = rearrange(mask, "H W -> 1 1 1 H W")
return mask
def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
_, _, T, H, W = hidden_states.shape
size_h, size_w = tile_size
stride_h, stride_w = tile_stride
# Split tasks
tasks = []
for h in range(0, H, stride_h):
if h - stride_h >= 0 and h - stride_h + size_h >= H:
continue
for w in range(0, W, stride_w):
if w - stride_w >= 0 and w - stride_w + size_w >= W:
continue
h_, w_ = h + size_h, w + size_w
tasks.append((h, h_, w, w_))
data_device = "cpu"
computation_device = device
out_T = T * 4 - 3
weight = torch.zeros(
(1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor),
dtype=hidden_states.dtype,
device=data_device,
)
values = torch.zeros(
(1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor),
dtype=hidden_states.dtype,
device=data_device,
)
for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(
computation_device
)
hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(
data_device
)
mask = self.build_mask(
hidden_states_batch,
is_bound=(h == 0, h_ >= H, w == 0, w_ >= W),
border_width=(
(size_h - stride_h) * self.upsampling_factor,
(size_w - stride_w) * self.upsampling_factor,
),
).to(dtype=hidden_states.dtype, device=data_device)
target_h = h * self.upsampling_factor
target_w = w * self.upsampling_factor
values[
:,
:,
:,
target_h : target_h + hidden_states_batch.shape[3],
target_w : target_w + hidden_states_batch.shape[4],
] += (
hidden_states_batch * mask
)
weight[
:,
:,
:,
target_h : target_h + hidden_states_batch.shape[3],
target_w : target_w + hidden_states_batch.shape[4],
] += mask
values = values / weight
values = values.clamp_(-1, 1)
return values
def tiled_encode(self, video, device, tile_size, tile_stride):
_, _, T, H, W = video.shape
size_h, size_w = tile_size
stride_h, stride_w = tile_stride
# Split tasks
tasks = []
for h in range(0, H, stride_h):
if h - stride_h >= 0 and h - stride_h + size_h >= H:
continue
for w in range(0, W, stride_w):
if w - stride_w >= 0 and w - stride_w + size_w >= W:
continue
h_, w_ = h + size_h, w + size_w
tasks.append((h, h_, w, w_))
data_device = "cpu"
computation_device = device
out_T = (T + 3) // 4
weight = torch.zeros(
(1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor),
dtype=video.dtype,
device=data_device,
)
values = torch.zeros(
(1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor),
dtype=video.dtype,
device=data_device,
)
for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(
data_device
)
mask = self.build_mask(
hidden_states_batch,
is_bound=(h == 0, h_ >= H, w == 0, w_ >= W),
border_width=(
(size_h - stride_h) // self.upsampling_factor,
(size_w - stride_w) // self.upsampling_factor,
),
).to(dtype=video.dtype, device=data_device)
target_h = h // self.upsampling_factor
target_w = w // self.upsampling_factor
values[
:,
:,
:,
target_h : target_h + hidden_states_batch.shape[3],
target_w : target_w + hidden_states_batch.shape[4],
] += (
hidden_states_batch * mask
)
weight[
:,
:,
:,
target_h : target_h + hidden_states_batch.shape[3],
target_w : target_w + hidden_states_batch.shape[4],
] += mask
values = values / weight
return values
def single_encode(self, video, device):
video = video.to(device)
x = self.model.encode(video, self.scale)
return x
def single_decode(self, hidden_state, device):
hidden_state = hidden_state.to(device)
video = self.model.decode(hidden_state, self.scale)
return video.clamp_(-1, 1)
def encode(
self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)
):
videos = [video.to("cpu") for video in videos]
hidden_states = []
for video in videos:
video = video.unsqueeze(0)
if tiled:
tile_size = (tile_size[0] * 8, tile_size[1] * 8)
tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
else:
hidden_state = self.single_encode(video, device)
hidden_state = hidden_state.squeeze(0)
hidden_states.append(hidden_state)
hidden_states = torch.stack(hidden_states)
return hidden_states
def decode(
self,
hidden_states,
device,
tiled=False,
tile_size=(34, 34),
tile_stride=(18, 16),
):
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
videos = []
for hidden_state in hidden_states:
hidden_state = hidden_state.unsqueeze(0)
if tiled:
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
else:
video = self.single_decode(hidden_state, device)
video = video.squeeze(0)
videos.append(video)
videos = torch.stack(videos)
return videos
def clear_cache(self):
self.model.clear_cache()
def stream_decode(
self, hidden_states, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)
):
hidden_states = [hidden_state for hidden_state in hidden_states]
assert len(hidden_states) == 1
hidden_state = hidden_states[0]
video = self.model.stream_decode(hidden_state, self.scale)
return video
@staticmethod
def state_dict_converter():
return WanVideoVAEStateDictConverter()
model_loader_configs = [
# These configs are provided for detecting model type automatically.
# The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
(
None,
"9269f8db9040a9d860eaca435be61814",
["wan_video_dit"],
[WanModel],
"civitai",
),
(
None,
"aafcfd9672c3a2456dc46e1cb6e52c70",
["wan_video_dit"],
[WanModel],
"civitai",
),
(
None,
"6bfcfb3b342cb286ce886889d519a77e",
["wan_video_dit"],
[WanModel],
"civitai",
),
(
None,
"6d6ccde6845b95ad9114ab993d917893",
["wan_video_dit"],
[WanModel],
"civitai",
),
(
None,
"6bfcfb3b342cb286ce886889d519a77e",
["wan_video_dit"],
[WanModel],
"civitai",
),
(
None,
"349723183fc063b2bfc10bb2835cf677",
["wan_video_dit"],
[WanModel],
"civitai",
),
(
None,
"efa44cddf936c70abd0ea28b6cbe946c",
["wan_video_dit"],
[WanModel],
"civitai",
),
(
None,
"3ef3b1f8e1dab83d5b71fd7b617f859f",
["wan_video_dit"],
[WanModel],
"civitai",
),
(
None,
"cb104773c6c2cb6df4f9529ad5c60d0b",
["wan_video_dit"],
[WanModel],
"diffusers",
),
(
None,
"1378ea763357eea97acdef78e65d6d96",
["wan_video_vae"],
[WanVideoVAE],
"civitai",
),
(
None,
"ccc42284ea13e1ad04693284c7a09be6",
["wan_video_vae"],
[WanVideoVAE],
"civitai",
),
]
huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically.
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
]
patch_model_loader_configs = [
# These configs are provided for detecting model type automatically.
# The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
]
def load_model_from_single_file(
state_dict, model_names, model_classes, model_resource, torch_dtype, device
):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
# print(f" model_name: {model_name} model_class: {model_class.__name__}")
state_dict_converter = model_class.state_dict_converter()
if model_resource == "civitai":
state_dict_results = state_dict_converter.from_civitai(state_dict)
elif model_resource == "diffusers":
state_dict_results = state_dict_converter.from_diffusers(state_dict)
if isinstance(state_dict_results, tuple):
model_state_dict, extra_kwargs = state_dict_results
# print(f" This model is initialized with extra kwargs: {extra_kwargs}")
else:
model_state_dict, extra_kwargs = state_dict_results, {}
torch_dtype = (
torch.float32
if extra_kwargs.get("upcast_to_float32", False)
else torch_dtype
)
with init_weights_on_device():
model = model_class(**extra_kwargs)
if hasattr(model, "eval"):
model = model.eval()
model.load_state_dict(model_state_dict, assign=True)
model = model.to(dtype=torch_dtype, device=device)
loaded_model_names.append(model_name)
loaded_models.append(model)
return loaded_model_names, loaded_models
def load_model_from_huggingface_folder(
file_path, model_names, model_classes, torch_dtype, device
):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
model = model_class.from_pretrained(
file_path, torch_dtype=torch_dtype
).eval()
else:
model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
if torch_dtype == torch.float16 and hasattr(model, "half"):
model = model.half()
try:
model = model.to(device=device)
except:
pass
loaded_model_names.append(model_name)
loaded_models.append(model)
return loaded_model_names, loaded_models
def load_single_patch_model_from_single_file(
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device
):
# print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
base_state_dict = base_model.state_dict()
base_model.to("cpu")
del base_model
model = model_class(**extra_kwargs)
model.load_state_dict(base_state_dict, strict=False)
model.load_state_dict(state_dict, strict=False)
model.to(dtype=torch_dtype, device=device)
return model
def load_patch_model_from_single_file(
state_dict,
model_names,
model_classes,
extra_kwargs,
model_manager,
torch_dtype,
device,
):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
while True:
for model_id in range(len(model_manager.model)):
base_model_name = model_manager.model_name[model_id]
if base_model_name == model_name:
base_model_path = model_manager.model_path[model_id]
base_model = model_manager.model[model_id]
print(
f" Adding patch model to {base_model_name} ({base_model_path})"
)
patched_model = load_single_patch_model_from_single_file(
state_dict,
model_name,
model_class,
base_model,
extra_kwargs,
torch_dtype,
device,
)
loaded_model_names.append(base_model_name)
loaded_models.append(patched_model)
model_manager.model.pop(model_id)
model_manager.model_path.pop(model_id)
model_manager.model_name.pop(model_id)
break
else:
break
return loaded_model_names, loaded_models
class ModelDetectorTemplate:
def __init__(self):
pass
def match(self, file_path="", state_dict={}):
return False
def load(
self,
file_path="",
state_dict={},
device="cuda",
torch_dtype=torch.float16,
**kwargs,
):
return [], []
class ModelDetectorFromSingleFile:
def __init__(self, model_loader_configs=[]):
self.keys_hash_with_shape_dict = {}
self.keys_hash_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(
self,
keys_hash,
keys_hash_with_shape,
model_names,
model_classes,
model_resource,
):
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (
model_names,
model_classes,
model_resource,
)
if keys_hash is not None:
self.keys_hash_dict[keys_hash] = (
model_names,
model_classes,
model_resource,
)
def match(self, file_path="", state_dict={}):
if isinstance(file_path, str) and os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
return True
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
if keys_hash in self.keys_hash_dict:
return True
return False
def load(
self,
file_path="",
state_dict={},
device="cuda",
torch_dtype=torch.float16,
**kwargs,
):
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
# Load models with strict matching
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[
keys_hash_with_shape
]
loaded_model_names, loaded_models = load_model_from_single_file(
state_dict,
model_names,
model_classes,
model_resource,
torch_dtype,
device,
)
return loaded_model_names, loaded_models
# Load models without strict matching
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
if keys_hash in self.keys_hash_dict:
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
loaded_model_names, loaded_models = load_model_from_single_file(
state_dict,
model_names,
model_classes,
model_resource,
torch_dtype,
device,
)
return loaded_model_names, loaded_models
return loaded_model_names, loaded_models
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
def __init__(self, model_loader_configs=[]):
super().__init__(model_loader_configs)
def match(self, file_path="", state_dict={}):
if isinstance(file_path, str) and os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
splited_state_dict = split_state_dict_with_prefix(state_dict)
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
return True
return False
def load(
self,
file_path="",
state_dict={},
device="cuda",
torch_dtype=torch.float16,
**kwargs,
):
# Split the state_dict and load from each component
splited_state_dict = split_state_dict_with_prefix(state_dict)
valid_state_dict = {}
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
valid_state_dict.update(sub_state_dict)
if super().match(file_path, valid_state_dict):
loaded_model_names, loaded_models = super().load(
file_path, valid_state_dict, device, torch_dtype
)
else:
loaded_model_names, loaded_models = [], []
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
loaded_model_names_, loaded_models_ = super().load(
file_path, valid_state_dict, device, torch_dtype
)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelDetectorFromHuggingfaceFolder:
def __init__(self, model_loader_configs=[]):
self.architecture_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(
self, architecture, huggingface_lib, model_name, redirected_architecture
):
self.architecture_dict[architecture] = (
huggingface_lib,
model_name,
redirected_architecture,
)
def match(self, file_path="", state_dict={}):
if not isinstance(file_path, str) or os.path.isfile(file_path):
return False
file_list = os.listdir(file_path)
if "config.json" not in file_list:
return False
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
if "architectures" not in config and "_class_name" not in config:
return False
return True
def load(
self,
file_path="",
state_dict={},
device="cuda",
torch_dtype=torch.float16,
**kwargs,
):
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
loaded_model_names, loaded_models = [], []
architectures = (
config["architectures"]
if "architectures" in config
else [config["_class_name"]]
)
for architecture in architectures:
huggingface_lib, model_name, redirected_architecture = (
self.architecture_dict[architecture]
)
if redirected_architecture is not None:
architecture = redirected_architecture
model_class = importlib.import_module(huggingface_lib).__getattribute__(
architecture
)
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(
file_path, [model_name], [model_class], torch_dtype, device
)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelDetectorFromPatchedSingleFile:
def __init__(self, model_loader_configs=[]):
self.keys_hash_with_shape_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(
self, keys_hash_with_shape, model_name, model_class, extra_kwargs
):
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (
model_name,
model_class,
extra_kwargs,
)
def match(self, file_path="", state_dict={}):
if not isinstance(file_path, str) or os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
return True
return False
def load(
self,
file_path="",
state_dict={},
device="cuda",
torch_dtype=torch.float16,
model_manager=None,
**kwargs,
):
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
# Load models with strict matching
loaded_model_names, loaded_models = [], []
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[
keys_hash_with_shape
]
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
state_dict,
model_names,
model_classes,
extra_kwargs,
model_manager,
torch_dtype,
device,
)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelManager:
def __init__(
self,
torch_dtype=torch.float16,
device="cuda",
file_path_list: List[str] = [],
):
self.torch_dtype = torch_dtype
self.device = device
self.model = []
self.model_path = []
self.model_name = []
self.model_detector = [
ModelDetectorFromSingleFile(model_loader_configs),
ModelDetectorFromSplitedSingleFile(model_loader_configs),
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
]
self.load_models(file_path_list)
def load_model_from_single_file(
self,
file_path="",
state_dict={},
model_names=[],
model_classes=[],
model_resource=None,
):
print(f"Loading models from file: {file_path}")
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
model_names, models = load_model_from_single_file(
state_dict,
model_names,
model_classes,
model_resource,
self.torch_dtype,
self.device,
)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
# print(f" The following models are loaded: {model_names}.")
def load_model_from_huggingface_folder(
self, file_path="", model_names=[], model_classes=[]
):
print(f"Loading models from folder: {file_path}")
model_names, models = load_model_from_huggingface_folder(
file_path, model_names, model_classes, self.torch_dtype, self.device
)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
# print(f" The following models are loaded: {model_names}.")
def load_patch_model_from_single_file(
self,
file_path="",
state_dict={},
model_names=[],
model_classes=[],
extra_kwargs={},
):
print(f"Loading patch models from file: {file_path}")
model_names, models = load_patch_model_from_single_file(
state_dict,
model_names,
model_classes,
extra_kwargs,
self,
self.torch_dtype,
self.device,
)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following patched models are loaded: {model_names}.")
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
if isinstance(file_path, list):
for file_path_ in file_path:
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
else:
print(f"Loading LoRA models from file: {file_path}")
is_loaded = False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
for model_name, model, model_path in zip(
self.model_name, self.model, self.model_path
):
for lora in get_lora_loaders():
match_results = lora.match(model, state_dict)
if match_results is not None:
print(f" Adding LoRA to {model_name} ({model_path}).")
lora_prefix, model_resource = match_results
lora.load(
model,
state_dict,
lora_prefix,
alpha=lora_alpha,
model_resource=model_resource,
)
is_loaded = True
break
if not is_loaded:
print(f" Cannot load LoRA: {file_path}")
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
# print(f"Loading models from: {file_path}")
if device is None:
device = self.device
if torch_dtype is None:
torch_dtype = self.torch_dtype
if isinstance(file_path, list):
state_dict = {}
for path in file_path:
state_dict.update(load_state_dict(path))
elif os.path.isfile(file_path):
state_dict = load_state_dict(file_path)
else:
state_dict = None
for model_detector in self.model_detector:
if model_detector.match(file_path, state_dict):
model_names, models = model_detector.load(
file_path,
state_dict,
device=device,
torch_dtype=torch_dtype,
allowed_model_names=model_names,
model_manager=self,
)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
# print(f" The following models are loaded: {model_names}.")
break
else:
print(f" We cannot detect the model type. No models are loaded.")
def load_models(
self, file_path_list, model_names=None, device=None, torch_dtype=None
):
for file_path in file_path_list:
self.load_model(
file_path, model_names, device=device, torch_dtype=torch_dtype
)
def fetch_model(self, model_name, file_path=None, require_model_path=False):
fetched_models = []
fetched_model_paths = []
for model, model_path, model_name_ in zip(
self.model, self.model_path, self.model_name
):
if file_path is not None and file_path != model_path:
continue
if model_name == model_name_:
fetched_models.append(model)
fetched_model_paths.append(model_path)
if len(fetched_models) == 0:
# print(f"No {model_name} models available.")
return None
if len(fetched_models) == 1:
print(f"Using {model_name} from {fetched_model_paths[0]}")
else:
print(
f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}"
)
if require_model_path:
return fetched_models[0], fetched_model_paths[0]
else:
return fetched_models[0]
def to(self, device):
for model in self.model:
model.to(device)
CACHE_T = 2
def check_is_instance(model, module_class):
if isinstance(model, module_class):
return True
if hasattr(model, "module") and isinstance(model.module, module_class):
return True
return False
def block_causal_mask(x, block_size):
# params
b, n, s, _, device = *x.size(), x.device
assert s % block_size == 0
num_blocks = s // block_size
# build mask
mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
for i in range(num_blocks):
mask[:, :, i * block_size : (i + 1) * block_size, : (i + 1) * block_size] = 1
return mask
class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolusion.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (
self.padding[2],
self.padding[2],
self.padding[1],
self.padding[1],
2 * self.padding[0],
0,
)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
# print('cache_x.shape', cache_x.shape, 'x.shape', x.shape)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
return super().forward(x)
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return (
F.normalize(x, dim=(1 if self.channel_first else -1))
* self.scale
* self.gamma
+ self.bias
)
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in (
"none",
"upsample2d",
"upsample3d",
"downsample2d",
"downsample3d",
)
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == "upsample2d":
self.resample = nn.Sequential(
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
nn.Conv2d(dim, dim // 2, 3, padding=1),
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
nn.Conv2d(dim, dim // 2, 3, padding=1),
)
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
)
elif mode == "downsample3d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
)
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
)
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == "upsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if (
cache_x.shape[2] < 2
and feat_cache[idx] is not None
and feat_cache[idx] != "Rep"
):
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
cache_x,
],
dim=2,
)
if (
cache_x.shape[2] < 2
and feat_cache[idx] is not None
and feat_cache[idx] == "Rep"
):
cache_x = torch.cat(
[torch.zeros_like(cache_x).to(cache_x.device), cache_x],
dim=2,
)
if feat_cache[idx] == "Rep":
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.resample(x)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
if self.mode == "downsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)
)
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
def init_weight(self, conv):
conv_weight = conv.weight
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
one_matrix = torch.eye(c1, c2)
init_matrix = one_matrix
nn.init.zeros_(conv_weight)
conv_weight.data[:, :, 1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def init_weight2(self, conv):
conv_weight = conv.weight.data
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False),
nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False),
nn.SiLU(),
nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1),
)
self.shortcut = (
CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
)
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
for layer in self.residual:
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
cache_x,
],
dim=2,
)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + h
class AttentionBlock(nn.Module):
"""
Causal self-attention with a single head.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = RMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
# zero out the last layer params
nn.init.zeros_(self.proj.weight)
def forward(self, x):
identity = x
b, c, t, h, w = x.size()
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.norm(x)
# compute query, key, value
q, k, v = (
self.to_qkv(x)
.reshape(b * t, 1, c * 3, -1)
.permute(0, 1, 3, 2)
.contiguous()
.chunk(3, dim=-1)
)
# apply attention
x = F.scaled_dot_product_attention(
q,
k,
v,
# attn_mask=block_causal_mask(q, block_size=h * w)
)
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
# output
x = self.proj(x)
x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
return x + identity
class Encoder3d(nn.Module):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
downsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
downsamples.append(Resample(out_dim, mode=mode))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(out_dim, out_dim, dropout),
AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim, dropout),
)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False),
nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1),
)
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle:
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
cache_x,
],
dim=2,
)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
class Decoder3d(nn.Module):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2 ** (len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout),
)
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i == 1 or i == 2 or i == 3:
in_dim = in_dim // 2
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
upsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# upsample block
if i != len(dim_mult) - 1:
mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False),
nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1),
)
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## middle
for layer in self.middle:
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
cache_x,
],
dim=2,
)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_conv3d(model):
count = 0
for m in model.modules():
if check_is_instance(m, CausalConv3d):
count += 1
return count
class VideoVAE_(nn.Module):
def __init__(
self,
dim=96,
z_dim=16,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[False, True, True],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(
dim,
z_dim * 2,
dim_mult,
num_res_blocks,
attn_scales,
self.temperal_downsample,
dropout,
)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(
dim,
z_dim,
dim_mult,
num_res_blocks,
attn_scales,
self.temperal_upsample,
dropout,
)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z)
return x_recon, mu, log_var
def encode(self, x, scale):
self.clear_cache()
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor):
scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1, self.z_dim, 1, 1, 1
)
else:
scale = scale.to(dtype=mu.dtype, device=mu.device)
mu = (mu - scale[0]) * scale[1]
return mu
def decode(self, z, scale):
self.clear_cache()
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1
)
else:
scale = scale.to(dtype=z.dtype, device=z.device)
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i : i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
)
else:
out_ = self.decoder(
x[:, :, i : i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
)
out = torch.cat([out, out_], 2) # may add tensor offload
return out
def stream_decode(self, z, scale):
# self.clear_cache()
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1
)
else:
scale = scale.to(dtype=z.dtype, device=z.device)
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i : i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
)
else:
out_ = self.decoder(
x[:, :, i : i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
)
out = torch.cat([out, out_], 2) # may add tensor offload
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# print('self._feat_map', len(self._feat_map))
# cache encode
if self.encoder is not None:
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
# print('self._enc_feat_map', len(self._enc_feat_map))
class WanVideoVAEStateDictConverter:
def __init__(self):
pass
def from_civitai(self, state_dict):
state_dict_ = {}
if "model_state" in state_dict:
state_dict = state_dict["model_state"]
for name in state_dict:
state_dict_["model." + name] = state_dict[name]
return state_dict_
class FlowMatchScheduler:
def __init__(
self,
num_inference_steps=100,
num_train_timesteps=1000,
shift=3.0,
sigma_max=1.0,
sigma_min=0.003 / 1.002,
inverse_timesteps=False,
extra_one_step=False,
reverse_sigmas=False,
):
self.num_train_timesteps = num_train_timesteps
self.shift = shift
self.sigma_max = sigma_max
self.sigma_min = sigma_min
self.inverse_timesteps = inverse_timesteps
self.extra_one_step = extra_one_step
self.reverse_sigmas = reverse_sigmas
self.set_timesteps(num_inference_steps)
def set_timesteps(
self,
num_inference_steps=100,
denoising_strength=1.0,
training=False,
shift=None,
):
if shift is not None:
self.shift = shift
sigma_start = (
self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
)
if self.extra_one_step:
self.sigmas = torch.linspace(
sigma_start, self.sigma_min, num_inference_steps + 1
)[:-1]
else:
self.sigmas = torch.linspace(
sigma_start, self.sigma_min, num_inference_steps
)
if self.inverse_timesteps:
self.sigmas = torch.flip(self.sigmas, dims=[0])
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
if self.reverse_sigmas:
self.sigmas = 1 - self.sigmas
self.timesteps = self.sigmas * self.num_train_timesteps
if training:
x = self.timesteps
y = torch.exp(
-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2
)
y_shifted = y - y.min()
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
self.linear_timesteps_weights = bsmntw_weighing
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
if to_final or timestep_id + 1 >= len(self.timesteps):
sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
else:
sigma_ = self.sigmas[timestep_id + 1]
prev_sample = sample + model_output * (sigma_ - sigma)
return prev_sample
def return_to_timestep(self, timestep, sample, sample_stablized):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
model_output = (sample - sample_stablized) / sigma
return model_output
def add_noise(self, original_samples, noise, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
sample = (1 - sigma) * original_samples + sigma * noise
return sample
def training_target(self, sample, noise, timestep):
target = noise - sample
return target
def training_weight(self, timestep):
timestep_id = torch.argmin(
(self.timesteps - timestep.to(self.timesteps.device)).abs()
)
weights = self.linear_timesteps_weights[timestep_id]
return weights
class BasePipeline(torch.nn.Module):
def __init__(
self,
device="cuda",
torch_dtype=torch.float16,
height_division_factor=64,
width_division_factor=64,
):
super().__init__()
self.device = device
self.torch_dtype = torch_dtype
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.cpu_offload = False
self.model_names = []
def check_resize_height_width(self, height, width):
if height % self.height_division_factor != 0:
height = (
(height + self.height_division_factor - 1)
// self.height_division_factor
* self.height_division_factor
)
print(
f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}."
)
if width % self.width_division_factor != 0:
width = (
(width + self.width_division_factor - 1)
// self.width_division_factor
* self.width_division_factor
)
print(
f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}."
)
return height, width
def preprocess_image(self, image):
image = (
torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1)
.permute(2, 0, 1)
.unsqueeze(0)
)
return image
def preprocess_images(self, images):
return [self.preprocess_image(image) for image in images]
def vae_output_to_image(self, vae_output):
image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
def vae_output_to_video(self, vae_output):
video = vae_output.cpu().permute(1, 2, 0).numpy()
video = [
Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
for image in video
]
return video
def merge_latents(
self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0
):
if len(latents) > 0:
blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
height, width = value.shape[-2:]
weight = torch.ones_like(value)
for latent, mask, scale in zip(latents, masks, scales):
mask = (
self.preprocess_image(mask.resize((width, height))).mean(
dim=1, keepdim=True
)
> 0
)
mask = mask.repeat(1, latent.shape[1], 1, 1).to(
dtype=latent.dtype, device=latent.device
)
mask = blur(mask)
value += latent * mask * scale
weight += mask * scale
value /= weight
return value
def control_noise_via_local_prompts(
self,
prompt_emb_global,
prompt_emb_locals,
masks,
mask_scales,
inference_callback,
special_kwargs=None,
special_local_kwargs_list=None,
):
if special_kwargs is None:
noise_pred_global = inference_callback(prompt_emb_global)
else:
noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
if special_local_kwargs_list is None:
noise_pred_locals = [
inference_callback(prompt_emb_local)
for prompt_emb_local in prompt_emb_locals
]
else:
noise_pred_locals = [
inference_callback(prompt_emb_local, special_kwargs)
for prompt_emb_local, special_kwargs in zip(
prompt_emb_locals, special_local_kwargs_list
)
]
noise_pred = self.merge_latents(
noise_pred_global, noise_pred_locals, masks, mask_scales
)
return noise_pred
def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
local_prompts = local_prompts or []
masks = masks or []
mask_scales = mask_scales or []
extended_prompt_dict = self.prompter.extend_prompt(prompt)
prompt = extended_prompt_dict.get("prompt", prompt)
local_prompts += extended_prompt_dict.get("prompts", [])
masks += extended_prompt_dict.get("masks", [])
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
return prompt, local_prompts, masks, mask_scales
def enable_cpu_offload(self):
self.cpu_offload = True
def load_models_to_device(self, loadmodel_names=[]):
# only load models to device if cpu_offload is enabled
if not self.cpu_offload:
return
# offload the unneeded models to cpu
for model_name in self.model_names:
if model_name not in loadmodel_names:
model = getattr(self, model_name)
if model is not None:
if (
hasattr(model, "vram_management_enabled")
and model.vram_management_enabled
):
for module in model.modules():
if hasattr(module, "offload"):
module.offload()
else:
model.cpu()
# load the needed models to device
for model_name in loadmodel_names:
model = getattr(self, model_name)
if model is not None:
if (
hasattr(model, "vram_management_enabled")
and model.vram_management_enabled
):
for module in model.modules():
if hasattr(module, "onload"):
module.onload()
else:
model.to(self.device)
# fresh the cuda cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
generator = None if seed is None else torch.Generator(device).manual_seed(seed)
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
return noise
# -----------------------------
# 基础工具:ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet)
# -----------------------------
def _calc_mean_std(
feat: torch.Tensor, eps: float = 1e-5
) -> Tuple[torch.Tensor, torch.Tensor]:
assert feat.dim() == 4, "feat 必须是 (N, C, H, W)"
N, C = feat.shape[:2]
var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
std = var.sqrt().view(N, C, 1, 1)
mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return mean, std
def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
size = content_feat.size()
style_mean, style_std = _calc_mean_std(style_feat)
content_mean, content_std = _calc_mean_std(content_feat)
normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized * style_std.expand(size) + style_mean.expand(size)
# -----------------------------
# 小波式模糊与分解/重构(ColorCorrector 用)
# -----------------------------
def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625],
]
return torch.tensor(vals, dtype=dtype, device=device)
def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
assert x.dim() == 4, "x 必须是 (N, C, H, W)"
N, C, H, W = x.shape
base = _make_gaussian3x3_kernel(x.dtype, x.device)
weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
pad = radius
x_pad = F.pad(x, (pad, pad, pad, pad), mode="replicate")
out = F.conv2d(
x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C
)
return out
def _wavelet_decompose(
x: torch.Tensor, levels: int = 5
) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 4, "x 必须是 (N, C, H, W)"
high = torch.zeros_like(x)
low = x
for i in range(levels):
radius = 2**i
blurred = _wavelet_blur(low, radius)
high = high + (low - blurred)
low = blurred
return high, low
def _wavelet_reconstruct(
content: torch.Tensor, style: torch.Tensor, levels: int = 5
) -> torch.Tensor:
c_high, _ = _wavelet_decompose(content, levels=levels)
_, s_low = _wavelet_decompose(style, levels=levels)
return c_high + s_low
# -----------------------------
# 无状态颜色矫正模块(视频友好,默认 wavelet)
# -----------------------------
class TorchColorCorrectorWavelet(nn.Module):
def __init__(self, levels: int = 5):
super().__init__()
self.levels = levels
@staticmethod
def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
assert x.dim() == 5, "输入必须是 (B, C, f, H, W)"
B, C, f, H, W = x.shape
y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
return y, B, f
@staticmethod
def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
BF, C, H, W = y.shape
assert BF == B * f
return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
def forward(
self,
hq_image: torch.Tensor, # (B, C, f, H, W)
lq_image: torch.Tensor, # (B, C, f, H, W)
clip_range: Tuple[float, float] = (-1.0, 1.0),
method: Literal["wavelet", "adain"] = "wavelet",
chunk_size: Optional[int] = None,
) -> torch.Tensor:
assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
assert (
hq_image.dim() == 5 and hq_image.shape[1] == 3
), "输入必须是 (B, 3, f, H, W)"
B, C, f, H, W = hq_image.shape
if chunk_size is None or chunk_size >= f:
hq4, B, f = self._flatten_time(hq_image)
lq4, _, _ = self._flatten_time(lq_image)
if method == "wavelet":
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
elif method == "adain":
out4 = _adain(hq4, lq4)
else:
raise ValueError(f"未知 method: {method}")
out4 = torch.clamp(out4, *clip_range)
out = self._unflatten_time(out4, B, f)
return out
outs = []
for start in range(0, f, chunk_size):
end = min(start + chunk_size, f)
hq_chunk = hq_image[:, :, start:end]
lq_chunk = lq_image[:, :, start:end]
hq4, B_, f_ = self._flatten_time(hq_chunk)
lq4, _, _ = self._flatten_time(lq_chunk)
if method == "wavelet":
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
elif method == "adain":
out4 = _adain(hq4, lq4)
else:
raise ValueError(f"未知 method: {method}")
out4 = torch.clamp(out4, *clip_range)
out_chunk = self._unflatten_time(out4, B_, f_)
outs.append(out_chunk)
out = torch.cat(outs, dim=2)
return out
def cast_to(weight, dtype, device):
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight)
return r
class AutoWrappedModule(torch.nn.Module):
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
super().__init__()
self.module = module.to(dtype=offload_dtype, device=offload_device)
self.offload_dtype = offload_dtype
self.offload_device = offload_device
self.onload_dtype = onload_dtype
self.onload_device = onload_device
self.computation_dtype = computation_dtype
self.computation_device = computation_device
self.state = 0
def offload(self):
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.module.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.module.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def forward(self, *args, **kwargs):
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
module = self.module
else:
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
return module(*args, **kwargs)
class AutoWrappedLinear(torch.nn.Linear):
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
with init_weights_on_device(device=torch.device("meta")):
super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
self.weight = module.weight
self.bias = module.bias
self.offload_dtype = offload_dtype
self.offload_device = offload_device
self.onload_dtype = onload_dtype
self.onload_device = onload_device
self.computation_dtype = computation_dtype
self.computation_device = computation_device
self.state = 0
def offload(self):
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def forward(self, x, *args, **kwargs):
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
weight, bias = self.weight, self.bias
else:
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
return torch.nn.functional.linear(x, weight, bias)
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
for name, module in model.named_children():
for source_module, target_module in module_map.items():
if isinstance(module, source_module):
num_param = sum(p.numel() for p in module.parameters())
if max_num_param is not None and total_num_param + num_param > max_num_param:
module_config_ = overflow_module_config
else:
module_config_ = module_config
module_ = target_module(module, **module_config_)
setattr(model, name, module_)
total_num_param += num_param
break
else:
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
return total_num_param
def enable_vram_management_(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
model.vram_management_enabled = True
# -----------------------------
# 简化版 Pipeline(仅 dit + vae)
# -----------------------------
class FlashVSRTinyPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
self.dit: WanModel = None
self.vae: WanVideoVAE = None
self.model_names = ["dit", "vae"]
self.height_division_factor = 16
self.width_division_factor = 16
self.use_unified_sequence_parallel = False
self.prompt_emb_posi = None
self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
print(
r"""
███████╗██╗ █████╗ ███████╗██╗ ██╗██╗ ██╗███████╗█████╗
██╔════╝██║ ██╔══██╗██╔════╝██║ ██║██║ ██║██╔════╝██╔══██╗ ██╗
█████╗ ██║ ███████║███████╗███████║╚██╗ ██╔╝███████╗███████║ ██████╗
██╔══╝ ██║ ██╔══██║╚════██║██╔══██║ ╚████╔╝ ╚════██║██╔═██║ ██╔═╝
██║ ███████╗██║ ██║███████║██║ ██║ ╚██╔╝ ███████║██║ ██║ ╚═╝
╚═╝ ╚══════╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝
"""
)
def enable_vram_management(self, num_persistent_param_in_dit=None):
# 仅管理 dit / vae
dtype = next(iter(self.dit.parameters())).dtype
enable_vram_management_(
self.dit,
module_map={
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
RMSNorm: AutoWrappedModule,
},
module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
self.enable_cpu_offload()
def fetch_models(self, model_manager: ModelManager):
self.dit = model_manager.fetch_model("wan_video_dit")
self.vae = model_manager.fetch_model("wan_video_vae")
@staticmethod
def from_model_manager(
model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False
):
if device is None:
device = model_manager.device
if torch_dtype is None:
torch_dtype = model_manager.torch_dtype
pipe = FlashVSRTinyPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager)
# 可选:统一序列并行入口(此处默认关闭)
pipe.use_unified_sequence_parallel = False
return pipe
def denoising_model(self):
return self.dit
# -------------------------
# 新增:显式 KV 预初始化函数
# -------------------------
def init_cross_kv(
self, context_tensor: Optional[torch.Tensor] = None, prompt_path=None
):
self.load_models_to_device(["dit"])
"""
使用固定 prompt 生成文本 context,并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
必须在 __call__ 前显式调用一次。
"""
# prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
if self.dit is None:
raise RuntimeError(
"请先通过 fetch_models / from_model_manager 初始化 self.dit"
)
if context_tensor is None:
if prompt_path is None:
raise ValueError(
"init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一"
)
ctx = torch.load(prompt_path, map_location=self.device)
else:
ctx = context_tensor
ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
if self.prompt_emb_posi is None:
self.prompt_emb_posi = {}
self.prompt_emb_posi["context"] = ctx
self.prompt_emb_posi["stats"] = "load"
if hasattr(self.dit, "reinit_cross_kv"):
self.dit.reinit_cross_kv(ctx)
else:
raise AttributeError(
"WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。"
)
self.timestep = torch.tensor(
[1000.0], device=self.device, dtype=self.torch_dtype
)
self.t = self.dit.time_embedding(
sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep)
)
self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
# Scheduler
self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
self.load_models_to_device([])
def prepare_unified_sequence_parallel(self):
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
def prepare_extra_input(self, latents=None):
return {}
def encode_video(
self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)
):
latents = self.vae.encode(
input_video,
device=self.device,
tiled=tiled,
tile_size=tile_size,
tile_stride=tile_stride,
)
return latents
def _decode_video(
self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)
):
frames = self.vae.decode(
latents,
device=self.device,
tiled=tiled,
tile_size=tile_size,
tile_stride=tile_stride,
)
return frames
def decode_video(self, latents, cond=None, **kwargs):
frames = (
self.TCDecoder.decode_video(
latents.transpose(1, 2), # TCDecoder 需要 (B, F, C, H, W)
parallel=False,
show_progress_bar=False,
cond=cond,
)
.transpose(1, 2)
.mul_(2)
.sub_(1)
) # 转回 (B, C, F, H, W) 格式,范围 -1 to 1
return frames
def offload_model(self, keep_vae=False):
self.dit.clear_cross_kv()
self.prompt_emb_posi["stats"] = "offload"
self.load_models_to_device([])
if hasattr(self.dit, "LQ_proj_in"):
self.dit.LQ_proj_in.to("cpu")
if not keep_vae:
self.TCDecoder.to("cpu")
@torch.no_grad()
def __call__(
self,
prompt=None,
negative_prompt="",
denoising_strength=1.0,
seed=None,
rand_device="gpu",
height=480,
width=832,
num_frames=81,
cfg_scale=5.0,
num_inference_steps=50,
sigma_shift=5.0,
tiled=True,
tile_size=(60, 104),
tile_stride=(30, 52),
tea_cache_l1_thresh=None,
tea_cache_model_id="Wan2.1-T2V-1.3B",
progress_bar_cmd=tqdm,
progress_bar_st=None,
LQ_video=None,
is_full_block=False,
if_buffer=False,
topk_ratio=2.0,
kv_ratio=3.0,
local_range=9,
color_fix=True,
unload_dit=False,
force_offload=False,
**kwargs,
):
# 只接受 cfg=1.0(与原代码一致)
assert cfg_scale == 1.0, "cfg_scale must be 1.0"
# 要求:必须先 init_cross_kv()
if self.prompt_emb_posi is None or "context" not in self.prompt_emb_posi:
raise RuntimeError(
"Cross-Attention KV not initialized. Please call __call__ only after:\n"
" pipe.init_cross_kv()\n"
"Or provide a custom context:\n"
" pipe.init_cross_kv(context_tensor=your_context_tensor)"
)
# 尺寸修正
height, width = self.check_resize_height_width(height, width)
if num_frames % 4 != 1:
num_frames = (num_frames + 2) // 4 * 4 + 1
print(
f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}."
)
# Tiler 参数
tiler_kwargs = {
"tiled": tiled,
"tile_size": tile_size,
"tile_stride": tile_stride,
}
# 初始化噪声
if if_buffer:
noise = self.generate_noise(
(1, 16, (num_frames - 1) // 4, height // 8, width // 8),
seed=seed,
device=self.device,
dtype=self.torch_dtype,
)
else:
noise = self.generate_noise(
(1, 16, (num_frames - 1) // 4 + 1, height // 8, width // 8),
seed=seed,
device=self.device,
dtype=self.torch_dtype,
)
# noise = noise.to(dtype=self.torch_dtype, device=self.device)
latents = noise
process_total_num = (num_frames - 1) // 8 - 2
is_stream = True
if self.prompt_emb_posi["stats"] == "offload":
self.init_cross_kv(context_tensor=self.prompt_emb_posi["context"])
self.load_models_to_device(["dit"])
self.dit.LQ_proj_in.to(self.device)
self.TCDecoder.to(self.device)
# 清理可能存在的 LQ_proj_in cache
if hasattr(self.dit, "LQ_proj_in"):
self.dit.LQ_proj_in.clear_cache()
latents_total = []
self.TCDecoder.clean_mem()
LQ_pre_idx = 0
LQ_cur_idx = 0
with torch.no_grad():
for cur_process_idx in progress_bar_cmd(range(process_total_num)):
if cur_process_idx == 0:
pre_cache_k = [None] * len(self.dit.blocks)
pre_cache_v = [None] * len(self.dit.blocks)
LQ_latents = None
inner_loop_num = 7
for inner_idx in range(inner_loop_num):
cur = (
self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[
:,
:,
max(0, inner_idx * 4 - 3) : (inner_idx + 1) * 4 - 3,
:,
:,
]
)
if LQ_video is not None
else None
)
if cur is None:
continue
if LQ_latents is None:
LQ_latents = cur
else:
for layer_idx in range(len(LQ_latents)):
LQ_latents[layer_idx] = torch.cat(
[LQ_latents[layer_idx], cur[layer_idx]], dim=1
)
LQ_cur_idx = (inner_loop_num - 1) * 4 - 3
cur_latents = latents[:, :, :6, :, :]
else:
LQ_latents = None
inner_loop_num = 2
for inner_idx in range(inner_loop_num):
cur = (
self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[
:,
:,
cur_process_idx * 8
+ 17
+ inner_idx * 4 : cur_process_idx * 8
+ 21
+ inner_idx * 4,
:,
:,
]
)
if LQ_video is not None
else None
)
if cur is None:
continue
if LQ_latents is None:
LQ_latents = cur
else:
for layer_idx in range(len(LQ_latents)):
LQ_latents[layer_idx] = torch.cat(
[LQ_latents[layer_idx], cur[layer_idx]], dim=1
)
LQ_cur_idx = cur_process_idx * 8 + 21 + (inner_loop_num - 2) * 4
cur_latents = latents[
:, :, 4 + cur_process_idx * 2 : 6 + cur_process_idx * 2, :, :
]
# 推理(无 motion_controller / vace)
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
self.dit,
x=cur_latents,
timestep=self.timestep,
context=None,
tea_cache=None,
use_unified_sequence_parallel=False,
LQ_latents=LQ_latents,
is_full_block=is_full_block,
is_stream=is_stream,
pre_cache_k=pre_cache_k,
pre_cache_v=pre_cache_v,
topk_ratio=topk_ratio,
kv_ratio=kv_ratio,
cur_process_idx=cur_process_idx,
t_mod=self.t_mod,
t=self.t,
local_range=local_range,
)
# 更新 latent
cur_latents = cur_latents - noise_pred_posi
latents_total.append(cur_latents)
LQ_pre_idx = LQ_cur_idx
if hasattr(self.dit, "LQ_proj_in"):
self.dit.LQ_proj_in.clear_cache()
if (
unload_dit
and hasattr(self, "dit")
and not next(self.dit.parameters()).is_cpu
):
print("[FlashVSR] Offloading DiT to the CPU to free up VRAM...")
self.offload_model(keep_vae=True)
latents = torch.cat(latents_total, dim=2)
# Decode
print("[FlashVSR] Starting VAE decoding...")
frames = (
self.TCDecoder.decode_video(
latents.transpose(1, 2),
parallel=False,
show_progress_bar=False,
cond=LQ_video[:, :, :LQ_cur_idx, :, :],
)
.transpose(1, 2)
.mul_(2)
.sub_(1)
)
self.TCDecoder.clean_mem()
if force_offload:
self.offload_model()
# 颜色校正(wavelet)
try:
if color_fix:
frames = self.ColorCorrector(
frames.to(device=LQ_video.device),
LQ_video[:, :, : frames.shape[2], :, :],
clip_range=(-1, 1),
chunk_size=16,
method="adain",
)
except:
pass
return frames[0]
# -----------------------------
# TeaCache(保留原逻辑;此处默认不启用)
# -----------------------------
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
self.num_inference_steps = num_inference_steps
self.step = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.rel_l1_thresh = rel_l1_thresh
self.previous_residual = None
self.previous_hidden_states = None
self.coefficients_dict = {
"Wan2.1-T2V-1.3B": [
-5.21862437e04,
9.23041404e03,
-5.28275948e02,
1.36987616e01,
-4.99875664e-02,
],
"Wan2.1-T2V-14B": [
-3.03318725e05,
4.90537029e04,
-2.65530556e03,
5.87365115e01,
-3.15583525e-01,
],
"Wan2.1-I2V-14B-480P": [
2.57151496e05,
-3.54229917e04,
1.40286849e03,
-1.35890334e01,
1.32517977e-01,
],
"Wan2.1-I2V-14B-720P": [
8.10705460e03,
2.13393892e03,
-3.72934672e02,
1.66203073e01,
-4.17769401e-02,
],
}
if model_id not in self.coefficients_dict:
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
raise ValueError(
f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids})."
)
self.coefficients = self.coefficients_dict[model_id]
def check(self, dit: WanModel, x, t_mod):
modulated_inp = t_mod.clone()
if self.step == 0 or self.step == self.num_inference_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = self.coefficients
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(
(
(modulated_inp - self.previous_modulated_input).abs().mean()
/ self.previous_modulated_input.abs().mean()
)
.cpu()
.item()
)
should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
if should_calc:
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.step = (self.step + 1) % self.num_inference_steps
if should_calc:
self.previous_hidden_states = x.clone()
return not should_calc
def store(self, hidden_states):
self.previous_residual = hidden_states - self.previous_hidden_states
self.previous_hidden_states = None
def update(self, hidden_states):
hidden_states = hidden_states + self.previous_residual
return hidden_states
# -----------------------------
# 简化版模型前向封装(无 vace / 无 motion_controller)
# -----------------------------
def model_fn_wan_video(
dit: WanModel,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
tea_cache: Optional[TeaCache] = None,
use_unified_sequence_parallel: bool = False,
LQ_latents: Optional[torch.Tensor] = None,
is_full_block: bool = False,
is_stream: bool = False,
pre_cache_k: Optional[list[torch.Tensor]] = None,
pre_cache_v: Optional[list[torch.Tensor]] = None,
topk_ratio: float = 2.0,
kv_ratio: float = 3.0,
cur_process_idx: int = 0,
t_mod: torch.Tensor = None,
t: torch.Tensor = None,
local_range: int = 9,
**kwargs,
):
# patchify
x, (f, h, w) = dit.patchify(x)
win = (2, 8, 8)
seqlen = f // win[0]
local_num = seqlen
window_size = win[0] * h * w // 128
square_num = window_size * window_size
topk = int(square_num * topk_ratio) - 1
kv_len = int(kv_ratio)
# RoPE 位置(分段)
if cur_process_idx == 0:
freqs = (
torch.cat(
[
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
)
.reshape(f * h * w, 1, -1)
.to(x.device)
)
else:
freqs = (
torch.cat(
[
dit.freqs[0][4 + cur_process_idx * 2 : 4 + cur_process_idx * 2 + f]
.view(f, 1, 1, -1)
.expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
)
.reshape(f * h * w, 1, -1)
.to(x.device)
)
# TeaCache(默认不启用)
tea_cache_update = (
tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
)
# 统一序列并行(此处默认关闭)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import (
get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group,
)
if dist.is_initialized() and dist.get_world_size() > 1:
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[
get_sequence_parallel_rank()
]
# Block 堆叠
if tea_cache_update:
x = tea_cache.update(x)
else:
for block_id, block in enumerate(dit.blocks):
if LQ_latents is not None and block_id < len(LQ_latents):
x = x + LQ_latents[block_id]
x, last_pre_cache_k, last_pre_cache_v = block(
x,
context,
t_mod,
freqs,
f,
h,
w,
local_num,
topk,
block_id=block_id,
kv_len=kv_len,
is_full_block=is_full_block,
is_stream=is_stream,
pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
local_range=local_range,
)
if pre_cache_k is not None:
pre_cache_k[block_id] = last_pre_cache_k
if pre_cache_v is not None:
pre_cache_v[block_id] = last_pre_cache_v
x = dit.head(x, t)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import get_sp_group
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
x = dit.unpatchify(x, (f, h, w))
return x, pre_cache_k, pre_cache_v
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
try:
from sageattention import sageattn
SAGE_ATTN_AVAILABLE = True
except ModuleNotFoundError:
SAGE_ATTN_AVAILABLE = False
try:
from block_sparse_attn import block_sparse_attn_func
BLOCK_ATTN_AVAILABLE = True
except:
BLOCK_ATTN_AVAILABLE = False
@triton.jit
def quant_per_block_int8_kernel(
Input,
Output,
Scale,
L,
stride_iz,
stride_ih,
stride_in,
stride_oz,
stride_oh,
stride_on,
stride_sz,
stride_sh,
sm_scale,
C: tl.constexpr,
BLK: tl.constexpr,
):
off_blk = tl.program_id(0)
off_h = tl.program_id(1)
off_b = tl.program_id(2)
offs_n = off_blk * BLK + tl.arange(0, BLK)
offs_k = tl.arange(0, C)
input_ptrs = (
Input
+ off_b * stride_iz
+ off_h * stride_ih
+ offs_n[:, None] * stride_in
+ offs_k[None, :]
)
output_ptrs = (
Output
+ off_b * stride_oz
+ off_h * stride_oh
+ offs_n[:, None] * stride_on
+ offs_k[None, :]
)
scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk
x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
x = x.to(tl.float32)
x *= sm_scale
scale = tl.max(tl.abs(x)) / 127.0
x_int8 = x / scale
x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
x_int8 = x_int8.to(tl.int8)
tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
tl.store(scale_ptrs, scale)
def per_block_int8(
q, k, km=None, BLKQ=128, BLKK=64, sm_scale=None, tensor_layout="HND"
):
q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
if km is not None:
k = k - km
if tensor_layout == "HND":
b, h_qo, qo_len, head_dim = q.shape
_, h_kv, kv_len, _ = k.shape
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
stride_bz_qo, stride_h_qo, stride_seq_qo = (
q_int8.stride(0),
q_int8.stride(1),
q_int8.stride(2),
)
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
stride_bz_ko, stride_h_ko, stride_seq_ko = (
k_int8.stride(0),
k_int8.stride(1),
k_int8.stride(2),
)
elif tensor_layout == "NHD":
b, qo_len, h_qo, head_dim = q.shape
_, kv_len, h_kv, _ = k.shape
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
stride_bz_qo, stride_h_qo, stride_seq_qo = (
q_int8.stride(0),
q_int8.stride(2),
q_int8.stride(1),
)
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
stride_bz_ko, stride_h_ko, stride_seq_ko = (
k_int8.stride(0),
k_int8.stride(2),
k_int8.stride(1),
)
else:
raise ValueError(f"Unknown tensor layout: {tensor_layout}")
q_scale = torch.empty(
(b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32
)
k_scale = torch.empty(
(b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32
)
if sm_scale is None:
sm_scale = head_dim**-0.5
grid = ((qo_len + BLKQ - 1) // BLKQ, h_qo, b)
quant_per_block_int8_kernel[grid](
q,
q_int8,
q_scale,
qo_len,
stride_bz_q,
stride_h_q,
stride_seq_q,
stride_bz_qo,
stride_h_qo,
stride_seq_qo,
q_scale.stride(0),
q_scale.stride(1),
sm_scale=(sm_scale * 1.44269504),
C=head_dim,
BLK=BLKQ,
)
grid = ((kv_len + BLKK - 1) // BLKK, h_kv, b)
quant_per_block_int8_kernel[grid](
k,
k_int8,
k_scale,
kv_len,
stride_bz_k,
stride_h_k,
stride_seq_k,
stride_bz_ko,
stride_h_ko,
stride_seq_ko,
k_scale.stride(0),
k_scale.stride(1),
sm_scale=1.0,
C=head_dim,
BLK=BLKK,
)
return q_int8, q_scale, k_int8, k_scale
@triton.jit
def _attn_fwd_inner(
acc,
l_i,
old_m,
q,
q_scale,
kv_len,
K_ptrs,
K_bid_ptr,
K_scale_ptr,
V_ptrs,
stride_kn,
stride_vn,
start_m,
BLOCK_M: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr,
STAGE: tl.constexpr,
offs_m: tl.constexpr,
offs_n: tl.constexpr,
):
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
K_scale_ptr += lo // BLOCK_N
K_ptrs += stride_kn * lo
V_ptrs += stride_vn * lo
elif STAGE == 3:
lo, hi = 0, kv_len
for start_n in range(lo, hi, BLOCK_N):
kbid = tl.load(K_bid_ptr + start_n // BLOCK_N)
if kbid:
k_mask = offs_n[None, :] < (kv_len - start_n)
k = tl.load(K_ptrs, mask=k_mask)
k_scale = tl.load(K_scale_ptr)
qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = qk + tl.where(mask, 0, -1.0e6)
local_m = tl.max(qk, 1)
new_m = tl.maximum(old_m, local_m)
qk -= new_m[:, None]
else:
local_m = tl.max(qk, 1)
new_m = tl.maximum(old_m, local_m)
qk = qk - new_m[:, None]
p = tl.math.exp2(qk)
l_ij = tl.sum(p, 1)
alpha = tl.math.exp2(old_m - new_m)
l_i = l_i * alpha + l_ij
acc = acc * alpha[:, None]
v = tl.load(V_ptrs, mask=offs_n[:, None] < (kv_len - start_n))
p = p.to(tl.float16)
acc += tl.dot(p, v, out_dtype=tl.float16)
old_m = new_m
K_ptrs += BLOCK_N * stride_kn
K_scale_ptr += 1
V_ptrs += BLOCK_N * stride_vn
return acc, l_i, old_m
@triton.jit
def _attn_fwd(
Q,
K,
K_blkid,
V,
Q_scale,
K_scale,
Out,
stride_qz,
stride_qh,
stride_qn,
stride_kz,
stride_kh,
stride_kn,
stride_vz,
stride_vh,
stride_vn,
stride_oz,
stride_oh,
stride_on,
stride_kbidq,
stride_kbidk,
qo_len,
kv_len,
H: tl.constexpr,
num_kv_groups: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
STAGE: tl.constexpr,
):
start_m = tl.program_id(0)
off_z = tl.program_id(2).to(tl.int64)
off_h = tl.program_id(1).to(tl.int64)
q_scale_offset = (off_z * H + off_h) * tl.cdiv(qo_len, BLOCK_M)
k_scale_offset = (off_z * (H // num_kv_groups) + off_h // num_kv_groups) * tl.cdiv(
kv_len, BLOCK_N
)
k_bid_offset = (
off_z * (H // num_kv_groups) + off_h // num_kv_groups
) * stride_kbidq
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, HEAD_DIM)
Q_ptrs = (
Q
+ (off_z * stride_qz + off_h * stride_qh)
+ offs_m[:, None] * stride_qn
+ offs_k[None, :]
)
Q_scale_ptr = Q_scale + q_scale_offset + start_m
K_ptrs = (
K
+ (off_z * stride_kz + (off_h // num_kv_groups) * stride_kh)
+ offs_n[None, :] * stride_kn
+ offs_k[:, None]
)
K_scale_ptr = K_scale + k_scale_offset
K_bid_ptr = K_blkid + k_bid_offset + start_m * stride_kbidk
V_ptrs = (
V
+ (off_z * stride_vz + (off_h // num_kv_groups) * stride_vh)
+ offs_n[:, None] * stride_vn
+ offs_k[None, :]
)
O_block_ptr = (
Out
+ (off_z * stride_oz + off_h * stride_oh)
+ offs_m[:, None] * stride_on
+ offs_k[None, :]
)
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len)
q_scale = tl.load(Q_scale_ptr)
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
q_scale,
kv_len,
K_ptrs,
K_bid_ptr,
K_scale_ptr,
V_ptrs,
stride_kn,
stride_vn,
start_m,
BLOCK_M,
HEAD_DIM,
BLOCK_N,
4 - STAGE,
offs_m,
offs_n,
)
if STAGE != 1:
acc, l_i, _ = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
q_scale,
kv_len,
K_ptrs,
K_bid_ptr,
K_scale_ptr,
V_ptrs,
stride_kn,
stride_vn,
start_m,
BLOCK_M,
HEAD_DIM,
BLOCK_N,
2,
offs_m,
offs_n,
)
acc = acc / l_i[:, None]
tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=(offs_m[:, None] < qo_len))
def sparse_sageattn_fwd(
q,
k,
k_block_id,
v,
q_scale,
k_scale,
is_causal=False,
tensor_layout="HND",
output_dtype=torch.float16,
):
BLOCK_M = 128
BLOCK_N = 64
stage = 3 if is_causal else 1
o = torch.empty(q.shape, dtype=output_dtype, device=q.device)
if tensor_layout == "HND":
b, h_qo, qo_len, head_dim = q.shape
_, h_kv, kv_len, _ = k.shape
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(1), v.stride(2)
stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(1), o.stride(2)
elif tensor_layout == "NHD":
b, qo_len, h_qo, head_dim = q.shape
_, kv_len, h_kv, _ = k.shape
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(2), v.stride(1)
stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(2), o.stride(1)
else:
raise ValueError(f"tensor_layout {tensor_layout} not supported")
if is_causal:
assert qo_len == kv_len, "qo_len and kv_len must be equal for causal attention"
HEAD_DIM_K = head_dim
num_kv_groups = h_qo // h_kv
grid = (triton.cdiv(qo_len, BLOCK_M), h_qo, b)
_attn_fwd[grid](
q,
k,
k_block_id,
v,
q_scale,
k_scale,
o,
stride_bz_q,
stride_h_q,
stride_seq_q,
stride_bz_k,
stride_h_k,
stride_seq_k,
stride_bz_v,
stride_h_v,
stride_seq_v,
stride_bz_o,
stride_h_o,
stride_seq_o,
k_block_id.stride(1),
k_block_id.stride(2),
qo_len,
kv_len,
h_qo,
num_kv_groups,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
HEAD_DIM=HEAD_DIM_K,
STAGE=stage,
num_warps=4 if head_dim == 64 else 8,
num_stages=4,
)
return o
def sparse_sageattn(q, k, v, mask_id=None, is_causal=False, tensor_layout="HND"):
if mask_id is None:
mask_id = torch.ones(
(
q.shape[0],
q.shape[1],
(q.shape[2] + 128 - 1) // 128,
(q.shape[3] + 64 - 1) // 64,
),
dtype=torch.int8,
device=q.device,
) # TODO
output_dtype = q.dtype
if output_dtype == torch.bfloat16 or output_dtype == torch.float32:
v = v.to(torch.float16)
seq_dim = 1 if tensor_layout == "NHD" else 2
km = k.mean(dim=seq_dim, keepdim=True)
# km = torch.zeros((k.size(0), k.size(1), 1, k.size(3)), dtype=torch.float16, device=k.device) # Placeholder for mean, not used in quantization
q_int8, q_scale, k_int8, k_scale = per_block_int8(
q, k, km=km, tensor_layout=tensor_layout
)
o = sparse_sageattn_fwd(
q_int8,
k_int8,
mask_id,
v,
q_scale,
k_scale,
is_causal=is_causal,
tensor_layout=tensor_layout,
output_dtype=output_dtype,
)
return o
USE_BLOCK_ATTN = False
# ----------------------------
# Local / window masks
# ----------------------------
@torch.no_grad()
def build_local_block_mask_shifted_vec(
block_h: int,
block_w: int,
win_h: int = 6,
win_w: int = 6,
include_self: bool = True,
device=None,
) -> torch.Tensor:
device = device or torch.device("cpu")
H, W = block_h, block_w
r = torch.arange(H, device=device)
c = torch.arange(W, device=device)
YY, XX = torch.meshgrid(r, c, indexing="ij")
r_all = YY.reshape(-1)
c_all = XX.reshape(-1)
r_half = win_h // 2
c_half = win_w // 2
start_r = torch.clamp(r_all - r_half, 0, H - win_h)
end_r = start_r + win_h - 1
start_c = torch.clamp(c_all - c_half, 0, W - win_w)
end_c = start_c + win_w - 1
in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None])
in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None])
mask = in_row & in_col
if not include_self:
mask.fill_diagonal_(False)
return mask
@torch.no_grad()
def build_local_block_mask_shifted_vec_normal_slide(
block_h: int,
block_w: int,
win_h: int = 6,
win_w: int = 6,
include_self: bool = True,
device=None,
) -> torch.Tensor:
device = device or torch.device("cpu")
H, W = block_h, block_w
r = torch.arange(H, device=device)
c = torch.arange(W, device=device)
YY, XX = torch.meshgrid(r, c, indexing="ij")
r_all = YY.reshape(-1)
c_all = XX.reshape(-1)
r_half = win_h // 2
c_half = win_w // 2
start_r = r_all - r_half
end_r = start_r + win_h - 1
start_c = c_all - c_half
end_c = start_c + win_w - 1
in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None])
in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None])
mask = in_row & in_col
if not include_self:
mask.fill_diagonal_(False)
return mask
class WindowPartition3D:
"""Partition / reverse-partition helpers for 5-D tensors (B,F,H,W,C)."""
@staticmethod
def partition(x: torch.Tensor, win: Tuple[int, int, int]):
B, F, H, W, C = x.shape
wf, wh, ww = win
assert (
F % wf == 0 and H % wh == 0 and W % ww == 0
), "Dims must divide by window size."
x = x.view(B, F // wf, wf, H // wh, wh, W // ww, ww, C)
x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous()
return x.view(-1, wf * wh * ww, C)
@staticmethod
def reverse(
windows: torch.Tensor, win: Tuple[int, int, int], orig: Tuple[int, int, int]
):
F, H, W = orig
wf, wh, ww = win
nf, nh, nw = F // wf, H // wh, W // ww
B = windows.size(0) // (nf * nh * nw)
x = windows.view(B, nf, nh, nw, wf, wh, ww, -1)
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous()
return x.view(B, F, H, W, -1)
@torch.no_grad()
def generate_draft_block_mask(
batch_size, nheads, seqlen, q_w, k_w, topk=10, local_attn_mask=None
):
assert batch_size == 1, "Only batch_size=1 supported for now"
assert local_attn_mask is not None, "local_attn_mask must be provided"
avgpool_q = torch.mean(q_w, dim=1)
avgpool_k = torch.mean(k_w, dim=1)
avgpool_q = rearrange(avgpool_q, "s (h d) -> s h d", h=nheads)
avgpool_k = rearrange(avgpool_k, "s (h d) -> s h d", h=nheads)
q_heads = avgpool_q.permute(1, 0, 2)
k_heads = avgpool_k.permute(1, 0, 2)
D = avgpool_q.shape[-1]
scores = torch.einsum("hld,hmd->hlm", q_heads, k_heads) / math.sqrt(D)
repeat_head = scores.shape[0]
repeat_len = scores.shape[1] // local_attn_mask.shape[0]
repeat_num = scores.shape[2] // local_attn_mask.shape[1]
local_attn_mask = (
local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1)
)
local_attn_mask = rearrange(local_attn_mask, "x a y b -> (x a) (y b)")
local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1)
local_attn_mask = local_attn_mask.to(torch.float32)
local_attn_mask = local_attn_mask.masked_fill(
local_attn_mask == False, -float("inf")
)
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0)
scores = scores + local_attn_mask
attn_map = torch.softmax(scores, dim=-1)
attn_map = rearrange(attn_map, "h (it s1) s2 -> (h it) s1 s2", it=seqlen)
loop_num, s1, s2 = attn_map.shape
flat = attn_map.reshape(loop_num, -1)
n = flat.shape[1]
apply_topk = min(flat.shape[1] - 1, topk)
thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1]
thresholds = thresholds.unsqueeze(1)
mask_new = (flat > thresholds).reshape(loop_num, s1, s2)
mask_new = rearrange(
mask_new, "(h it) s1 s2 -> h (it s1) s2", it=seqlen
) # keep shape note
# 修正:上行变量名统一
# mask_new = rearrange(attn_map, 'h (it s1) s2 -> h (it s1) s2', it=seqlen) * 0 + mask_new
mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1)
return mask
@torch.no_grad()
def generate_draft_block_mask_sage(
batch_size, nheads, seqlen, q_w, k_w, topk=10, local_attn_mask=None
):
assert batch_size == 1, "Only batch_size=1 supported for now"
assert local_attn_mask is not None, "local_attn_mask must be provided"
avgpool_q = torch.mean(q_w, dim=1)
avgpool_q = rearrange(avgpool_q, "s (h d) -> s h d", h=nheads)
q_heads = avgpool_q.permute(1, 0, 2)
D = avgpool_q.shape[-1]
k_w_split = k_w.view(k_w.shape[0], 2, 64, k_w.shape[2])
avgpool_k_split = torch.mean(k_w_split, dim=2)
avgpool_k_refined = rearrange(
avgpool_k_split, "s two d -> (s two) d", two=2
) # shape: (s*2, C)
avgpool_k_refined = rearrange(
avgpool_k_refined, "s (h d) -> s h d", h=nheads
) # shape: (s*2, h, d)
k_heads_doubled = avgpool_k_refined.permute(1, 0, 2) # shape: (h, s*2, d)
k_heads_1, k_heads_2 = torch.chunk(k_heads_doubled, 2, dim=1)
scores_1 = torch.einsum("hld,hmd->hlm", q_heads, k_heads_1) / math.sqrt(D)
scores_2 = torch.einsum("hld,hmd->hlm", q_heads, k_heads_2) / math.sqrt(D)
scores = torch.cat([scores_1, scores_2], dim=-1)
repeat_head = scores.shape[0]
repeat_len = scores.shape[1] // local_attn_mask.shape[0]
repeat_num = (scores.shape[2] // 2) // local_attn_mask.shape[1]
local_attn_mask = (
local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1)
)
local_attn_mask = rearrange(local_attn_mask, "x a y b -> (x a) (y b)")
local_attn_mask = local_attn_mask.repeat_interleave(2, dim=1)
local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1)
assert (
scores.shape == local_attn_mask.shape
), f"Scores shape {scores.shape} != Mask shape {local_attn_mask.shape}"
local_attn_mask = local_attn_mask.to(torch.float32)
local_attn_mask = local_attn_mask.masked_fill(
local_attn_mask == False, -float("inf")
)
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0)
scores = scores + local_attn_mask
attn_map = torch.softmax(scores, dim=-1)
attn_map = rearrange(attn_map, "h (it s1) s2 -> (h it) s1 s2", it=seqlen)
loop_num, s1, s2 = attn_map.shape
flat = attn_map.reshape(loop_num, -1)
apply_topk = min(flat.shape[1] - 1, topk)
if apply_topk <= 0:
mask_new = torch.zeros_like(flat, dtype=torch.bool).reshape(loop_num, s1, s2)
else:
thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[
:, -1
]
thresholds = thresholds.unsqueeze(1)
mask_new = (flat > thresholds).reshape(loop_num, s1, s2)
mask_new = rearrange(mask_new, "(h it) s1 s2 -> h (it s1) s2", it=seqlen)
mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1)
return mask
# ----------------------------
# Attention kernels
# ----------------------------
def flash_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
num_heads: int,
compatibility_mode=False,
attention_mask=None,
return_KV=False,
):
if attention_mask is not None:
seqlen = q.shape[1]
seqlen_kv = k.shape[1]
if USE_BLOCK_ATTN and BLOCK_ATTN_AVAILABLE:
q = rearrange(q, "b s (n d) -> (b s) n d", n=num_heads)
k = rearrange(k, "b s (n d) -> (b s) n d", n=num_heads)
v = rearrange(v, "b s (n d) -> (b s) n d", n=num_heads)
else:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
cu_seqlens_q = torch.tensor([0, seqlen], device=q.device, dtype=torch.int32)
cu_seqlens_k = torch.tensor([0, seqlen_kv], device=q.device, dtype=torch.int32)
head_mask_type = torch.tensor(
[1] * num_heads, device=q.device, dtype=torch.int32
)
streaming_info = None
base_blockmask = attention_mask
max_seqlen_q_ = seqlen
max_seqlen_k_ = seqlen_kv
p_dropout = 0.0
if USE_BLOCK_ATTN and BLOCK_ATTN_AVAILABLE:
x = block_sparse_attn_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
head_mask_type,
streaming_info,
base_blockmask,
max_seqlen_q_,
max_seqlen_k_,
p_dropout,
deterministic=False,
softmax_scale=None,
is_causal=False,
exact_streaming=False,
return_attn_probs=False,
).unsqueeze(0)
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
else:
x = sparse_sageattn(
q,
k,
v,
mask_id=base_blockmask.to(torch.int8),
is_causal=False,
tensor_layout="HND",
)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
elif compatibility_mode:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
elif FLASH_ATTN_3_AVAILABLE:
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v)
if isinstance(x, tuple):
x = x[0]
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif FLASH_ATTN_2_AVAILABLE:
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = flash_attn.flash_attn_func(q, k, v)
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif SAGE_ATTN_AVAILABLE:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = sageattn(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
else:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
return x
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return x * (1 + scale) + shift
def sinusoidal_embedding_1d(dim, position):
sinusoid = torch.outer(
position.type(torch.float64),
torch.pow(
10000,
-torch.arange(dim // 2, dtype=torch.float64, device=position.device).div(
dim // 2
),
),
)
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x.to(position.dtype)
def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
return f_freqs_cis, h_freqs_cis, w_freqs_cis
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].double() / dim))
freqs = torch.outer(torch.arange(end, device=freqs.device), freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
x_out = torch.view_as_complex(
x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
)
x_out = torch.view_as_real(x_out * freqs).flatten(2)
return x_out.to(x.dtype)
# ----------------------------
# Norms & Blocks
# ----------------------------
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def forward(self, x):
dtype = x.dtype
return self.norm(x.float()).to(dtype) * self.weight
class AttentionModule(nn.Module):
def __init__(self, num_heads):
super().__init__()
self.num_heads = num_heads
def forward(self, q, k, v, attention_mask=None):
x = flash_attention(
q=q, k=k, v=v, num_heads=self.num_heads, attention_mask=attention_mask
)
return x
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads)
self.local_attn_mask = None
def forward(
self,
x,
freqs,
f=None,
h=None,
w=None,
local_num=None,
topk=None,
train_img=False,
block_id=None,
kv_len=None,
is_full_block=False,
is_stream=False,
pre_cache_k=None,
pre_cache_v=None,
local_range=9,
):
B, L, D = x.shape
if is_stream and pre_cache_k is not None and pre_cache_v is not None:
assert f == 2, "f must be 2"
if is_stream and (pre_cache_k is None or pre_cache_v is None):
assert f == 6, " start f must be 6"
assert L == f * h * w, "Sequence length mismatch with provided (f,h,w)."
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
v = self.v(x)
q = rope_apply(q, freqs, self.num_heads)
k = rope_apply(k, freqs, self.num_heads)
win = (2, 8, 8)
q = q.view(B, f, h, w, D)
k = k.view(B, f, h, w, D)
v = v.view(B, f, h, w, D)
q_w = WindowPartition3D.partition(q, win)
k_w = WindowPartition3D.partition(k, win)
v_w = WindowPartition3D.partition(v, win)
seqlen = f // win[0]
one_len = k_w.shape[0] // B // seqlen
if pre_cache_k is not None and pre_cache_v is not None:
k_w = torch.cat([pre_cache_k, k_w], dim=0)
v_w = torch.cat([pre_cache_v, v_w], dim=0)
block_n = q_w.shape[0] // B
block_s = q_w.shape[1]
block_n_kv = k_w.shape[0] // B
reorder_q = rearrange(
q_w,
"(b block_n) (block_s) d -> b (block_n block_s) d",
block_n=block_n,
block_s=block_s,
)
reorder_k = rearrange(
k_w,
"(b block_n) (block_s) d -> b (block_n block_s) d",
block_n=block_n_kv,
block_s=block_s,
)
reorder_v = rearrange(
v_w,
"(b block_n) (block_s) d -> b (block_n block_s) d",
block_n=block_n_kv,
block_s=block_s,
)
window_size = win[0] * h * w // 128
if (
self.local_attn_mask is None
or self.local_attn_mask_h != h // 8
or self.local_attn_mask_w != w // 8
or self.local_range != local_range
):
self.local_attn_mask = build_local_block_mask_shifted_vec_normal_slide(
h // 8,
w // 8,
local_range,
local_range,
include_self=True,
device=k_w.device,
)
self.local_attn_mask_h = h // 8
self.local_attn_mask_w = w // 8
self.local_range = local_range
if USE_BLOCK_ATTN and BLOCK_ATTN_AVAILABLE:
attention_mask = generate_draft_block_mask(
B,
self.num_heads,
seqlen,
q_w,
k_w,
topk=topk,
local_attn_mask=self.local_attn_mask,
)
else:
attention_mask = generate_draft_block_mask_sage(
B,
self.num_heads,
seqlen,
q_w,
k_w,
topk=topk,
local_attn_mask=self.local_attn_mask,
)
x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask)
cur_block_n, cur_block_s, _ = k_w.shape
cache_num = cur_block_n // one_len
if cache_num > kv_len:
cache_k = k_w[one_len:, :, :]
cache_v = v_w[one_len:, :, :]
else:
cache_k = k_w
cache_v = v_w
x = rearrange(
x,
"b (block_n block_s) d -> (b block_n) (block_s) d",
block_n=block_n,
block_s=block_s,
)
x = WindowPartition3D.reverse(x, win, (f, h, w))
x = x.view(B, f * h * w, D)
if is_stream:
return self.o(x), cache_k, cache_v
return self.o(x)
class CrossAttention(nn.Module):
"""
仅考虑文本 context;提供持久 KV 缓存。
"""
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads)
# 持久缓存
self.cache_k = None
self.cache_v = None
@torch.no_grad()
def init_cache(self, ctx: torch.Tensor):
"""ctx: [B, S_ctx, dim] —— 经过 text_embedding 之后的上下文"""
self.cache_k = self.norm_k(self.k(ctx))
self.cache_v = self.v(ctx)
def clear_cache(self):
self.cache_k = None
self.cache_v = None
def forward(self, x: torch.Tensor, y: torch.Tensor, is_stream: bool = False):
"""
y 即文本上下文(未做其他分支)。
"""
q = self.norm_q(self.q(x))
assert self.cache_k is not None and self.cache_v is not None
k = self.cache_k
v = self.cache_v
x = self.attn(q, k, v)
return self.o(x)
class GateModule(nn.Module):
def __init__(
self,
):
super().__init__()
def forward(self, x, gate, residual):
return x + gate * residual
class DiTBlock(nn.Module):
def __init__(self, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.ffn_dim = ffn_dim
self.self_attn = SelfAttention(dim, num_heads, eps)
self.cross_attn = CrossAttention(dim, num_heads, eps)
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.norm3 = nn.LayerNorm(dim, eps=eps)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim),
nn.GELU(approximate="tanh"),
nn.Linear(ffn_dim, dim),
)
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
self.gate = GateModule()
def forward(
self,
x,
context,
t_mod,
freqs,
f,
h,
w,
local_num=None,
topk=None,
train_img=False,
block_id=None,
kv_len=None,
is_full_block=False,
is_stream=False,
pre_cache_k=None,
pre_cache_v=None,
local_range=9,
):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod
).chunk(6, dim=1)
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
self_attn_output, self_attn_cache_k, self_attn_cache_v = self.self_attn(
input_x,
freqs,
f,
h,
w,
local_num,
topk,
train_img,
block_id,
kv_len=kv_len,
is_full_block=is_full_block,
is_stream=is_stream,
pre_cache_k=pre_cache_k,
pre_cache_v=pre_cache_v,
local_range=local_range,
)
x = self.gate(x, gate_msa, self_attn_output)
x = x + self.cross_attn(self.norm3(x), context, is_stream=is_stream)
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = self.gate(x, gate_mlp, self.ffn(input_x))
if is_stream:
return x, self_attn_cache_k, self_attn_cache_v
return x
class MLP(torch.nn.Module):
def __init__(self, in_dim, out_dim, has_pos_emb=False):
super().__init__()
self.proj = torch.nn.Sequential(
nn.LayerNorm(in_dim),
nn.Linear(in_dim, in_dim),
nn.GELU(),
nn.Linear(in_dim, out_dim),
nn.LayerNorm(out_dim),
)
self.has_pos_emb = has_pos_emb
if has_pos_emb:
self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
def forward(self, x):
if self.has_pos_emb:
x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
return self.proj(x)
class Head(nn.Module):
def __init__(
self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float
):
super().__init__()
self.dim = dim
self.patch_size = patch_size
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
def forward(self, x, t_mod):
shift, scale = (
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod
).chunk(2, dim=1)
x = self.head(self.norm(x) * (1 + scale) + shift)
return x
# ----------------------------
# State dict converter(保持原映射;已忽略 has_image_input 使用)
# ----------------------------
class WanModelStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
"blocks.0.scale_shift_table": "blocks.0.modulation",
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
"condition_embedder.time_proj.bias": "time_projection.1.bias",
"condition_embedder.time_proj.weight": "time_projection.1.weight",
"patch_embedding.bias": "patch_embedding.bias",
"patch_embedding.weight": "patch_embedding.weight",
"scale_shift_table": "head.modulation",
"proj_out.bias": "head.head.bias",
"proj_out.weight": "head.head.weight",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
state_dict_[rename_dict[name]] = param
else:
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
if name_ in rename_dict:
name_ = rename_dict[name_]
name_ = ".".join(
name_.split(".")[:1]
+ [name.split(".")[1]]
+ name_.split(".")[2:]
)
state_dict_[name_] = param
if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
config = {
"model_type": "t2v",
"patch_size": (1, 2, 2),
"text_len": 512,
"in_dim": 16,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"window_size": (-1, -1),
"qk_norm": True,
"cross_attn_norm": True,
"eps": 1e-6,
}
else:
config = {}
return state_dict_, config
def from_civitai(self, state_dict):
state_dict = {
name: param
for name, param in state_dict.items()
if not name.startswith("vace")
}
# 保留原有哈希匹配返回的 config;实现本身不使用 has_image_input 分支
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 16,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6,
}
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 16,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
}
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
}
elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6,
}
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 48,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6,
}
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 48,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
}
elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f":
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
"has_image_pos_emb": False,
}
else:
config = {}
return state_dict, config
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
# ----------------------------
# Utility / building blocks
# ----------------------------
class IdentityConv2d(nn.Conv2d):
"""Same-shape Conv2d initialized to identity (Dirac)."""
def __init__(self, C, kernel_size=3, bias=False):
pad = kernel_size // 2
super().__init__(C, C, kernel_size, padding=pad, bias=bias)
with torch.no_grad():
init.dirac_(self.weight)
if self.bias is not None:
self.bias.zero_()
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3
class MemBlock(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(
conv(n_in * 2, n_out),
nn.ReLU(inplace=True),
conv(n_out, n_out),
nn.ReLU(inplace=True),
conv(n_out, n_out),
)
self.skip = (
nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
)
self.act = nn.ReLU(inplace=True)
def forward(self, x, past):
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
class TPool(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f * stride, n_f, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
return self.conv(x.reshape(-1, self.stride * C, H, W))
class TGrow(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
x = self.conv(x)
return x.reshape(-1, C, H, W)
class PixelShuffle3dTAEHV(nn.Module):
def __init__(self, ff, hh, ww):
super().__init__()
self.ff = ff
self.hh = hh
self.ww = ww
def forward(self, x):
# x: (B, C, F, H, W)
B, C, F, H, W = x.shape
if F % self.ff != 0:
first_frame = x[:, :, 0:1, :, :].repeat(1, 1, self.ff - F % self.ff, 1, 1)
x = torch.cat([first_frame, x], dim=2)
return rearrange(
x,
"b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w",
ff=self.ff,
hh=self.hh,
ww=self.ww,
).transpose(1, 2)
# ----------------------------
# Generic NTCHW graph executor (kept; used by decoder)
# ----------------------------
def apply_model_with_memblocks(model, x, parallel, show_progress_bar, mem=None):
"""
Apply a sequential model with memblocks to the given input.
Args:
- model: nn.Sequential of blocks to apply
- x: input data, of dimensions NTCHW
- parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
if False, each timestep will be processed sequentially (slow but uses O(1) memory)
- show_progress_bar: if True, enables tqdm progressbar display
Returns NTCHW tensor of output data.
"""
assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor"
N, T, C, H, W = x.shape
if parallel:
x = x.reshape(N * T, C, H, W)
for b in tqdm(model, disable=not show_progress_bar):
if isinstance(b, MemBlock):
NT, C, H, W = x.shape
T = NT // N
_x = x.reshape(N, T, C, H, W)
mem = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape(
x.shape
)
x = b(x, mem)
else:
x = b(x)
NT, C, H, W = x.shape
T = NT // N
x = x.view(N, T, C, H, W)
else:
out = []
work_queue = [
TWorkItem(xt, 0)
for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))
]
progress_bar = tqdm(range(T), disable=not show_progress_bar)
while work_queue:
xt, i = work_queue.pop(0)
if i == 0:
progress_bar.update(1)
if i == len(model):
out.append(xt)
else:
b = model[i]
if isinstance(b, MemBlock):
if mem[i] is None:
xt_new = b(xt, xt * 0)
mem[i] = xt
else:
xt_new = b(xt, mem[i])
mem[i].copy_(xt)
work_queue.insert(0, TWorkItem(xt_new, i + 1))
elif isinstance(b, TPool):
if mem[i] is None:
mem[i] = []
mem[i].append(xt)
if len(mem[i]) > b.stride:
raise ValueError("TPool internal state invalid.")
elif len(mem[i]) == b.stride:
N_, C_, H_, W_ = xt.shape
xt = b(torch.cat(mem[i], 1).view(N_ * b.stride, C_, H_, W_))
mem[i] = []
work_queue.insert(0, TWorkItem(xt, i + 1))
elif isinstance(b, TGrow):
xt = b(xt)
NT, C_, H_, W_ = xt.shape
for xt_next in reversed(
xt.view(N, b.stride * C_, H_, W_).chunk(b.stride, 1)
):
work_queue.insert(0, TWorkItem(xt_next, i + 1))
else:
xt = b(xt)
work_queue.insert(0, TWorkItem(xt, i + 1))
progress_bar.close()
x = torch.stack(out, 1)
return x, mem
# ----------------------------
# Decoder-only TAEHV
# ----------------------------
class TAEHV(nn.Module):
image_channels = 3
def __init__(
self,
checkpoint_path="taehv.pth",
decoder_time_upscale=(True, True),
decoder_space_upscale=(True, True, True),
channels=[256, 128, 64, 64],
latent_channels=16,
):
"""Initialize TAEHV (decoder-only) with built-in deepening after every ReLU.
Deepening config: how_many_each=1, k=3 (fixed as requested).
"""
super().__init__()
self.latent_channels = latent_channels
n_f = channels
self.frames_to_trim = 2 ** sum(decoder_time_upscale) - 1
# Build the decoder "skeleton"
base_decoder = nn.Sequential(
Clamp(),
conv(self.latent_channels, n_f[0]),
nn.ReLU(inplace=True),
MemBlock(n_f[0], n_f[0]),
MemBlock(n_f[0], n_f[0]),
MemBlock(n_f[0], n_f[0]),
nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1),
TGrow(n_f[0], 1),
conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1]),
MemBlock(n_f[1], n_f[1]),
MemBlock(n_f[1], n_f[1]),
nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1),
TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1),
conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2]),
MemBlock(n_f[2], n_f[2]),
MemBlock(n_f[2], n_f[2]),
nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1),
TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1),
conv(n_f[2], n_f[3], bias=False),
nn.ReLU(inplace=True),
conv(n_f[3], TAEHV.image_channels),
)
# Inline deepening: insert (IdentityConv2d(k=3) + ReLU) after every ReLU
self.decoder = self._apply_identity_deepen(base_decoder, how_many_each=1, k=3)
self.pixel_shuffle = PixelShuffle3dTAEHV(4, 8, 8)
if checkpoint_path is not None:
missing_keys = self.load_state_dict(
self.patch_tgrow_layers(
torch.load(checkpoint_path, map_location="cpu", weights_only=True)
),
strict=False,
)
print("missing_keys", missing_keys)
# Initialize decoder mem state
self.mem = [None] * len(self.decoder)
@staticmethod
def _apply_identity_deepen(
decoder: nn.Sequential, how_many_each=1, k=3
) -> nn.Sequential:
"""Return a new Sequential where every nn.ReLU is followed by how_many_each*(IdentityConv2d(k)+ReLU)."""
new_layers = []
for b in decoder:
new_layers.append(b)
if isinstance(b, nn.ReLU):
# Deduce channel count from preceding layer
C = None
if len(new_layers) >= 2 and isinstance(new_layers[-2], nn.Conv2d):
C = new_layers[-2].out_channels
elif len(new_layers) >= 2 and isinstance(new_layers[-2], MemBlock):
C = new_layers[-2].conv[-1].out_channels
if C is not None:
for _ in range(how_many_each):
new_layers.append(IdentityConv2d(C, kernel_size=k, bias=False))
new_layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*new_layers)
def patch_tgrow_layers(self, sd):
"""Patch TGrow layers to use a smaller kernel if needed (decoder-only)."""
new_sd = self.state_dict()
for i, layer in enumerate(self.decoder):
if isinstance(layer, TGrow):
key = f"decoder.{i}.conv.weight"
if key in sd and sd[key].shape[0] > new_sd[key].shape[0]:
sd[key] = sd[key][-new_sd[key].shape[0] :]
return sd
def decode_video(self, x, parallel=True, show_progress_bar=False, cond=None):
"""Decode a sequence of frames from latents.
x: NTCHW latent tensor; returns NTCHW RGB in ~[0, 1].
"""
trim_flag = self.mem[-8] is None # keeps original relative check
if cond is not None:
x = torch.cat([self.pixel_shuffle(cond), x], dim=2)
x, self.mem = apply_model_with_memblocks(
self.decoder, x, parallel, show_progress_bar, mem=self.mem
)
if trim_flag:
return x[:, self.frames_to_trim :]
return x
def forward(self, *args, **kwargs):
raise NotImplementedError("Decoder-only model: call decode_video(...) instead.")
def clean_mem(self):
self.mem = [None] * len(self.decoder)
class DotDict(dict):
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
class TAEW2_1DiffusersWrapper(nn.Module):
def __init__(self, pretrained_path=None, channels=[256, 128, 64, 64]):
super().__init__()
self.dtype = torch.bfloat16
self.device = "cuda"
self.taehv = TAEHV(pretrained_path, channels=channels).to(self.dtype)
self.temperal_downsample = [True, True, False] # [sic]
self.config = DotDict(
scaling_factor=1.0,
latents_mean=torch.zeros(16),
z_dim=16,
latents_std=torch.ones(16),
)
def decode(self, latents, return_dict=None):
n, c, t, h, w = latents.shape
return (
self.taehv.decode_video(latents.transpose(1, 2), parallel=False)
.transpose(1, 2)
.mul_(2)
.sub_(1),
)
def stream_decode_with_cond(self, latents, tiled=False, cond=None):
n, c, t, h, w = latents.shape
return (
self.taehv.decode_video(latents.transpose(1, 2), parallel=False, cond=cond)
.transpose(1, 2)
.mul_(2)
.sub_(1)
)
def clean_mem(self):
self.taehv.clean_mem()
# ----------------------------
# Simplified builder (no small, no transplant, no post-hoc deepening)
# ----------------------------
def build_tcdecoder(
new_channels=[512, 256, 128, 128],
device="cuda",
dtype=torch.bfloat16,
new_latent_channels=None,
):
"""
构建“更宽”的 decoder;深度增强(IdentityConv2d+ReLU)已在 TAEHV 内部完成。
- 不创建 small / 不做移植
- base_ckpt_path 参数保留但不使用(接口兼容)
返回:big (单个模型)
"""
if new_latent_channels is not None:
big = (
TAEHV(
checkpoint_path=None,
channels=new_channels,
latent_channels=new_latent_channels,
)
.to(device)
.to(dtype)
.train()
)
else:
big = (
TAEHV(checkpoint_path=None, channels=new_channels)
.to(device)
.to(dtype)
.train()
)
big.clean_mem()
return big
CACHE_T = 2
@contextmanager
def init_weights_on_device(device=torch.device("meta"), include_buffers: bool = False):
old_register_parameter = torch.nn.Module.register_parameter
if include_buffers:
old_register_buffer = torch.nn.Module.register_buffer
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(
module._parameters[name].to(device), **kwargs
)
def register_empty_buffer(module, name, buffer, persistent=True):
old_register_buffer(module, name, buffer, persistent=persistent)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(device)
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = device
return fn(*args, **kwargs)
return wrapper
if include_buffers:
tensor_constructors_to_patch = {
torch_function_name: getattr(torch, torch_function_name)
for torch_function_name in ["empty", "zeros", "ones", "full"]
}
else:
tensor_constructors_to_patch = {}
try:
torch.nn.Module.register_parameter = register_empty_parameter
if include_buffers:
torch.nn.Module.register_buffer = register_empty_buffer
for torch_function_name in tensor_constructors_to_patch.keys():
setattr(
torch,
torch_function_name,
patch_tensor_constructor(getattr(torch, torch_function_name)),
)
yield
finally:
torch.nn.Module.register_parameter = old_register_parameter
if include_buffers:
torch.nn.Module.register_buffer = old_register_buffer
for (
torch_function_name,
old_torch_function,
) in tensor_constructors_to_patch.items():
setattr(torch, torch_function_name, old_torch_function)
def load_state_dict_from_folder(file_path, torch_dtype=None):
state_dict = {}
for file_name in os.listdir(file_path):
if "." in file_name and file_name.split(".")[-1] in [
"safetensors",
"bin",
"ckpt",
"pth",
"pt",
]:
state_dict.update(
load_state_dict(
os.path.join(file_path, file_name), torch_dtype=torch_dtype
)
)
return state_dict
def load_state_dict(file_path, torch_dtype=None):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict
def search_for_embeddings(state_dict):
embeddings = []
for k in state_dict:
if isinstance(state_dict[k], torch.Tensor):
embeddings.append(state_dict[k])
elif isinstance(state_dict[k], dict):
embeddings += search_for_embeddings(state_dict[k])
return embeddings
def search_parameter(param, state_dict):
for name, param_ in state_dict.items():
if param.numel() == param_.numel():
if param.shape == param_.shape:
if torch.dist(param, param_) < 1e-3:
return name
else:
if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
return name
return None
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
matched_keys = set()
with torch.no_grad():
for name in source_state_dict:
rename = search_parameter(source_state_dict[name], target_state_dict)
if rename is not None:
print(f'"{name}": "{rename}",')
matched_keys.add(rename)
elif (
split_qkv
and len(source_state_dict[name].shape) >= 1
and source_state_dict[name].shape[0] % 3 == 0
):
length = source_state_dict[name].shape[0] // 3
rename = []
for i in range(3):
rename.append(
search_parameter(
source_state_dict[name][i * length : i * length + length],
target_state_dict,
)
)
if None not in rename:
print(f'"{name}": {rename},')
for rename_ in rename:
matched_keys.add(rename_)
for name in target_state_dict:
if name not in matched_keys:
print("Cannot find", name, target_state_dict[name].shape)
def search_for_files(folder, extensions):
files = []
if os.path.isdir(folder):
for file in sorted(os.listdir(folder)):
files += search_for_files(os.path.join(folder, file), extensions)
elif os.path.isfile(folder):
for extension in extensions:
if folder.endswith(extension):
files.append(folder)
break
return files
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, torch.Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(
key
+ "|"
+ convert_state_dict_keys_to_single_str(
value, with_shape=with_shape
)
)
keys.sort()
keys_str = ",".join(keys)
return keys_str
def split_state_dict_with_prefix(state_dict):
keys = sorted([key for key in state_dict if isinstance(key, str)])
prefix_dict = {}
for key in keys:
prefix = key if "." not in key else key.split(".")[0]
if prefix not in prefix_dict:
prefix_dict[prefix] = []
prefix_dict[prefix].append(key)
state_dicts = []
for prefix, keys in prefix_dict.items():
sub_state_dict = {key: state_dict[key] for key in keys}
state_dicts.append(sub_state_dict)
return state_dicts
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()
def clean_vram():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
if torch.mps.is_available():
torch.mps.empty_cache()
def get_device_list():
devs = []
try:
if (
hasattr(torch, "cuda")
and hasattr(torch.cuda, "is_available")
and torch.cuda.is_available()
):
devs += [f"cuda:{i}" for i in range(torch.cuda.device_count())]
except Exception:
pass
try:
if (
hasattr(torch, "mps")
and hasattr(torch.mps, "is_available")
and torch.mps.is_available()
):
devs += [f"mps:{i}" for i in range(torch.mps.device_count())]
except Exception:
pass
return devs
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return (
F.normalize(x, dim=(1 if self.channel_first else -1))
* self.scale
* self.gamma
+ self.bias
)
class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolusion.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (
self.padding[2],
self.padding[2],
self.padding[1],
self.padding[1],
2 * self.padding[0],
0,
)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
# print(cache_x.shape, x.shape)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
# print('cache!')
x = F.pad(x, padding, mode="replicate") # mode='replicate'
# print(x[0,0,:,0,0])
return super().forward(x)
class PixelShuffle3d(nn.Module):
def __init__(self, ff, hh, ww):
super().__init__()
self.ff = ff
self.hh = hh
self.ww = ww
def forward(self, x):
# x: (B, C, F, H, W)
return rearrange(
x,
"b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w",
ff=self.ff,
hh=self.hh,
ww=self.ww,
)
class Buffer_LQ4x_Proj(nn.Module):
def __init__(self, in_dim, out_dim, layer_num=30):
super().__init__()
self.ff = 1
self.hh = 16
self.ww = 16
self.hidden_dim1 = 2048
self.hidden_dim2 = 3072
self.layer_num = layer_num
self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww)
self.conv1 = CausalConv3d(
in_dim * self.ff * self.hh * self.ww,
self.hidden_dim1,
(4, 3, 3),
stride=(2, 1, 1),
padding=(1, 1, 1),
) # f -> f/2 h -> h w -> w
self.norm1 = RMS_norm(self.hidden_dim1, images=False)
self.act1 = nn.SiLU()
self.conv2 = CausalConv3d(
self.hidden_dim1,
self.hidden_dim2,
(4, 3, 3),
stride=(2, 1, 1),
padding=(1, 1, 1),
) # f -> f/2 h -> h w -> w
self.norm2 = RMS_norm(self.hidden_dim2, images=False)
self.act2 = nn.SiLU()
self.linear_layers = nn.ModuleList(
[nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)]
)
self.clip_idx = 0
def forward(self, video):
self.clear_cache()
# x: (B, C, F, H, W)
t = video.shape[2]
iter_ = 1 + (t - 1) // 4
first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
video = torch.cat([first_frame, video], dim=2)
# print(video.shape)
out_x = []
for i in range(iter_):
x = self.pixel_shuffle(video[:, :, i * 4 : (i + 1) * 4, :, :])
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache["conv1"] = cache1_x
x = self.conv1(x, self.cache["conv1"])
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache["conv2"] = cache2_x
if i == 0:
continue
x = self.conv2(x, self.cache["conv2"])
x = self.norm2(x)
x = self.act2(x)
out_x.append(x)
out_x = torch.cat(out_x, dim=2)
# print(out_x.shape)
out_x = rearrange(out_x, "b c f h w -> b (f h w) c")
outputs = []
for i in range(self.layer_num):
outputs.append(self.linear_layers[i](out_x))
return outputs
def clear_cache(self):
self.cache = {}
self.cache["conv1"] = None
self.cache["conv2"] = None
self.clip_idx = 0
def stream_forward(self, video_clip):
if self.clip_idx == 0:
# self.clear_cache()
first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
video_clip = torch.cat([first_frame, video_clip], dim=2)
x = self.pixel_shuffle(video_clip)
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache["conv1"] = cache1_x
x = self.conv1(x, self.cache["conv1"])
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache["conv2"] = cache2_x
self.clip_idx += 1
return None
else:
x = self.pixel_shuffle(video_clip)
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache["conv1"] = cache1_x
x = self.conv1(x, self.cache["conv1"])
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache["conv2"] = cache2_x
x = self.conv2(x, self.cache["conv2"])
x = self.norm2(x)
x = self.act2(x)
out_x = rearrange(x, "b c f h w -> b (f h w) c")
outputs = []
for i in range(self.layer_num):
outputs.append(self.linear_layers[i](out_x))
self.clip_idx += 1
return outputs
class Causal_LQ4x_Proj(nn.Module):
def __init__(self, in_dim, out_dim, layer_num=30):
super().__init__()
self.ff = 1
self.hh = 16
self.ww = 16
self.hidden_dim1 = 2048
self.hidden_dim2 = 3072
self.layer_num = layer_num
self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww)
self.conv1 = CausalConv3d(
in_dim * self.ff * self.hh * self.ww,
self.hidden_dim1,
(4, 3, 3),
stride=(2, 1, 1),
padding=(1, 1, 1),
) # f -> f/2 h -> h w -> w
self.norm1 = RMS_norm(self.hidden_dim1, images=False)
self.act1 = nn.SiLU()
self.conv2 = CausalConv3d(
self.hidden_dim1,
self.hidden_dim2,
(4, 3, 3),
stride=(2, 1, 1),
padding=(1, 1, 1),
) # f -> f/2 h -> h w -> w
self.norm2 = RMS_norm(self.hidden_dim2, images=False)
self.act2 = nn.SiLU()
self.linear_layers = nn.ModuleList(
[nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)]
)
self.clip_idx = 0
def forward(self, video):
self.clear_cache()
# x: (B, C, F, H, W)
t = video.shape[2]
iter_ = 1 + (t - 1) // 4
first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
video = torch.cat([first_frame, video], dim=2)
# print(video.shape)
out_x = []
for i in range(iter_):
x = self.pixel_shuffle(video[:, :, i * 4 : (i + 1) * 4, :, :])
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
x = self.conv1(x, self.cache["conv1"])
self.cache["conv1"] = cache1_x
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
if i == 0:
self.cache["conv2"] = cache2_x
continue
x = self.conv2(x, self.cache["conv2"])
self.cache["conv2"] = cache2_x
x = self.norm2(x)
x = self.act2(x)
out_x.append(x)
out_x = torch.cat(out_x, dim=2)
out_x = rearrange(out_x, "b c f h w -> b (f h w) c")
outputs = []
for i in range(self.layer_num):
outputs.append(self.linear_layers[i](out_x))
return outputs
def clear_cache(self):
self.cache = {}
self.cache["conv1"] = None
self.cache["conv2"] = None
self.clip_idx = 0
def stream_forward(self, video_clip):
if self.clip_idx == 0:
# self.clear_cache()
first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
video_clip = torch.cat([first_frame, video_clip], dim=2)
x = self.pixel_shuffle(video_clip)
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
x = self.conv1(x, self.cache["conv1"])
self.cache["conv1"] = cache1_x
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache["conv2"] = cache2_x
self.clip_idx += 1
return None
else:
x = self.pixel_shuffle(video_clip)
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
x = self.conv1(x, self.cache["conv1"])
self.cache["conv1"] = cache1_x
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
x = self.conv2(x, self.cache["conv2"])
self.cache["conv2"] = cache2_x
x = self.norm2(x)
x = self.act2(x)
out_x = rearrange(x, "b c f h w -> b (f h w) c")
outputs = []
for i in range(self.layer_num):
outputs.append(self.linear_layers[i](out_x))
self.clip_idx += 1
return outputs
class FrameStreamBuffer:
def __init__(
self,
frame_generator: types.GeneratorType,
buffer_size: int = 60,
device="cpu",
dtype=torch.float16,
):
self.generator = frame_generator
self.buffer_size = buffer_size
self.device = device
self.dtype = dtype
self.buffer = deque()
self.start_frame_index = 0
self._fill_buffer(initial_fill_count=self.buffer_size)
def _fill_buffer(self, initial_fill_count: int):
try:
for _ in range(initial_fill_count):
frame = next(self.generator)
self.buffer.append(frame)
except StopIteration:
pass
def get_chunk(self, start: int, end: int) -> torch.Tensor:
if start < self.start_frame_index:
raise IndexError(
f"Start frame {start} has already been discarded (current buffer starts at {self.start_frame_index})"
)
while end > self.start_frame_index + len(self.buffer):
try:
self.buffer.append(next(self.generator))
except StopIteration:
if end > self.start_frame_index + len(self.buffer):
print(
f"End frame {end} is out of range! It will be truncated to {self.start_frame_index + len(self.buffer)}"
)
end = self.start_frame_index + len(self.buffer)
break
while len(self.buffer) > self.buffer_size:
self.buffer.popleft()
self.start_frame_index += 1
relative_start = start - self.start_frame_index
relative_end = end - self.start_frame_index
chunk_list = [self.buffer[i] for i in range(relative_start, relative_end)]
if not chunk_list:
C, H, W = self.buffer[0].shape
return torch.empty((1, C, 0, H, W), device=self.device, dtype=self.dtype)
chunk_tensor = torch.stack(chunk_list, dim=1) # (C, chunk_len, H, W)
return chunk_tensor.unsqueeze(0).to(
device=self.device
) # (1, C, chunk_len, H, W)
class TensorAsBuffer:
def __init__(self, tensor: torch.Tensor):
self.tensor = tensor
def get_chunk(self, start: int, end: int) -> torch.Tensor:
return self.tensor[:, :, start:end, :, :]
def tensor_to_imageio_frame(frame_tensor: torch.Tensor) -> np.ndarray:
img_tensor = (frame_tensor + 1.0) / 2.0
img_tensor_hwc = img_tensor.permute(1, 2, 0)
img_tensor_hwc_u8 = (img_tensor_hwc * 255.0).clamp(0, 255).to(torch.uint8)
img_np = img_tensor_hwc_u8.cpu().numpy()
return img_np
root = os.path.dirname(os.path.abspath(__file__))
temp = os.path.join(root, "_temp")
devices = get_device_list()
def model_downlod(model_name="JunhaoZhuang/FlashVSR"):
model_dir = os.path.join(root, "models", model_name.split("/")[-1])
print(f"model dir: {model_dir}")
if not os.path.exists(model_dir):
log(
f"Downloading model '{model_name}' from huggingface...", message_type="info"
)
snapshot_download(
repo_id=model_name,
local_dir=model_dir,
local_dir_use_symlinks=False,
resume_download=True,
)
def is_ffmpeg_available():
ffmpeg_path = shutil.which("ffmpeg")
if ffmpeg_path is None:
log("[FlashVSR] FFmpeg not found!", message_type="warning")
log("Please install FFmpeg and ensure it is in your system's PATH.")
log(
"- Windows: Download from https://www.ffmpeg.org/download.html and add the 'bin' directory to PATH."
)
log("- macOS (via Homebrew): brew install ffmpeg")
log("- Linux (Ubuntu/Debian): sudo apt-get install ffmpeg")
return False
return True
def tensor2video(frames: torch.Tensor):
video_squeezed = frames.squeeze(0)
video_permuted = rearrange(video_squeezed, "C F H W -> F H W C")
video_final = (video_permuted.float() + 1.0) / 2.0
return video_final
def natural_key(name: str):
return [
int(t) if t.isdigit() else t.lower()
for t in re.split(r"([0-9]+)", os.path.basename(name))
]
def list_images_natural(folder: str):
exts = (".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG")
fs = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(exts)]
fs.sort(key=natural_key)
return fs
def largest_8n1_leq(n): # 8n+1
return 0 if n < 1 else ((n - 1) // 8) * 8 + 1
def next_8n5(n): # next 8n+5
return 21 if n < 21 else ((n - 5 + 7) // 8) * 8 + 5
def is_video(path):
return os.path.isfile(path) and path.lower().endswith(
(".mp4", ".mov", ".avi", ".mkv")
)
def save_video(frames, save_path, fps=30, quality=5):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
frames_np = (frames.cpu().float() * 255.0).clip(0, 255).numpy().astype(np.uint8)
w = imageio.get_writer(save_path, fps=fps, quality=quality)
for frame_np in tqdm(frames_np, desc=f"[FlashVSR] Saving video"):
w.append_data(frame_np)
w.close()
def merge_video_with_audio(video_path, audio_source_path):
temp = video_path + "temp.mp4"
if os.path.isdir(audio_source_path):
log(f"[FlashVSR] Output video saved to '{video_path}'", message_type="info")
return
if not is_ffmpeg_available():
log(f"[FlashVSR] Output video saved to '{video_path}'", message_type="info")
return
try:
probe = ffmpeg.probe(audio_source_path)
audio_streams = [s for s in probe["streams"] if s["codec_type"] == "audio"]
if not audio_streams:
log(f"[FlashVSR] Output video saved to '{video_path}'", message_type="info")
return
log("[FlashVSR] Copying audio tracks...")
os.rename(video_path, temp)
input_video = ffmpeg.input(temp)["v"]
input_audio = ffmpeg.input(audio_source_path)["a"]
output_ffmpeg = ffmpeg.output(
input_video, input_audio, video_path, vcodec="copy", acodec="copy"
).run(overwrite_output=True, quiet=True)
log(f"[FlashVSR] Output video saved to '{video_path}'", message_type="info")
except ffmpeg.Error as e:
print(
"[ERROR] FFmpeg error during merge:",
e.stderr.decode() if e.stderr else "Unknown error",
)
log(
f"[FlashVSR] Audio merge failed. A silent video has been saved to '{video_path}'.",
message_type="warning",
)
finally:
if os.path.exists(temp):
try:
os.remove(temp)
except OSError as e:
lgo(
f"[FlashVSR] Could not remove temporary file '{temp}': {e}",
message_type="error",
)
def compute_scaled_and_target_dims(
w0: int, h0: int, scale: int = 4, multiple: int = 128
):
if w0 <= 0 or h0 <= 0:
raise ValueError("invalid original size")
sW, sH = w0 * scale, h0 * scale
tW = max(multiple, (sW // multiple) * multiple)
tH = max(multiple, (sH // multiple) * multiple)
return sW, sH, tW, tH
def tensor_upscale_then_center_crop(
frame_tensor: torch.Tensor, scale: int, tW: int, tH: int
) -> torch.Tensor:
h0, w0, c = frame_tensor.shape
tensor_bchw = frame_tensor.permute(2, 0, 1).unsqueeze(0) # HWC -> CHW -> BCHW
sW, sH = w0 * scale, h0 * scale
upscaled_tensor = F.interpolate(
tensor_bchw, size=(sH, sW), mode="bicubic", align_corners=False
)
l = max(0, (sW - tW) // 2)
t = max(0, (sH - tH) // 2)
cropped_tensor = upscaled_tensor[:, :, t : t + tH, l : l + tW]
return cropped_tensor.squeeze(0)
def prepare_tensors(path: str, dtype=torch.bfloat16):
if os.path.isdir(path):
paths0 = list_images_natural(path)
if not paths0:
raise FileNotFoundError(f"No images in {path}")
with Image.open(paths0[0]) as _img0:
w0, h0 = _img0.size
N0 = len(paths0)
frames = []
for p in paths0:
with Image.open(p).convert("RGB") as img:
img_np = np.array(img).astype(np.float32) / 255.0
frames.append(torch.from_numpy(img_np).to(dtype))
vid = torch.stack(frames, 0)
fps = 30
return vid, fps
if is_video(path):
rdr = imageio.get_reader(path)
meta = {}
try:
meta = rdr.get_meta_data()
first_frame = rdr.get_data(0)
h0, w0, _ = first_frame.shape
except Exception:
first_frame = rdr.get_data(0)
h0, w0, _ = first_frame.shape
fps_val = meta.get("fps", 30)
fps = int(round(fps_val)) if isinstance(fps_val, (int, float)) else 30
total = meta.get("nframes", rdr.count_frames())
if total is None or total <= 0:
total = len([_ for _ in rdr])
rdr = imageio.get_reader(path)
if total <= 0:
rdr.close()
raise RuntimeError(f"Cannot read frames from {path}")
frames = []
try:
for frame_data in rdr:
frame_np = frame_data.astype(np.float32) / 255.0
frames.append(torch.from_numpy(frame_np).to(dtype))
finally:
try:
rdr.close()
except Exception:
pass
vid = torch.stack(frames, 0)
return vid, fps
raise ValueError(f"Unsupported input: {path}")
def get_input_params(image_tensor, scale):
N0, h0, w0, _ = image_tensor.shape
multiple = 128
sW, sH, tW, tH = compute_scaled_and_target_dims(
w0, h0, scale=scale, multiple=multiple
)
num_frames_with_padding = N0 + 4
F = largest_8n1_leq(num_frames_with_padding)
if F == 0:
raise RuntimeError(
f"Not enough frames after padding. Got {num_frames_with_padding}."
)
return tH, tW, F
def input_tensor_generator(
image_tensor: torch.Tensor, device, scale: int = 4, dtype=torch.bfloat16
):
"""
一个生成器函数,逐帧处理并 yield 准备好的帧张量,以节省内存。
产出的每个张量形状为 (C, H, W)。
"""
N0, h0, w0, _ = image_tensor.shape
tH, tW, F = get_input_params(image_tensor, scale)
for i in range(F):
frame_idx = min(i, N0 - 1)
frame_slice = image_tensor[frame_idx].to(device)
tensor_chw = tensor_upscale_then_center_crop(
frame_slice, scale=scale, tW=tW, tH=tH
)
tensor_out = tensor_chw * 2.0 - 1.0
del tensor_chw
yield tensor_out.to("cpu").to(dtype)
def prepare_input_tensor(
image_tensor: torch.Tensor, device, scale: int = 4, dtype=torch.bfloat16
):
N0, h0, w0, _ = image_tensor.shape
multiple = 128
sW, sH, tW, tH = compute_scaled_and_target_dims(
w0, h0, scale=scale, multiple=multiple
)
num_frames_with_padding = N0 + 4
F = largest_8n1_leq(num_frames_with_padding)
if F == 0:
raise RuntimeError(
f"Not enough frames after padding. Got {num_frames_with_padding}."
)
frames = []
for i in range(F):
frame_idx = min(i, N0 - 1)
frame_slice = image_tensor[frame_idx].to(device)
tensor_chw = tensor_upscale_then_center_crop(
frame_slice, scale=scale, tW=tW, tH=tH
)
tensor_out = tensor_chw * 2.0 - 1.0
tensor_out = tensor_out.to("cpu").to(dtype)
frames.append(tensor_out)
vid_stacked = torch.stack(frames, 0)
vid_final = vid_stacked.permute(1, 0, 2, 3).unsqueeze(0)
del vid_stacked
clean_vram()
return vid_final, tH, tW, F
def calculate_tile_coords(height, width, tile_size, overlap):
coords = []
stride = tile_size - overlap
num_rows = math.ceil((height - overlap) / stride)
num_cols = math.ceil((width - overlap) / stride)
for r in range(num_rows):
for c in range(num_cols):
y1 = r * stride
x1 = c * stride
y2 = min(y1 + tile_size, height)
x2 = min(x1 + tile_size, width)
if y2 - y1 < tile_size:
y1 = max(0, y2 - tile_size)
if x2 - x1 < tile_size:
x1 = max(0, x2 - tile_size)
coords.append((x1, y1, x2, y2))
return coords
def create_feather_mask_numpy(size, overlap):
H, W = size
mask = np.ones((H, W, 1), dtype=np.float32)
ramp = np.linspace(0, 1, overlap, dtype=np.float32)
mask[:, :overlap, :] *= ramp[np.newaxis, :, np.newaxis]
mask[:, -overlap:, :] *= np.flip(ramp)[np.newaxis, :, np.newaxis]
mask[:overlap, :, :] *= ramp[:, np.newaxis, np.newaxis]
mask[-overlap:, :, :] *= np.flip(ramp)[:, np.newaxis, np.newaxis]
return mask
def create_feather_mask(size, overlap):
H, W = size
mask = torch.ones(1, 1, H, W)
ramp = torch.linspace(0, 1, overlap)
mask[:, :, :, :overlap] = torch.minimum(
mask[:, :, :, :overlap], ramp.view(1, 1, 1, -1)
)
mask[:, :, :, -overlap:] = torch.minimum(
mask[:, :, :, -overlap:], ramp.flip(0).view(1, 1, 1, -1)
)
mask[:, :, :overlap, :] = torch.minimum(
mask[:, :, :overlap, :], ramp.view(1, 1, -1, 1)
)
mask[:, :, -overlap:, :] = torch.minimum(
mask[:, :, -overlap:, :], ramp.flip(0).view(1, 1, -1, 1)
)
return mask
def stitch_video_tiles(
tile_paths,
tile_coords,
final_dims,
scale,
overlap,
output_path,
fps,
quality,
cleanup=True,
chunk_size=40, # --- 新增参数:每次在内存中处理的帧数 ---
):
if not tile_paths:
log("No tile videos found to stitch.", message_type="error")
return
final_W, final_H = final_dims
# 1. 一次性打开所有视频文件
readers = [imageio.get_reader(p) for p in tile_paths]
try:
# 获取总帧数
num_frames = readers[0].count_frames()
if num_frames is None or num_frames <= 0:
num_frames = len([_ for _ in readers[0]])
for r in readers:
r.close()
readers = [imageio.get_reader(p) for p in tile_paths]
# 打开最终的写入器
with imageio.get_writer(output_path, fps=fps, quality=quality) as writer:
# 2. 按 chunk_size 遍历所有帧
# tqdm 现在描述的是处理了多少个“块”
for start_frame in tqdm(
range(0, num_frames, chunk_size), desc="[FlashVSR] Stitching Chunks"
):
end_frame = min(start_frame + chunk_size, num_frames)
current_chunk_size = end_frame - start_frame
# 3. 为整个“块”在内存中创建画布
# 形状: (Frames, Height, Width, Channels)
chunk_canvas = np.zeros(
(current_chunk_size, final_H, final_W, 3), dtype=np.float32
)
weight_canvas = np.zeros_like(chunk_canvas, dtype=np.float32)
# 4. 遍历每个分块视频 (tile)
for i, reader in enumerate(readers):
# 5. 一次性读取这个 tile 在当前 chunk 中的所有帧
# 这是利用顺序读取的关键优化
try:
# get_reader().iter_data() 是高效读取连续帧的方式
tile_chunk_frames = [
frame.astype(np.float32) / 255.0
for idx, frame in enumerate(reader.iter_data())
if start_frame <= idx < end_frame
]
# 将帧列表转换为一个 NumPy 数组
tile_chunk_np = np.stack(tile_chunk_frames, axis=0)
except Exception as e:
log(
f"Warning: Could not read chunk from tile {i}. Error: {e}",
message_type="warning",
)
continue
if tile_chunk_np.shape[0] != current_chunk_size:
log(
f"Warning: Tile {i} chunk has incorrect frame count. Skipping.",
message_type="warning",
)
continue
# 6. 创建羽化蒙版 (只需要创建一次)
tile_H, tile_W, _ = tile_chunk_np.shape[1:]
ramp = np.linspace(0, 1, overlap * scale, dtype=np.float32)
mask = np.ones((tile_H, tile_W, 1), dtype=np.float32)
mask[:, : overlap * scale, :] *= ramp[np.newaxis, :, np.newaxis]
mask[:, -overlap * scale :, :] *= np.flip(ramp)[
np.newaxis, :, np.newaxis
]
mask[: overlap * scale, :, :] *= ramp[:, np.newaxis, np.newaxis]
mask[-overlap * scale :, :, :] *= np.flip(ramp)[
:, np.newaxis, np.newaxis
]
# 扩展蒙版以匹配 chunk 的帧数维度
mask_4d = mask[np.newaxis, :, :, :] # 形状: (1, H, W, C)
# 7. 在内存中拼接整个 chunk
x1_orig, y1_orig, _, _ = tile_coords[i]
out_y1, out_x1 = y1_orig * scale, x1_orig * scale
out_y2, out_x2 = out_y1 + tile_H, out_x1 + tile_W
# 使用 NumPy 的广播机制 (broadcasting)
chunk_canvas[:, out_y1:out_y2, out_x1:out_x2, :] += (
tile_chunk_np * mask_4d
)
weight_canvas[:, out_y1:out_y2, out_x1:out_x2, :] += mask_4d
# 8. 归一化整个 chunk
weight_canvas[weight_canvas == 0] = 1.0
stitched_chunk = chunk_canvas / weight_canvas
# 9. 将这个 chunk 的所有帧一次性写入文件
for frame_idx_in_chunk in range(current_chunk_size):
frame_uint8 = (
np.clip(stitched_chunk[frame_idx_in_chunk], 0, 1) * 255
).astype(np.uint8)
writer.append_data(frame_uint8)
finally:
log("Closing all tile reader instances...")
for reader in readers:
reader.close()
if cleanup:
log("Cleaning up temporary tile files...")
for path in tile_paths:
try:
os.remove(path)
except OSError as e:
log(
f"Could not remove temporary file '{path}': {e}",
message_type="warning",
)
def init_pipeline(version, mode, device, dtype):
if version == "10":
model = "FlashVSR"
else:
model = "FlashVSR-v1.1"
model_downlod(model_name="JunhaoZhuang/" + model)
model_path = os.path.join(root, "models", model)
if not os.path.exists(model_path):
raise RuntimeError(
f'Model directory does not exist! Please save all weights to "{model_path}"'
)
ckpt_path = os.path.join(
model_path, "diffusion_pytorch_model_streaming_dmd.safetensors"
)
if not os.path.exists(ckpt_path):
raise RuntimeError(
f'"diffusion_pytorch_model_streaming_dmd.safetensors" does not exist! Please save it to "{model_path}"'
)
vae_path = os.path.join(model_path, "Wan2.1_VAE.pth")
if not os.path.exists(vae_path):
raise RuntimeError(
f'"Wan2.1_VAE.pth" does not exist! Please save it to "{model_path}"'
)
lq_path = os.path.join(model_path, "LQ_proj_in.ckpt")
if not os.path.exists(lq_path):
raise RuntimeError(
f'"LQ_proj_in.ckpt" does not exist! Please save it to "{model_path}"'
)
tcd_path = os.path.join(model_path, "TCDecoder.ckpt")
if not os.path.exists(tcd_path):
raise RuntimeError(
f'"TCDecoder.ckpt" does not exist! Please save it to "{model_path}"'
)
prompt_path = os.path.join(root, "models", "posi_prompt.pth")
mm = ModelManager(torch_dtype=dtype, device="cpu")
if mode == "full":
mm.load_models([ckpt_path, vae_path])
pipe = FlashVSRFullPipeline.from_model_manager(mm, device=device)
pipe.vae.model.encoder = None
pipe.vae.model.conv1 = None
else:
mm.load_models([ckpt_path])
if mode == "tiny":
pipe = FlashVSRTinyPipeline.from_model_manager(mm, device=device)
else:
pipe = FlashVSRTinyLongPipeline.from_model_manager(mm, device=device)
multi_scale_channels = [512, 256, 128, 128]
pipe.TCDecoder = build_tcdecoder(
new_channels=multi_scale_channels,
device=device,
dtype=dtype,
new_latent_channels=16 + 768,
)
mis = pipe.TCDecoder.load_state_dict(
torch.load(tcd_path, map_location=device), strict=False
)
pipe.TCDecoder.clean_mem()
if model == "FlashVSR":
pipe.denoising_model().LQ_proj_in = Buffer_LQ4x_Proj(
in_dim=3, out_dim=1536, layer_num=1
).to(device, dtype=dtype)
else:
pipe.denoising_model().LQ_proj_in = Causal_LQ4x_Proj(
in_dim=3, out_dim=1536, layer_num=1
).to(device, dtype=dtype)
pipe.denoising_model().LQ_proj_in.load_state_dict(
torch.load(lq_path, map_location="cpu"), strict=True
)
pipe.denoising_model().LQ_proj_in.to(device)
pipe.to(device, dtype=dtype)
pipe.enable_vram_management(num_persistent_param_in_dit=None)
pipe.init_cross_kv(prompt_path=prompt_path)
pipe.load_models_to_device(["dit", "vae"])
return pipe
def main(
input,
version,
mode,
scale,
color_fix,
tiled_vae,
tiled_dit,
tile_size,
tile_overlap,
unload_dit,
dtype,
sparse_ratio=2,
kv_ratio=3,
local_range=11,
seed=0,
device="auto",
quality=6,
output=None,
):
_device = device
if device == "auto":
_device = (
"cuda:0"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else device
)
if _device == "auto" or _device not in devices:
raise RuntimeError("No devices found to run FlashVSR!")
if _device.startswith("cuda"):
torch.cuda.set_device(_device)
if tiled_dit and (tile_overlap > tile_size / 2):
raise ValueError('The "tile_overlap" must be less than half of "tile_size"!')
_frames, fps = prepare_tensors(input, dtype=dtype)
_fps = fps if is_video(input) else args.fps
add = next_8n5(_frames.shape[0]) - _frames.shape[0]
padding_frames = _frames[-1:, :, :, :].repeat(add, 1, 1, 1)
frames = torch.cat([_frames, padding_frames], dim=0)
frame_count = _frames.shape[0]
del _frames
clean_vram()
log("[FlashVSR] Preparing frames...", message_type="finish")
if tiled_dit:
N, H, W, C = frames.shape
num_aligned_frames = largest_8n1_leq(N + 4) - 4
if mode == "tiny-long":
local_temp = os.path.join(temp, str(uuid.uuid4()))
os.makedirs(local_temp, exist_ok=True)
else:
final_output_canvas = torch.zeros(
(num_aligned_frames, H * scale, W * scale, C), dtype=dtype, device="cpu"
)
weight_sum_canvas = torch.zeros_like(final_output_canvas)
tile_coords = calculate_tile_coords(H, W, tile_size, tile_overlap)
latent_tiles_cpu = []
temp_videos = []
pipe = init_pipeline(version, mode, _device, dtype)
for i, (x1, y1, x2, y2) in enumerate(tile_coords):
input_tile = frames[:, y1:y2, x1:x2, :]
if mode == "tiny-long":
temp_name = os.path.join(local_temp, f"{i+1:05d}.mp4")
th, tw, F = get_input_params(input_tile, scale=scale)
LQ_tile = input_tensor_generator(
input_tile, _device, scale=scale, dtype=dtype
)
else:
LQ_tile, th, tw, F = prepare_input_tensor(
input_tile, _device, scale=scale, dtype=dtype
)
LQ_tile = LQ_tile.to(_device)
if i == 0:
log(
f"[FlashVSR] Processing {frame_count} frames...",
message_type="info",
)
log(
f"[FlashVSR] Processing tile {i+1}/{len(tile_coords)}: ({x1},{y1}) to ({x2},{y2})",
message_type="info",
)
output_tile_gpu = pipe(
prompt="",
negative_prompt="",
cfg_scale=1.0,
num_inference_steps=1,
seed=seed,
tiled=tiled_vae,
LQ_video=LQ_tile,
num_frames=F,
height=th,
width=tw,
is_full_block=False,
if_buffer=True,
topk_ratio=sparse_ratio * 768 * 1280 / (th * tw),
kv_ratio=kv_ratio,
local_range=local_range,
color_fix=color_fix,
unload_dit=unload_dit,
fps=_fps,
quality=10,
output_path=temp_name,
tiled_dit=True,
)
temp_videos.append(temp_name)
if mode == "tiny-long":
final_output = output_tile_gpu
del LQ_tile, input_tile
clean_vram()
continue
processed_tile_cpu = tensor2video(output_tile_gpu).to("cpu")
mask_nchw = create_feather_mask(
(processed_tile_cpu.shape[1], processed_tile_cpu.shape[2]),
tile_overlap * scale,
).to("cpu")
mask_nhwc = mask_nchw.permute(0, 2, 3, 1)
out_x1, out_y1 = x1 * scale, y1 * scale
tile_H_scaled = processed_tile_cpu.shape[1]
tile_W_scaled = processed_tile_cpu.shape[2]
out_x2, out_y2 = out_x1 + tile_W_scaled, out_y1 + tile_H_scaled
final_output_canvas[:, out_y1:out_y2, out_x1:out_x2, :] += (
processed_tile_cpu * mask_nhwc
)
weight_sum_canvas[:, out_y1:out_y2, out_x1:out_x2, :] += mask_nhwc
del LQ_tile, output_tile_gpu, processed_tile_cpu, input_tile
clean_vram()
if mode == "tiny-long":
stitch_video_tiles(
tile_paths=temp_videos,
tile_coords=tile_coords,
final_dims=(W * scale, H * scale),
scale=scale,
overlap=tile_overlap,
output_path=output,
fps=_fps,
quality=quality,
cleanup=True,
)
shutil.rmtree(local_temp)
else:
weight_sum_canvas[weight_sum_canvas == 0] = 1.0
final_output = final_output_canvas / weight_sum_canvas
else:
if mode == "tiny-long":
th, tw, F = get_input_params(frames, scale=scale)
LQ = input_tensor_generator(frames, _device, scale=scale, dtype=dtype)
else:
LQ, th, tw, F = prepare_input_tensor(
frames, _device, scale=scale, dtype=dtype
)
LQ = LQ.to(_device)
pipe = init_pipeline(version, mode, _device, dtype)
log(f"[FlashVSR] Processing {frame_count} frames...", message_type="info")
video = pipe(
prompt="",
negative_prompt="",
cfg_scale=1.0,
num_inference_steps=1,
seed=seed,
tiled=tiled_vae,
LQ_video=LQ,
num_frames=F,
height=th,
width=tw,
is_full_block=False,
if_buffer=True,
topk_ratio=sparse_ratio * 768 * 1280 / (th * tw),
kv_ratio=kv_ratio,
local_range=local_range,
color_fix=color_fix,
unload_dit=unload_dit,
fps=_fps,
output_path=output,
tiled_dit=True,
)
if mode == "tiny-long":
del pipe, LQ
clean_vram()
return video, _fps
log("[FlashVSR] Preparing frames...")
final_output = tensor2video(video).to("cpu")
del pipe, video, LQ
clean_vram()
return final_output[:frame_count, :, :, :], fps
if __name__ == "__main__":
dtype_map = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
try:
dtype = dtype_map[args.dtype]
except:
dtype = torch.bfloat16
if args.attention == "sage":
USE_BLOCK_ATTN = False
else:
USE_BLOCK_ATTN = True
if os.path.exists(temp):
shutil.rmtree(temp)
os.makedirs(temp, exist_ok=True)
name = os.path.basename(args.input.rstrip("/"))
final = "/root/output.mp4"
result, fps = main(
args.input,
args.version,
args.mode,
args.scale,
args.color_fix,
args.tiled_vae,
args.tiled_dit,
args.tile_size,
args.overlap,
args.unload_dit,
dtype,
seed=args.seed,
device=args.device,
quality=args.quality,
output=final,
)
if args.mode != "tiny-long":
save_video(result, final, fps=fps, quality=args.quality)
merge_video_with_audio(final, args.input)
log("[FlashVSR] Done.", message_type="finish")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment