Last active
December 27, 2025 02:09
-
-
Save tamnguyenvan/3e5cb659c4e196e36e8e50497a0b60ee to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import re | |
| import math | |
| import uuid | |
| import shutil | |
| import random | |
| import hashlib | |
| import gc | |
| import copy | |
| import json | |
| from dataclasses import dataclass | |
| from typing import List, Tuple, Optional, Literal | |
| from contextlib import contextmanager | |
| import numpy as np | |
| import mlx.core as mx | |
| import mlx.nn as nn | |
| import mlx.optimizers as optim | |
| from mlx.utils import tree_flatten, tree_unflatten | |
| from PIL import Image | |
| from tqdm.auto import tqdm | |
| from huggingface_hub import snapshot_download | |
| from safetensors import safe_open | |
| import imageio | |
| import ffmpeg | |
| # ============================================================================== | |
| # Configuration & Globals | |
| # ============================================================================== | |
| @dataclass | |
| class Args: | |
| input: str = "input.mp4" | |
| output_folder: str = "./outputs" | |
| scale: int = 4 | |
| version: str = "10" | |
| mode: str = "tiny" | |
| tiled_vae: bool = True | |
| tiled_dit: bool = True | |
| tile_size: int = 16 | |
| overlap: int = 8 | |
| unload_dit: bool = True | |
| color_fix: bool = False | |
| seed: int = 0 | |
| dtype: str = "fp16" # MLX uses fp16 or fp32 | |
| fps: int = 30 | |
| quality: int = 6 | |
| attention: str = "sdpa" | |
| args = Args() | |
| CACHE_T = 2 | |
| root = os.path.dirname(os.path.abspath(__file__)) | |
| temp = os.path.join(root, "_temp") | |
| def log(message: str, message_type: str = "normal"): | |
| if message_type == "error": | |
| message = "\033[1;41m" + message + "\033[m" | |
| elif message_type == "warning": | |
| message = "\033[1;31m" + message + "\033[m" | |
| elif message_type == "finish": | |
| message = "\033[1;32m" + message + "\033[m" | |
| elif message_type == "info": | |
| message = "\033[1;33m" + message + "\033[m" | |
| print(f"{message}") | |
| # ============================================================================== | |
| # MLX Utils & Helpers | |
| # ============================================================================== | |
| def rearrange(x, pattern, **axes_lengths): | |
| # Simplified einops-like rearrange for MLX using reshape and transpose | |
| # Note: This is a manual implementation for specific patterns used in the code | |
| # to avoid external dependency issues with compiling. | |
| # In a real scenario, explicit MLX ops are preferred for speed. | |
| return x # Placeholder, logic implemented inline for optimization | |
| def sinusoidal_embedding_1d(dim, position): | |
| # position: (N,) | |
| half_dim = dim // 2 | |
| emb = math.log(10000) / (half_dim - 1) | |
| emb = mx.exp(mx.arange(half_dim) * -emb) | |
| emb = position[:, None] * emb[None, :] | |
| emb = mx.concatenate([mx.cos(emb), mx.sin(emb)], axis=-1) | |
| if dim % 2 == 1: | |
| emb = mx.pad(emb, ((0,0), (0,1))) | |
| return emb | |
| def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0): | |
| f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta) | |
| h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) | |
| w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) | |
| return f_freqs_cis, h_freqs_cis, w_freqs_cis | |
| def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): | |
| freqs = 1.0 / (theta ** (mx.arange(0, dim, 2).astype(mx.float32) / dim)) | |
| t = mx.arange(end).astype(mx.float32) | |
| freqs = mx.outer(t, freqs) # (end, dim//2) | |
| # MLX doesn't have complex number support like PyTorch's polar for this exact usage yet | |
| # We store as (cos, sin) for rope_apply | |
| freqs_cos = mx.cos(freqs) | |
| freqs_sin = mx.sin(freqs) | |
| return mx.stack([freqs_cos, freqs_sin], axis=-1) | |
| def rope_apply(x, freqs, num_heads): | |
| # x: (B, S, n*d) -> (B, S, n, d) | |
| B, S, _ = x.shape | |
| x = x.reshape(B, S, num_heads, -1) | |
| # x is formatted as real numbers, logic needs to treat pairs as complex numbers | |
| # x: (..., d) where d is even | |
| x0 = x[..., 0::2] | |
| x1 = x[..., 1::2] | |
| # freqs: (S, 1, d//2, 2) -> (1, S, 1, d//2, 2) | |
| # Broadcast freqs | |
| freqs_cos = freqs[..., 0][None, :, None, :] # (1, S, 1, D/2) | |
| freqs_sin = freqs[..., 1][None, :, None, :] # (1, S, 1, D/2) | |
| # Rotate | |
| x_out0 = x0 * freqs_cos - x1 * freqs_sin | |
| x_out1 = x0 * freqs_sin + x1 * freqs_cos | |
| # Stack back (interleave) | |
| x_out = mx.stack([x_out0, x_out1], axis=-1).reshape(B, S, num_heads, -1) | |
| return x_out.reshape(B, S, -1) | |
| def modulate(x, shift, scale): | |
| return x * (1 + scale) + shift | |
| # ============================================================================== | |
| # Layers (MLX Port) | |
| # ============================================================================== | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim, eps=1e-5): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = mx.ones((dim,)) | |
| def __call__(self, x): | |
| # x: (..., dim) | |
| return mx.fast.rms_norm(x, self.weight, self.eps) | |
| class RMS_norm_General(nn.Module): | |
| def __init__(self, dim, channel_first=True, images=True, bias=False): | |
| super().__init__() | |
| # MLX prefers channel last, so we adapt input if channel_first is True | |
| self.channel_first = channel_first | |
| self.images = images | |
| self.scale = dim**0.5 | |
| self.gamma = mx.ones((dim,)) | |
| self.bias = mx.zeros((dim,)) if bias else None | |
| def __call__(self, x): | |
| # Standardize to channel last for norm, then convert back if needed | |
| if self.channel_first: | |
| # (B, C, ...) -> (B, ..., C) | |
| if x.ndim == 5: # B, C, T, H, W | |
| x = x.transpose(0, 2, 3, 4, 1) | |
| elif x.ndim == 4: # B, C, H, W | |
| x = x.transpose(0, 2, 3, 1) | |
| # Norm | |
| out = mx.fast.rms_norm(x, self.gamma, 1e-6) # Using fast kernel | |
| out = out * self.scale | |
| if self.bias is not None: | |
| out = out + self.bias | |
| if self.channel_first: | |
| # (B, ..., C) -> (B, C, ...) | |
| if x.ndim == 5: | |
| out = out.transpose(0, 4, 1, 2, 3) | |
| elif x.ndim == 4: | |
| out = out.transpose(0, 3, 1, 2) | |
| return out | |
| class CausalConv3d(nn.Module): | |
| """ | |
| MLX Conv3d expects (N, D, H, W, C). | |
| PyTorch Conv3d expects (N, C, D, H, W). | |
| We wrap this to handle layout internally but optimize weights for NHWC. | |
| """ | |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, mode='zeros'): | |
| super().__init__() | |
| if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size, kernel_size) | |
| if isinstance(stride, int): stride = (stride, stride, stride) | |
| if isinstance(padding, int): padding = (padding, padding, padding) | |
| self.mode = mode | |
| self.padding_val = padding | |
| # Causal padding on Time dimension (dim 0 of kernel in MLX terms for 3D) | |
| # padding is (time, height, width) | |
| self.time_pad = 2 * padding[0] | |
| self.spatial_padding = (padding[1], padding[2]) | |
| # MLX Conv3D: weight shape (kD, kH, kW, C_in, C_out) | |
| self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding=0, bias=bias) | |
| # Adjust internal padding for spatial | |
| # We handle Time padding manually for causality | |
| def __call__(self, x, cache_x=None): | |
| # Input x is expected to be (N, C, T, H, W) to match PyTorch code flow, | |
| # OR (N, T, H, W, C) for MLX optimization. | |
| # Let's standardize on (N, T, H, W, C) for MLX layers. | |
| # Handle Causal Padding | |
| if cache_x is not None and self.time_pad > 0: | |
| # cache_x: (N, T_cache, H, W, C) | |
| x = mx.concatenate([cache_x, x], axis=1) # Concat on Time | |
| # Adjust effective padding: we already have history, so minimal padding needed | |
| # For simplicity in this port, we assume manual causal padding logic from caller | |
| # or simply slice | |
| pad_t = 0 # Handled by cache | |
| else: | |
| # Pad only the "past" | |
| if self.time_pad > 0: | |
| # pad (0,0) for C, (0,0) for W, (0,0) for H, (pad, 0) for T, (0,0) for N | |
| if self.mode == 'replicate': | |
| # simple replicate pad for time (index 1) | |
| first_frame = x[:, :1, ...] | |
| padding = mx.repeat(first_frame, self.time_pad, axis=1) | |
| x = mx.concatenate([padding, x], axis=1) | |
| else: | |
| # Zero pad | |
| pad_width = [(0,0), (self.time_pad, 0), (self.spatial_padding[0], self.spatial_padding[0]), (self.spatial_padding[1], self.spatial_padding[1]), (0,0)] | |
| x = mx.pad(x, pad_width) | |
| else: | |
| # Handle spatial padding if no time padding needed (1x1 convolution usually) | |
| if self.spatial_padding[0] > 0 or self.spatial_padding[1] > 0: | |
| pad_width = [(0,0), (0, 0), (self.spatial_padding[0], self.spatial_padding[0]), (self.spatial_padding[1], self.spatial_padding[1]), (0,0)] | |
| x = mx.pad(x, pad_width) | |
| out = self.conv(x) | |
| return out | |
| class Resample(nn.Module): | |
| def __init__(self, dim, mode): | |
| super().__init__() | |
| self.dim = dim | |
| self.mode = mode | |
| if mode == "upsample2d": | |
| self.conv = nn.Conv2d(dim, dim // 2, 3, padding=1) | |
| elif mode == "upsample3d": | |
| self.conv = nn.Conv2d(dim, dim // 2, 3, padding=1) | |
| self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) | |
| elif mode == "downsample2d": | |
| self.conv = nn.Conv2d(dim, dim, 3, stride=2) # Manual padding needed | |
| elif mode == "downsample3d": | |
| self.conv = nn.Conv2d(dim, dim, 3, stride=2) | |
| self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) | |
| def __call__(self, x, feat_cache=None, feat_idx=[0]): | |
| # x: (N, T, H, W, C) | |
| B, T, H, W, C = x.shape | |
| if self.mode == "upsample3d": | |
| if feat_cache is not None: | |
| idx = feat_idx[0] | |
| # Caching logic | |
| if feat_cache[idx] is None: | |
| feat_cache[idx] = "Rep" | |
| feat_idx[0] += 1 | |
| else: | |
| cache_x = x[:, -CACHE_T:, ...].astype(x.dtype) | |
| if cache_x.shape[1] < 2 and feat_cache[idx] is not None and isinstance(feat_cache[idx], mx.array): | |
| cache_x = mx.concatenate([feat_cache[idx][:, -1:, ...], cache_x], axis=1) | |
| if isinstance(feat_cache[idx], str) and feat_cache[idx] == "Rep": | |
| # Time conv handles replicate | |
| x_t = self.time_conv(x) # Handle mode='replicate' inside CausalConv3d if needed or pad manually | |
| else: | |
| x_t = self.time_conv(x, cache_x=feat_cache[idx]) | |
| feat_cache[idx] = cache_x | |
| feat_idx[0] += 1 | |
| # Reshape for channel mix | |
| # x_t: (B, T, H, W, C*2) -> split C | |
| x_t = x_t.reshape(B, T, H, W, 2, C) | |
| x = mx.stack([x_t[..., 0, :], x_t[..., 1, :]], axis=2) # (B, T, 2, H, W, C) | |
| x = x.reshape(B, T * 2, H, W, C) | |
| # Spatial Resampling | |
| # Flatten time: (B*T, H, W, C) | |
| T_curr = x.shape[1] | |
| x_flat = x.reshape(-1, x.shape[2], x.shape[3], x.shape[4]) | |
| if "upsample" in self.mode: | |
| # Nearest neighbor upsampling | |
| # MLX doesn't have Upsample layer, utilize broadcast or repeat | |
| # (B*T, H, W, C) -> (B*T, H*2, W*2, C) | |
| x_up = mx.repeat(x_flat, 2, axis=1) | |
| x_up = mx.repeat(x_up, 2, axis=2) | |
| x_flat = self.conv(x_up) | |
| elif "downsample" in self.mode: | |
| # Pad for stride 2 conv | |
| x_flat = mx.pad(x_flat, ((0,0), (0,1), (0,1), (0,0))) | |
| x_flat = self.conv(x_flat) | |
| x = x_flat.reshape(B, T_curr, x_flat.shape[1], x_flat.shape[2], -1) | |
| if self.mode == "downsample3d": | |
| # Time downsample | |
| if feat_cache is not None: | |
| idx = feat_idx[0] | |
| if feat_cache[idx] is None: | |
| feat_cache[idx] = x | |
| feat_idx[0] += 1 | |
| else: | |
| cache_x = x[:, -1:, ...].astype(x.dtype) | |
| # concat time | |
| inp = mx.concatenate([feat_cache[idx][:, -1:, ...], x], axis=1) | |
| x = self.time_conv(inp) | |
| feat_cache[idx] = cache_x | |
| feat_idx[0] += 1 | |
| return x | |
| class PixelShuffle3d(nn.Module): | |
| def __init__(self, ff, hh, ww): | |
| super().__init__() | |
| self.ff, self.hh, self.ww = ff, hh, ww | |
| def __call__(self, x): | |
| # FIX: Chuyển đổi từ PixelShuffle (Upscale) sang PixelUnshuffle (Space-to-Depth) | |
| # Input x: (B, T, H, W, C) -> Output: (B, T, H/hh, W/ww, C*hh*ww) | |
| B, T, H, W, C = x.shape | |
| # Kiểm tra kích thước | |
| if H % self.hh != 0 or W % self.ww != 0: | |
| raise ValueError(f"Input spatial dims ({H},{W}) must be divisible by patch size ({self.hh},{self.ww})") | |
| # 1. Reshape để tách các khối pixel (block) | |
| # Shape: (B, T, H//hh, hh, W//ww, ww, C) | |
| # Lưu ý: Chúng ta bỏ qua ff (thời gian) trong logic reshape này vì ff=1 trong config mặc định, | |
| # nhưng logic này vẫn hoạt động đúng về mặt không gian. | |
| x = x.reshape(B, T, H // self.hh, self.hh, W // self.ww, self.ww, C) | |
| # 2. Transpose để gom các chiều block (hh, ww) về phía Channel | |
| # Hiện tại: (B, T, NewH, hh, NewW, ww, C) | |
| # Target: (B, T, NewH, NewW, C, hh, ww) | |
| # Permutation axes: 0, 1, 2, 4, 6, 3, 5 | |
| x = x.transpose(0, 1, 2, 4, 6, 3, 5) | |
| # 3. Flatten 3 chiều cuối cùng vào Channel | |
| # Output: (B, T, H//hh, W//ww, C * hh * ww) | |
| x = x.reshape(B, T, H // self.hh, W // self.ww, -1) | |
| return x | |
| class PixelShuffle3dInverse(nn.Module): | |
| # Used for unpatchify logic if needed, or simple reshapes | |
| pass | |
| # ============================================================================== | |
| # Blocks | |
| # ============================================================================== | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, in_dim, out_dim, dropout=0.0): | |
| super().__init__() | |
| # Using MLX layers (NHWC) | |
| self.norm1 = RMS_norm_General(in_dim, channel_first=False, images=False) | |
| self.act1 = nn.SiLU() | |
| self.conv1 = CausalConv3d(in_dim, out_dim, 3, padding=1) | |
| self.norm2 = RMS_norm_General(out_dim, channel_first=False, images=False) | |
| self.act2 = nn.SiLU() | |
| self.dropout = nn.Dropout(dropout) | |
| self.conv2 = CausalConv3d(out_dim, out_dim, 3, padding=1) | |
| self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() | |
| def __call__(self, x, feat_cache=None, feat_idx=[0]): | |
| h = self.shortcut(x) | |
| x = self.norm1(x) | |
| x = self.act1(x) | |
| # Conv1 with cache handling | |
| if feat_cache is not None: | |
| idx = feat_idx[0] | |
| cache_x = x[:, -CACHE_T:, ...].astype(x.dtype) | |
| if cache_x.shape[1] < 2 and feat_cache[idx] is not None: | |
| cache_x = mx.concatenate([feat_cache[idx][:, -1:, ...], cache_x], axis=1) # naive pad | |
| x = self.conv1(x, cache_x=feat_cache[idx]) | |
| feat_cache[idx] = cache_x | |
| feat_idx[0] += 1 | |
| else: | |
| x = self.conv1(x) | |
| x = self.norm2(x) | |
| x = self.act2(x) | |
| x = self.dropout(x) | |
| # Conv2 with cache handling | |
| if feat_cache is not None: | |
| idx = feat_idx[0] | |
| cache_x = x[:, -CACHE_T:, ...].astype(x.dtype) | |
| if cache_x.shape[1] < 2 and feat_cache[idx] is not None: | |
| cache_x = mx.concatenate([feat_cache[idx][:, -1:, ...], cache_x], axis=1) | |
| x = self.conv2(x, cache_x=feat_cache[idx]) | |
| feat_cache[idx] = cache_x | |
| feat_idx[0] += 1 | |
| else: | |
| x = self.conv2(x) | |
| return x + h | |
| class AttentionBlock(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.norm = RMS_norm_General(dim, channel_first=False) | |
| # Conv2d for (B*T, H, W, C) | |
| self.to_qkv = nn.Conv2d(dim, dim * 3, 1) | |
| self.proj = nn.Conv2d(dim, dim, 1) | |
| def __call__(self, x): | |
| # x: (B, T, H, W, C) | |
| identity = x | |
| B, T, H, W, C = x.shape | |
| x = self.norm(x) | |
| # Merge B, T for Conv2d | |
| x_flat = x.reshape(B * T, H, W, C) | |
| qkv = self.to_qkv(x_flat) | |
| q, k, v = mx.split(qkv, 3, axis=-1) | |
| # Reshape for Attention: (B*T, H*W, C) | |
| q = q.reshape(B*T, H*W, C) | |
| k = k.reshape(B*T, H*W, C) | |
| v = v.reshape(B*T, H*W, C) | |
| # Standard Scaled Dot Product Attention | |
| # Single head attention essentially here as implemented in source | |
| # or Multihead if dim allows, source says dim -> dim*3 so 1 head implicitly or handled by shape | |
| # PyTorch F.scaled_dot_product_attention expects (N, L, E) or (N, H, L, E) | |
| # Here we have (N, L, E). | |
| scale = 1 / math.sqrt(C) | |
| x_attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) | |
| x_attn = x_attn.reshape(B*T, H, W, C) | |
| x_attn = self.proj(x_attn) | |
| x_out = x_attn.reshape(B, T, H, W, C) | |
| return x_out + identity | |
| # ============================================================================== | |
| # DiT / Transformer Components | |
| # ============================================================================== | |
| class SelfAttention(nn.Module): | |
| def __init__(self, dim, num_heads, eps=1e-6): | |
| super().__init__() | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| self.head_dim = dim // num_heads | |
| self.q = nn.Linear(dim, dim) | |
| self.k = nn.Linear(dim, dim) | |
| self.v = nn.Linear(dim, dim) | |
| self.o = nn.Linear(dim, dim) | |
| self.norm_q = RMSNorm(dim, eps=eps) | |
| self.norm_k = RMSNorm(dim, eps=eps) | |
| def __call__(self, x, freqs, f=None, h=None, w=None, is_stream=False, pre_cache_k=None, pre_cache_v=None): | |
| B, L, D = x.shape | |
| q = self.norm_q(self.q(x)) | |
| k = self.norm_k(self.k(x)) | |
| v = self.v(x) | |
| # Apply RoPE | |
| q = rope_apply(q, freqs, self.num_heads) | |
| k = rope_apply(k, freqs, self.num_heads) | |
| # Reshape for multi-head attention: (B, L, n_heads, head_dim) | |
| q = q.reshape(B, L, self.num_heads, self.head_dim) | |
| k = k.reshape(B, L, self.num_heads, self.head_dim) | |
| v = v.reshape(B, L, self.num_heads, self.head_dim) | |
| # Caching logic for stream | |
| # Source uses Window Partitioning logic which is complex. | |
| # For this MLX port, we implement standard Attention with optional caching. | |
| # Ideally, we would reimplement the window partitioning, but full attention is simpler and faster on M-chips for small tiles. | |
| # FlashAttention on M-chips is very efficient. | |
| if pre_cache_k is not None: | |
| k = mx.concatenate([pre_cache_k, k], axis=1) | |
| v = mx.concatenate([pre_cache_v, v], axis=1) | |
| # Compute Attention | |
| # q: (B, L_q, H, D), k: (B, L_k, H, D) | |
| # Transpose to (B, H, L, D) for MLX attention if needed? | |
| # mx.fast.scaled_dot_product_attention expects (q, k, v) | |
| # q: (B, H, L, D) or (B, L, H, D)? Docs say: (batch, heads, queries, keys) if using transpose. | |
| # Actually mx.fast.sdpa is flexible. | |
| # Let's perform standard transpose to (B, n_heads, L, head_dim) | |
| q_t = q.transpose(0, 2, 1, 3) | |
| k_t = k.transpose(0, 2, 1, 3) | |
| v_t = v.transpose(0, 2, 1, 3) | |
| x_out = mx.fast.scaled_dot_product_attention(q_t, k_t, v_t, scale=1.0/math.sqrt(self.head_dim)) | |
| # Transpose back: (B, L, n_heads, head_dim) -> (B, L, D) | |
| x_out = x_out.transpose(0, 2, 1, 3).reshape(B, L, D) | |
| out = self.o(x_out) | |
| cache_k = k if is_stream else None | |
| cache_v = v if is_stream else None | |
| # Handle KV cache trimming if logic requires window sliding (simplified here) | |
| if is_stream and cache_k.shape[1] > 2000: # heuristic limit from config | |
| cache_k = cache_k[:, -1000:, ...] | |
| cache_v = cache_v[:, -1000:, ...] | |
| if is_stream: | |
| return out, cache_k, cache_v | |
| return out | |
| class CrossAttention(nn.Module): | |
| def __init__(self, dim, num_heads, eps=1e-6): | |
| super().__init__() | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| self.q = nn.Linear(dim, dim) | |
| self.k = nn.Linear(dim, dim) | |
| self.v = nn.Linear(dim, dim) | |
| self.o = nn.Linear(dim, dim) | |
| self.norm_q = RMSNorm(dim, eps) | |
| self.norm_k = RMSNorm(dim, eps) | |
| self.cache_k = None | |
| self.cache_v = None | |
| def init_cache(self, ctx): | |
| # ctx: (B, S, D) | |
| k = self.norm_k(self.k(ctx)) | |
| v = self.v(ctx) | |
| # Pre-shape for attention: (B, n_heads, S, head_dim) | |
| B, S, _ = ctx.shape | |
| self.cache_k = k.reshape(B, S, self.num_heads, -1).transpose(0, 2, 1, 3) | |
| self.cache_v = v.reshape(B, S, self.num_heads, -1).transpose(0, 2, 1, 3) | |
| def clear_cache(self): | |
| self.cache_k = None | |
| self.cache_v = None | |
| def __call__(self, x, context, is_stream=False): | |
| B, L, D = x.shape | |
| q = self.norm_q(self.q(x)) | |
| q = q.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3) | |
| # Use cached K, V if available (typically context is static per prompt) | |
| k = self.cache_k | |
| v = self.cache_v | |
| if k is None: # Fallback | |
| k = self.norm_k(self.k(context)).reshape(B, -1, self.num_heads, D // self.num_heads).transpose(0, 2, 1, 3) | |
| v = self.v(context).reshape(B, -1, self.num_heads, D // self.num_heads).transpose(0, 2, 1, 3) | |
| x_out = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0/math.sqrt(D // self.num_heads)) | |
| x_out = x_out.transpose(0, 2, 1, 3).reshape(B, L, D) | |
| return self.o(x_out) | |
| class DiTBlock(nn.Module): | |
| def __init__(self, dim, num_heads, ffn_dim, eps=1e-6): | |
| super().__init__() | |
| self.self_attn = SelfAttention(dim, num_heads, eps) | |
| self.cross_attn = CrossAttention(dim, num_heads, eps) | |
| self.norm1 = nn.LayerNorm(dim, eps=eps, affine=False) | |
| self.norm2 = nn.LayerNorm(dim, eps=eps, affine=False) | |
| self.norm3 = nn.LayerNorm(dim, eps=eps) # Affine True default | |
| self.ffn = nn.Sequential( | |
| nn.Linear(dim, ffn_dim), | |
| nn.GELU(), # Approximate tanh not in standard MLX nn yet, standard GELU is fine or custom | |
| nn.Linear(ffn_dim, dim) | |
| ) | |
| self.modulation = mx.random.normal((1, 6, dim)) / dim**0.5 | |
| def __call__(self, x, context, t_mod, freqs, f, h, w, is_stream=False, pre_cache_k=None, pre_cache_v=None, **kwargs): | |
| # Modulation | |
| # t_mod: (1, 6, dim) | |
| mod = self.modulation + t_mod | |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mx.split(mod, 6, axis=1) | |
| # Self Attention | |
| norm1_x = self.norm1(x) | |
| norm1_x = modulate(norm1_x, shift_msa, scale_msa) | |
| attn_out = self.self_attn(norm1_x, freqs, f, h, w, is_stream, pre_cache_k, pre_cache_v) | |
| if is_stream: | |
| attn_out, cache_k, cache_v = attn_out | |
| x = x + gate_msa * attn_out | |
| # Cross Attention | |
| x = x + self.cross_attn(self.norm3(x), context, is_stream) | |
| # FFN | |
| norm2_x = self.norm2(x) | |
| norm2_x = modulate(norm2_x, shift_mlp, scale_mlp) | |
| x = x + gate_mlp * self.ffn(norm2_x) | |
| if is_stream: | |
| return x, cache_k, cache_v | |
| return x | |
| class Head(nn.Module): | |
| def __init__(self, dim, out_dim, patch_size, eps): | |
| super().__init__() | |
| self.norm = nn.LayerNorm(dim, eps=eps, affine=False) | |
| self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) | |
| self.modulation = mx.random.normal((1, 2, dim)) / dim**0.5 | |
| def __call__(self, x, t_mod): | |
| mod = self.modulation + t_mod | |
| shift, scale = mx.split(mod, 2, axis=1) | |
| x = self.norm(x) | |
| x = x * (1 + scale) + shift | |
| return self.head(x) | |
| # ============================================================================== | |
| # Models: WanModel, VAE | |
| # ============================================================================== | |
| class Clamp(nn.Module): | |
| def __call__(self, x): | |
| return mx.tanh(x / 3) * 3 | |
| class IdentityConv2d(nn.Module): | |
| # Used in TAEHV deep layers | |
| def __init__(self, channels, kernel_size=3): | |
| super().__init__() | |
| # In MLX, identity conv can be initialized with a dirac delta weight | |
| # Weight shape: (H, W, In, Out) | |
| self.conv = nn.Conv2d(channels, channels, kernel_size, padding=kernel_size//2, bias=False) | |
| # Manually init weights to Identity (Dirac) | |
| w = mx.zeros((kernel_size, kernel_size, channels, channels)) | |
| center = kernel_size // 2 | |
| for i in range(channels): | |
| w[center, center, i, i] = 1.0 | |
| self.conv.weight = w | |
| def __call__(self, x): | |
| return self.conv(x) | |
| # --- VAE Components --- | |
| class Encoder3d(nn.Module): | |
| def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_downsample=[True, True, False], dropout=0.0): | |
| super().__init__() | |
| self.dim = dim | |
| self.dim_mult = dim_mult | |
| self.num_res_blocks = num_res_blocks | |
| dims = [dim * u for u in [1] + dim_mult] | |
| scale = 1.0 | |
| # Initial Conv: Input 3 channels -> dims[0] | |
| self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) | |
| self.downsamples = [] | |
| for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): | |
| # ResBlocks | |
| for _ in range(num_res_blocks): | |
| self.downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) | |
| if scale in attn_scales: | |
| self.downsamples.append(AttentionBlock(out_dim)) | |
| in_dim = out_dim | |
| # Downsample Layer | |
| if i != len(dim_mult) - 1: | |
| mode = "downsample3d" if temperal_downsample[i] else "downsample2d" | |
| self.downsamples.append(Resample(out_dim, mode=mode)) | |
| scale /= 2.0 | |
| # Middle | |
| self.middle = [ | |
| ResidualBlock(out_dim, out_dim, dropout), | |
| AttentionBlock(out_dim), | |
| ResidualBlock(out_dim, out_dim, dropout) | |
| ] | |
| # Head | |
| self.head_norm = RMS_norm_General(out_dim, channel_first=False, images=False) | |
| self.head_act = nn.SiLU() | |
| self.head_conv = CausalConv3d(out_dim, z_dim, 3, padding=1) | |
| def __call__(self, x, feat_cache=None, feat_idx=[0]): | |
| # x: (N, T, H, W, C) | |
| # Conv1 | |
| if feat_cache is not None: | |
| idx = feat_idx[0] | |
| cache_x = x[:, -CACHE_T:, ...].astype(x.dtype) | |
| # Causal Pad logic | |
| if cache_x.shape[1] < 2 and feat_cache[idx] is not None: | |
| cache_x = mx.concatenate([feat_cache[idx][:, -1:, ...], cache_x], axis=1) | |
| x = self.conv1(x, cache_x=feat_cache[idx]) | |
| feat_cache[idx] = cache_x | |
| feat_idx[0] += 1 | |
| else: | |
| x = self.conv1(x) | |
| # Downsamples | |
| for layer in self.downsamples: | |
| if isinstance(layer, (ResidualBlock, Resample)) and feat_cache is not None: | |
| x = layer(x, feat_cache, feat_idx) | |
| else: | |
| x = layer(x) | |
| # Middle | |
| for layer in self.middle: | |
| if isinstance(layer, ResidualBlock) and feat_cache is not None: | |
| x = layer(x, feat_cache, feat_idx) | |
| else: | |
| x = layer(x) | |
| # Head | |
| x = self.head_norm(x) | |
| x = self.head_act(x) | |
| if feat_cache is not None: | |
| idx = feat_idx[0] | |
| cache_x = x[:, -CACHE_T:, ...].astype(x.dtype) | |
| if cache_x.shape[1] < 2 and feat_cache[idx] is not None: | |
| cache_x = mx.concatenate([feat_cache[idx][:, -1:, ...], cache_x], axis=1) | |
| x = self.head_conv(x, cache_x=feat_cache[idx]) | |
| feat_cache[idx] = cache_x | |
| feat_idx[0] += 1 | |
| else: | |
| x = self.head_conv(x) | |
| return x | |
| class Decoder3d(nn.Module): | |
| def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_upsample=[False, True, True], dropout=0.0): | |
| super().__init__() | |
| dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] | |
| scale = 1.0 / 2 ** (len(dim_mult) - 2) | |
| self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) | |
| self.middle = [ | |
| ResidualBlock(dims[0], dims[0], dropout), | |
| AttentionBlock(dims[0]), | |
| ResidualBlock(dims[0], dims[0], dropout) | |
| ] | |
| self.upsamples = [] | |
| for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): | |
| # Channel adjustment logic from original code | |
| if i == 1 or i == 2 or i == 3: | |
| in_dim = in_dim // 2 | |
| # Blocks | |
| for _ in range(num_res_blocks + 1): | |
| self.upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) | |
| if scale in attn_scales: | |
| self.upsamples.append(AttentionBlock(out_dim)) | |
| in_dim = out_dim | |
| # Upsample Layer | |
| if i != len(dim_mult) - 1: | |
| mode = "upsample3d" if temperal_upsample[i] else "upsample2d" | |
| self.upsamples.append(Resample(out_dim, mode=mode)) | |
| scale *= 2.0 | |
| self.head_norm = RMS_norm_General(out_dim, channel_first=False, images=False) | |
| self.head_act = nn.SiLU() | |
| self.head_conv = CausalConv3d(out_dim, 3, 3, padding=1) # Output RGB | |
| def __call__(self, x, feat_cache=None, feat_idx=[0]): | |
| # x: (N, T, H, W, C) | |
| # Conv1 | |
| if feat_cache is not None: | |
| idx = feat_idx[0] | |
| cache_x = x[:, -CACHE_T:, ...].astype(x.dtype) | |
| if cache_x.shape[1] < 2 and feat_cache[idx] is not None: | |
| cache_x = mx.concatenate([feat_cache[idx][:, -1:, ...], cache_x], axis=1) | |
| x = self.conv1(x, cache_x=feat_cache[idx]) | |
| feat_cache[idx] = cache_x | |
| feat_idx[0] += 1 | |
| else: | |
| x = self.conv1(x) | |
| # Middle | |
| for layer in self.middle: | |
| if isinstance(layer, ResidualBlock) and feat_cache is not None: | |
| x = layer(x, feat_cache, feat_idx) | |
| else: | |
| x = layer(x) | |
| # Upsamples | |
| for layer in self.upsamples: | |
| if isinstance(layer, (ResidualBlock, Resample)) and feat_cache is not None: | |
| x = layer(x, feat_cache, feat_idx) | |
| else: | |
| x = layer(x) | |
| # Head | |
| x = self.head_norm(x) | |
| x = self.head_act(x) | |
| if feat_cache is not None: | |
| idx = feat_idx[0] | |
| cache_x = x[:, -CACHE_T:, ...].astype(x.dtype) | |
| if cache_x.shape[1] < 2 and feat_cache[idx] is not None: | |
| cache_x = mx.concatenate([feat_cache[idx][:, -1:, ...], cache_x], axis=1) | |
| x = self.head_conv(x, cache_x=feat_cache[idx]) | |
| feat_cache[idx] = cache_x | |
| feat_idx[0] += 1 | |
| else: | |
| x = self.head_conv(x) | |
| return x | |
| class WanModel(nn.Module): | |
| def __init__(self, dim, in_dim, ffn_dim, out_dim, text_dim, freq_dim, eps, patch_size, num_heads, num_layers): | |
| super().__init__() | |
| self.dim = dim | |
| self.freq_dim = freq_dim | |
| self.patch_size = patch_size | |
| self.num_heads = num_heads | |
| # Embeddings | |
| self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) | |
| self.text_embedding = nn.Sequential( | |
| nn.Linear(text_dim, dim), | |
| nn.GELU(), | |
| nn.Linear(dim, dim) | |
| ) | |
| self.time_embedding = nn.Sequential( | |
| nn.Linear(freq_dim, dim), | |
| nn.SiLU(), | |
| nn.Linear(dim, dim) | |
| ) | |
| self.time_projection = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(dim, dim * 6) | |
| ) | |
| self.blocks = [DiTBlock(dim, num_heads, ffn_dim, eps) for _ in range(num_layers)] | |
| self.head = Head(dim, out_dim, patch_size, eps) | |
| # Freqs | |
| self.freqs = precompute_freqs_cis_3d(dim // num_heads) | |
| # LQ Projector placeholder | |
| self.LQ_proj_in = None | |
| def patchify(self, x): | |
| x = self.patch_embedding(x) | |
| grid_size = (x.shape[1], x.shape[2], x.shape[3]) | |
| x = x.reshape(x.shape[0], -1, x.shape[-1]) | |
| return x, grid_size | |
| def unpatchify(self, x, grid_size): | |
| N, L, _ = x.shape | |
| f, h, w = grid_size | |
| px, py, pz = self.patch_size | |
| x = x.reshape(N, f, h, w, pz, py, px, -1) | |
| x = x.transpose(0, 1, 4, 2, 5, 3, 6, 7) | |
| x = x.reshape(N, f*pz, h*py, w*px, -1) | |
| return x | |
| def __call__(self, x, timestep, context, LQ_latents=None, is_stream=False, pre_cache_k=None, pre_cache_v=None, **kwargs): | |
| # Time embeddings | |
| t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) | |
| t_mod = self.time_projection(t).reshape(1, 6, self.dim) | |
| # Patchify | |
| x, (f, h, w) = self.patchify(x) | |
| # --- FIX FREQS LOGIC (ROBUST V3) --- | |
| # Hàm nội bộ để lấy embedding an toàn, tự động tính lại nếu thiếu | |
| def get_freq_emb(freq_cache, idx, length, dim_sub): | |
| # freq_cache: (MaxLen, DimSub, 2) | |
| # Nếu length > MaxLen, tính lại on-the-fly | |
| if length > freq_cache.shape[0]: | |
| theta = 10000.0 | |
| freqs = 1.0 / (theta ** (mx.arange(0, dim_sub * 2, 2).astype(mx.float32) / (dim_sub * 2))) | |
| t_seq = mx.arange(length).astype(mx.float32) | |
| freqs_outer = mx.outer(t_seq, freqs) | |
| freqs_cos = mx.cos(freqs_outer) | |
| freqs_sin = mx.sin(freqs_outer) | |
| emb = mx.stack([freqs_cos, freqs_sin], axis=-1) | |
| return emb.reshape(length, -1) | |
| else: | |
| return freq_cache[:length].reshape(length, -1) | |
| head_dim = self.dim // self.num_heads | |
| dim_h = head_dim // 3 | |
| dim_w = head_dim // 3 | |
| dim_f = head_dim - dim_h - dim_w | |
| # Lấy embeddings | |
| f_emb = get_freq_emb(self.freqs[0], 0, f, dim_f // 2) | |
| h_emb = get_freq_emb(self.freqs[1], 1, h, dim_h // 2) | |
| w_emb = get_freq_emb(self.freqs[2], 2, w, dim_w // 2) | |
| # Reshape & Broadcast | |
| f_broad = mx.broadcast_to(f_emb.reshape(f, 1, 1, -1), (f, h, w, f_emb.shape[-1])) | |
| h_broad = mx.broadcast_to(h_emb.reshape(1, h, 1, -1), (f, h, w, h_emb.shape[-1])) | |
| w_broad = mx.broadcast_to(w_emb.reshape(1, 1, w, -1), (f, h, w, w_emb.shape[-1])) | |
| # Concatenate | |
| freqs = mx.concatenate([f_broad, h_broad, w_broad], axis=-1) | |
| freqs = freqs.reshape(f*h*w, 1, -1) | |
| # ------------------------------------- | |
| # Blocks | |
| out_caches_k = [] | |
| out_caches_v = [] | |
| for i, block in enumerate(self.blocks): | |
| if LQ_latents is not None and i < len(LQ_latents): | |
| x = x + LQ_latents[i] | |
| pck = pre_cache_k[i] if pre_cache_k else None | |
| pcv = pre_cache_v[i] if pre_cache_v else None | |
| if is_stream: | |
| x, ck, cv = block(x, context, t_mod, freqs, f, h, w, is_stream=True, pre_cache_k=pck, pre_cache_v=pcv) | |
| out_caches_k.append(ck) | |
| out_caches_v.append(cv) | |
| else: | |
| x = block(x, context, t_mod, freqs, f, h, w) | |
| # Head | |
| t_for_head = t.reshape(1, 1, -1) if t.ndim==2 else t | |
| x = self.head(x, t_for_head) | |
| x = self.unpatchify(x, (f, h, w)) | |
| if is_stream: | |
| return x, out_caches_k, out_caches_v | |
| return x | |
| class WanVideoVAE(nn.Module): | |
| def __init__(self, z_dim=16, dim=96): | |
| super().__init__() | |
| # Standard configs for WanVAE | |
| dim_mult = [1, 2, 4, 4] | |
| num_res_blocks = 2 | |
| attn_scales = [] | |
| temperal_downsample = [False, True, True] | |
| temperal_upsample = temperal_downsample[::-1] | |
| self.z_dim = z_dim | |
| # Encoder output is z_dim * 2 (mean + logvar) | |
| self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, temperal_downsample) | |
| self.quant_conv = CausalConv3d(z_dim * 2, z_dim * 2, 1) | |
| self.post_quant_conv = CausalConv3d(z_dim, z_dim, 1) | |
| self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, temperal_upsample) | |
| def encode(self, x): | |
| # x: (N, T, H, W, 3) | |
| h = self.encoder(x) | |
| moments = self.quant_conv(h) | |
| mean, logvar = mx.split(moments, 2, axis=-1) | |
| return mean, logvar | |
| def decode(self, z): | |
| # z: (N, T, H, W, z_dim) | |
| z = self.post_quant_conv(z) | |
| dec = self.decoder(z) | |
| return dec | |
| # Helper to count cache size needed | |
| def get_cache_size(self, is_encoder=False): | |
| # Rough heuristic count of CausalConv layers | |
| # In real usage, one would traverse the graph. | |
| # For this specific architecture: | |
| if is_encoder: return 50 # Safe upper bound | |
| return 50 # Safe upper bound | |
| # --- TAEHV (FlashVSR Decoder) Components --- | |
| class Buffer_LQ4x_Proj(nn.Module): | |
| def __init__(self, in_dim, out_dim, layer_num=30): | |
| super().__init__() | |
| self.hidden_dim1, self.hidden_dim2, self.layer_num = 2048, 3072, layer_num | |
| self.pixel_shuffle = PixelShuffle3d(1, 16, 16) | |
| # Input channel calc: in_dim * 1 * 16 * 16 | |
| in_c = in_dim * 256 | |
| # Note: Padding is handled manually in stream_forward to satisfy MLX Kernel constraints | |
| self.conv1 = CausalConv3d(in_c, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(0, 1, 1), mode='replicate') | |
| self.norm1 = RMS_norm_General(self.hidden_dim1, channel_first=False, images=False) | |
| self.act1 = nn.SiLU() | |
| self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(0, 1, 1), mode='replicate') | |
| self.norm2 = RMS_norm_General(self.hidden_dim2, channel_first=False, images=False) | |
| self.act2 = nn.SiLU() | |
| self.linear_layers = [nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)] | |
| self.cache = {"conv1": None, "conv2": None} | |
| self.clip_idx = 0 | |
| def stream_forward(self, video_clip): | |
| # FIX: Ensure input Time dimension >= Kernel Time dimension (4) | |
| if self.clip_idx == 0: | |
| # First frame: Pad generously at start (Replicate) | |
| # Input: [0]. Target Conv Input: [0,0,0,0] | |
| first = video_clip[:, :1, ...] | |
| # Pad 3 frames at start so total is 4 | |
| video_clip_padded = mx.concatenate([mx.repeat(first, 3, axis=1), video_clip], axis=1) | |
| x = self.pixel_shuffle(video_clip_padded) # (B, 4, ...) | |
| # Save raw input to cache (last 2 frames) for next step | |
| self.cache["conv1"] = x[:, -CACHE_T:, ...] | |
| x = self.conv1(x) # Input T=4, Kernel=4, Stride=2 -> Out T=1 | |
| x = self.act1(self.norm1(x)) | |
| # Save feature to cache 2 | |
| self.cache["conv2"] = x[:, -CACHE_T:, ...] | |
| self.clip_idx += 1 | |
| return None | |
| else: | |
| x = self.pixel_shuffle(video_clip) | |
| # --- CONV 1 --- | |
| # Retrieve cache | |
| prev_cache = self.cache["conv1"] | |
| # Update cache for NEXT step immediately with current raw input (Last CACHE_T frames) | |
| combined_raw = mx.concatenate([prev_cache, x], axis=1) | |
| self.cache["conv1"] = combined_raw[:, -CACHE_T:, ...] | |
| # Construct Input for Conv: Need T >= 4 | |
| inp1 = combined_raw | |
| if inp1.shape[1] < 4: | |
| pad_amt = 4 - inp1.shape[1] | |
| first = inp1[:, :1, ...] | |
| inp1 = mx.concatenate([mx.repeat(first, pad_amt, axis=1), inp1], axis=1) | |
| # Run Conv | |
| x = self.conv1(inp1) | |
| x = self.act1(self.norm1(x)) | |
| # --- CONV 2 --- | |
| prev_cache_2 = self.cache["conv2"] | |
| combined_feat = mx.concatenate([prev_cache_2, x], axis=1) | |
| # Update cache | |
| self.cache["conv2"] = combined_feat[:, -CACHE_T:, ...] | |
| # Pad for Conv: Need T >= 4 | |
| inp2 = combined_feat | |
| if inp2.shape[1] < 4: | |
| pad_amt = 4 - inp2.shape[1] | |
| first = inp2[:, :1, ...] | |
| inp2 = mx.concatenate([mx.repeat(first, pad_amt, axis=1), inp2], axis=1) | |
| x = self.conv2(inp2) | |
| x = self.act2(self.norm2(x)) | |
| # Linear projections | |
| B, T, H, W, C = x.shape | |
| x_flat = x.reshape(B, T*H*W, C) | |
| out = [l(x_flat) for l in self.linear_layers] | |
| self.clip_idx += 1 | |
| return out | |
| class MemBlock(nn.Module): | |
| def __init__(self, n_in, n_out): | |
| super().__init__() | |
| # Note: FlashVSR uses Conv2d for these blocks treating Time as Batch for spatial mixing | |
| # Inputs will be (N*T, H, W, C) | |
| self.conv_main_0 = nn.Conv2d(n_in * 2, n_out, 3, padding=1) | |
| self.conv_main_1 = nn.Conv2d(n_out, n_out, 3, padding=1) | |
| self.conv_main_2 = nn.Conv2d(n_out, n_out, 3, padding=1) | |
| self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else None | |
| self.act = nn.ReLU() | |
| def __call__(self, x, past): | |
| # x: (B, H, W, C) | |
| # past: (B, H, W, C) - Memory from previous timestep | |
| # Concat along Channel | |
| inp = mx.concatenate([x, past], axis=-1) | |
| h = self.act(self.conv_main_0(inp)) | |
| h = self.act(self.conv_main_1(h)) | |
| h = self.conv_main_2(h) | |
| skip_x = self.skip(x) if self.skip is not None else x | |
| return self.act(h + skip_x) | |
| class TGrow(nn.Module): | |
| def __init__(self, n_f, stride): | |
| super().__init__() | |
| self.stride = stride | |
| self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False) | |
| def __call__(self, x): | |
| # x: (B*T, H, W, C) -> (B*T, H, W, C*stride) | |
| x = self.conv(x) | |
| if self.stride > 1: | |
| # We need to expand the time dimension. | |
| # Current layout: (NT, H, W, C*stride) | |
| # Target: (NT * stride, H, W, C) | |
| NT, H, W, C_full = x.shape | |
| C = C_full // self.stride | |
| # Reshape to separate stride | |
| x = x.reshape(NT, H, W, self.stride, C) | |
| # Permute to bring stride next to NT | |
| x = x.transpose(0, 3, 1, 2, 4) # (NT, stride, H, W, C) | |
| x = x.reshape(NT * self.stride, H, W, C) | |
| return x | |
| class PixelShuffle3dTAEHV(nn.Module): | |
| def __init__(self, ff, hh, ww): | |
| super().__init__() | |
| self.ff, self.hh, self.ww = ff, hh, ww | |
| def __call__(self, x): | |
| # Input x: (B, T, H, W, C) where H,W are small, C is large | |
| # We want to reverse pixel shuffle into channels? | |
| # Wait, the original code is: | |
| # rearrange(x, "b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w") | |
| # This is essentially a "Pixel Unshuffle" or "Space-to-Depth". | |
| B, T, H, W, C = x.shape | |
| # Target logic: Compress dimensions into channels | |
| # T -> T_new * ff, H -> H_new * hh ... | |
| # But here inputs are likely (B, C, F, H, W) in Torch. | |
| # Original: Input video (High Res) -> Pixel Unshuffle -> Latent | |
| # MLX Input: (B, T_full, H_full, W_full, C) | |
| # Check divisibility | |
| pad_t = (self.ff - T % self.ff) % self.ff | |
| if pad_t > 0: | |
| first = x[:, :1, ...] | |
| x = mx.concatenate([mx.repeat(first, pad_t, axis=1), x], axis=1) | |
| T = x.shape[1] | |
| # Reshape for unshuffle | |
| # (B, T//ff, ff, H//hh, hh, W//ww, ww, C) | |
| x = x.reshape(B, T//self.ff, self.ff, H//self.hh, self.hh, W//self.ww, self.ww, C) | |
| # Transpose to gather blocks into channel | |
| # Target: (B, T//ff, H//hh, W//ww, C * ff * hh * ww) | |
| # Permute: (0, 1, 3, 5, 7, 2, 4, 6) | |
| x = x.transpose(0, 1, 3, 5, 7, 2, 4, 6) | |
| # Flatten | |
| x = x.reshape(B, T//self.ff, H//self.hh, W//self.ww, -1) | |
| return x | |
| class TAEHV(nn.Module): | |
| image_channels = 3 | |
| def __init__(self, channels=[256, 128, 64, 64], latent_channels=16, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True)): | |
| super().__init__() | |
| self.latent_channels = latent_channels | |
| self.frames_to_trim = 2 ** sum(decoder_time_upscale) - 1 | |
| n_f = channels | |
| # Building the layers list (Manual construction to match PyTorch Sequential) | |
| self.layers = [] | |
| # --- Block 0 --- | |
| self.layers.append(Clamp()) | |
| self.layers.append(nn.Conv2d(self.latent_channels, n_f[0], 3, padding=1)) | |
| self.layers.append(nn.ReLU()) | |
| self.layers.append(MemBlock(n_f[0], n_f[0])) | |
| self.layers.append(MemBlock(n_f[0], n_f[0])) | |
| self.layers.append(MemBlock(n_f[0], n_f[0])) | |
| # Upsample 0 | |
| if decoder_space_upscale[0]: | |
| # In MLX, use UpSampling2d or Repeat | |
| self.layers.append("upsample_space_2x") | |
| self.layers.append(TGrow(n_f[0], 1)) # Stride 1 (No T upsample) | |
| self.layers.append(nn.Conv2d(n_f[0], n_f[1], 3, padding=1, bias=False)) | |
| # --- Block 1 --- | |
| self.layers.append(MemBlock(n_f[1], n_f[1])) | |
| self.layers.append(MemBlock(n_f[1], n_f[1])) | |
| self.layers.append(MemBlock(n_f[1], n_f[1])) | |
| # Upsample 1 | |
| if decoder_space_upscale[1]: | |
| self.layers.append("upsample_space_2x") | |
| self.layers.append(TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1)) | |
| self.layers.append(nn.Conv2d(n_f[1], n_f[2], 3, padding=1, bias=False)) | |
| # --- Block 2 --- | |
| self.layers.append(MemBlock(n_f[2], n_f[2])) | |
| self.layers.append(MemBlock(n_f[2], n_f[2])) | |
| self.layers.append(MemBlock(n_f[2], n_f[2])) | |
| # Upsample 2 | |
| if decoder_space_upscale[2]: | |
| self.layers.append("upsample_space_2x") | |
| self.layers.append(TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1)) | |
| self.layers.append(nn.Conv2d(n_f[2], n_f[3], 3, padding=1, bias=False)) | |
| self.layers.append(nn.ReLU()) | |
| self.layers.append(nn.Conv2d(n_f[3], TAEHV.image_channels, 3, padding=1)) | |
| # Apply "Identity Deepen" logic (Inserting Identity Conv + ReLU after existing ReLUs) | |
| # To simplify porting, we assume the provided weights match the deepened structure | |
| # or we manually construct the exact list. | |
| # Here we apply the expansion dynamically similar to PyTorch code. | |
| self.final_layers = [] | |
| for layer in self.layers: | |
| self.final_layers.append(layer) | |
| # Check if layer is ReLU | |
| is_relu = isinstance(layer, nn.ReLU) | |
| if is_relu: | |
| # Need to find channel count of previous layer | |
| prev = self.final_layers[-2] | |
| C = None | |
| if isinstance(prev, nn.Conv2d): C = prev.weight.shape[-1] # NHWC: output is last | |
| elif isinstance(prev, MemBlock): C = prev.conv_main_2.weight.shape[-1] | |
| if C is not None: | |
| # Add Identity Block (how_many_each=1) | |
| self.final_layers.append(IdentityConv2d(C)) | |
| self.final_layers.append(nn.ReLU()) | |
| self.pixel_shuffle = PixelShuffle3dTAEHV(4, 8, 8) | |
| self.mem_state = [None] * len(self.final_layers) | |
| def clean_mem(self): | |
| self.mem_state = [None] * len(self.final_layers) | |
| def decode_video(self, x, cond=None): | |
| # x: (B, T, H, W, C) - Latents | |
| # cond: (B, T_full, H_full, W_full, 3) - LQ Video input | |
| trim_flag = self.mem_state[0] is None # Simple check if it's first run | |
| if cond is not None: | |
| # Preprocess condition (pixel unshuffle) | |
| cond_feat = self.pixel_shuffle(cond) | |
| # x should match T dimension logic or be compatible | |
| # x is latents, cond_feat is reshaped LQ | |
| # Concat along channel | |
| x = mx.concatenate([cond_feat, x], axis=-1) | |
| # Apply model with memblocks logic | |
| # For MLX, we iterate Time steps? | |
| # Since MemBlocks rely on "past", we must process sequentially along Time dim, | |
| # OR if we have full sequence, we can use scan or simple loop. | |
| # But wait, TGrow expands Time. | |
| # The logic in original code `apply_model_with_memblocks` handles splitting T. | |
| B, T, H, W, C = x.shape | |
| # Flatten B*T for spatial layers | |
| # But MemBlock needs T separation. | |
| # We process chunk by chunk (frame by frame effectively relative to layer stride) | |
| # Initial chunking | |
| current_x = x # (B, T, H, W, C) | |
| # We iterate through layers. | |
| # If layer is Spatial (Conv2d, ReLU), flatten T. | |
| # If layer is MemBlock, unflatten T, apply recurrent, flatten back. | |
| # If layer is TGrow, expand T. | |
| for i, layer in enumerate(self.final_layers): | |
| if layer == "upsample_space_2x": | |
| # Nearest Upsample | |
| # current_x: (B, T, H, W, C) | |
| current_x = mx.repeat(current_x, 2, axis=2) | |
| current_x = mx.repeat(current_x, 2, axis=3) | |
| continue | |
| # Helper to reshape for spatial op: (B*T, H, W, C) | |
| B_curr, T_curr, H_curr, W_curr, C_curr = current_x.shape | |
| if isinstance(layer, MemBlock): | |
| # Need explicit time iteration or parallel shift | |
| # Original code: pad with 1 zero frame at start, slice off end. | |
| # (Equivalent to prev_frame) | |
| # Check memory state | |
| if self.mem_state[i] is None: | |
| # First chunk: prev is 0 | |
| prev = mx.zeros_like(current_x[:, 0:1, ...]) # (B, 1, H, W, C) | |
| # Pad x at T=-1 with 0 | |
| x_padded = mx.concatenate([prev, current_x[:, :-1, ...]], axis=1) | |
| # Save last frame for next chunk | |
| self.mem_state[i] = current_x[:, -1:, ...] | |
| else: | |
| # Subsequent chunk | |
| prev_mem = self.mem_state[i] | |
| # Pad x at T=-1 with prev_mem | |
| x_padded = mx.concatenate([prev_mem, current_x[:, :-1, ...]], axis=1) | |
| # Save last frame | |
| self.mem_state[i] = current_x[:, -1:, ...] | |
| # Apply MemBlock spatially | |
| # Flatten T | |
| x_in_flat = current_x.reshape(-1, H_curr, W_curr, C_curr) | |
| x_past_flat = x_padded.reshape(-1, H_curr, W_curr, C_curr) | |
| out_flat = layer(x_in_flat, x_past_flat) | |
| current_x = out_flat.reshape(B_curr, T_curr, H_curr, W_curr, -1) | |
| elif isinstance(layer, TGrow): | |
| # Flatten | |
| x_flat = current_x.reshape(-1, H_curr, W_curr, C_curr) | |
| out_flat = layer(x_flat) | |
| # Reshape back handled inside TGrow logic? | |
| # My TGrow implementation above returns expanded time if stride > 1 | |
| # But it expects (BT, ...) input. | |
| # If stride > 1, output of layer(x_flat) is (BT*stride, H, W, C) | |
| # We need to reshape to (B, T*stride, H, W, C) | |
| new_stride = layer.stride | |
| current_x = out_flat.reshape(B_curr, T_curr * new_stride, H_curr, W_curr, -1) | |
| else: | |
| # Standard Spatial Layer | |
| x_flat = current_x.reshape(-1, H_curr, W_curr, C_curr) | |
| out_flat = layer(x_flat) | |
| current_x = out_flat.reshape(B_curr, T_curr, H_curr, W_curr, -1) | |
| # Trim frames if first run | |
| if trim_flag and self.frames_to_trim > 0: | |
| current_x = current_x[:, self.frames_to_trim:, ...] | |
| return current_x | |
| # ============================================================================== | |
| # Weight Loading & Conversion | |
| # ============================================================================== | |
| def map_torch_to_mlx(key, val): | |
| # Transpose weights for MLX layout | |
| if "conv" in key and val.ndim == 4: # Conv2d (Out, In, H, W) -> (Out, H, W, In) | |
| val = val.transpose(0, 2, 3, 1) | |
| elif "conv" in key and val.ndim == 5: # Conv3d (Out, In, D, H, W) -> (Out, D, H, W, In) | |
| val = val.transpose(0, 2, 3, 4, 1) | |
| elif "linear" in key or "proj" in key: | |
| if val.ndim == 2: # (Out, In) -> (Out, In) -- MLX Linear weights are (In, Out) usually transposed in loading | |
| val = val.transpose(1, 0) | |
| return key, mx.array(val.numpy()) | |
| def load_weights(model, ckpt_path): | |
| print(f"Loading weights from {ckpt_path}...") | |
| if ckpt_path.endswith(".safetensors"): | |
| weights = {} | |
| with safe_open(ckpt_path, framework="numpy") as f: | |
| for k in f.keys(): | |
| weights[k] = f.get_tensor(k) | |
| else: | |
| # Load torch pth using numpy | |
| import torch | |
| weights = torch.load(ckpt_path, map_location="cpu", weights_only=True) | |
| weights = {k: v.numpy() for k, v in weights.items()} | |
| # Conversion loop | |
| mlx_weights = {} | |
| for k, v in weights.items(): | |
| # Heuristic mapping | |
| new_k = k | |
| # Map specific PyTorch keys to MLX model structure | |
| # Example: blocks.0.attn1.to_q.weight -> blocks[0].self_attn.q.weight | |
| # This requires a detailed mapping dictionary based on the implemented class names | |
| # Generic Transpose | |
| if v.ndim == 4: # Conv2d | |
| v = v.transpose(0, 2, 3, 1) | |
| elif v.ndim == 5: # Conv3d | |
| v = v.transpose(0, 2, 3, 4, 1) | |
| elif v.ndim == 2 and "weight" in k: # Linear | |
| v = v.transpose(1, 0) | |
| mlx_weights[new_k] = mx.array(v) | |
| # Load into model (partial loading allowed) | |
| model.load_weights(list(mlx_weights.items()), strict=False) | |
| mx.eval(model.parameters()) | |
| print("Weights loaded.") | |
| # ============================================================================== | |
| # Pipeline | |
| # ============================================================================== | |
| class FlashVSRTinyPipeline: | |
| def __init__(self, device="gpu"): | |
| self.dit = None | |
| self.vae = None | |
| self.lq_proj = None | |
| def load_models(self): | |
| # Config WanModel | |
| self.dit = WanModel( | |
| dim=5120, in_dim=16, ffn_dim=13824, out_dim=16, text_dim=4096, | |
| freq_dim=256, eps=1e-6, patch_size=(1,2,2), num_heads=40, num_layers=40 | |
| ) | |
| self.lq_proj = Buffer_LQ4x_Proj(in_dim=3, out_dim=5120, layer_num=1) | |
| # Compile | |
| self.step_fn = mx.compile(self.dit) | |
| def __call__(self, video_path, num_frames=16, tile_size=128): | |
| # 1. Read Video | |
| print("Reading video...") | |
| frames, fps = self.read_video(video_path, num_frames) | |
| frames = mx.array(frames) | |
| frames = frames[None, ...] | |
| # 2. Latent Projection | |
| print("Running LQ Projection...") | |
| lq_latents_raw = [] | |
| for i in tqdm(range(frames.shape[1])): | |
| clip = frames[:, i:i+1, ...] | |
| lat = self.lq_proj.stream_forward(clip) | |
| if lat is not None: | |
| lq_latents_raw.append(lat) | |
| final_lq_latents = [] | |
| if lq_latents_raw: | |
| num_layers = len(lq_latents_raw[0]) | |
| for layer_idx in range(num_layers): | |
| time_steps_for_layer = [step[layer_idx] for step in lq_latents_raw] | |
| concatenated = mx.concatenate(time_steps_for_layer, axis=1) | |
| final_lq_latents.append(concatenated) | |
| if not final_lq_latents: | |
| print("Warning: No latents produced (video too short?). Returning None.") | |
| return None | |
| # 3. DiT Inference | |
| print("Running DiT...") | |
| ref_latent = final_lq_latents[0] # (B, L, 5120) | |
| B, L, D = ref_latent.shape | |
| # Noise input | |
| x_start = mx.random.normal((B, L, 5120)) | |
| # --- FIX: CREATE DUMMY CONTEXT --- | |
| # WanModel text_dim = 4096 (được định nghĩa trong load_models) | |
| # Tạo một context rỗng (ví dụ: chuỗi dài 512 tokens, dim 4096) | |
| # Trong thực tế, đây là output của T5 Encoder. | |
| text_dim = 4096 | |
| context_len = 512 | |
| dummy_context = mx.zeros((B, context_len, text_dim)) | |
| # --------------------------------- | |
| # Patching Patchify | |
| original_patchify = self.dit.patchify | |
| def smart_patchify(inp): | |
| if inp.ndim == 3: | |
| return inp, (1, 1, inp.shape[1]) | |
| else: | |
| return original_patchify(inp) | |
| self.dit.patchify = smart_patchify | |
| t = mx.array([1000.0]) | |
| # Truyền dummy_context vào thay vì None | |
| out = self.step_fn(x_start, t, context=dummy_context, LQ_latents=final_lq_latents, is_stream=True) | |
| self.dit.patchify = original_patchify | |
| # 4. Decode | |
| print("Decoding (Mock)...") | |
| return out | |
| def read_video(self, path, num_frames): | |
| reader = imageio.get_reader(path) | |
| frames = [] | |
| for i, im in enumerate(reader): | |
| if i >= num_frames: break | |
| frames.append(im) | |
| return np.array(frames) / 255.0, 30 | |
| # ============================================================================== | |
| # Main | |
| # ============================================================================== | |
| def main(): | |
| if not os.path.exists(args.output_folder): | |
| os.makedirs(args.output_folder) | |
| # Check for M-series | |
| print(f"Running on Device: {mx.default_device()}") | |
| pipe = FlashVSRTinyPipeline() | |
| pipe.load_models() # Uncomment when weights are present | |
| # Mock run | |
| print("Initializing pipeline (Mock mode for code verification)...") | |
| # Example flow | |
| pipe(args.input) | |
| print("Done. (Note: Weights are needed to produce actual video)") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment