Skip to content

Instantly share code, notes, and snippets.

@tamnguyenvan
Last active December 27, 2025 02:09
Show Gist options
  • Select an option

  • Save tamnguyenvan/3e5cb659c4e196e36e8e50497a0b60ee to your computer and use it in GitHub Desktop.

Select an option

Save tamnguyenvan/3e5cb659c4e196e36e8e50497a0b60ee to your computer and use it in GitHub Desktop.
import os
import re
import math
import uuid
import shutil
import random
import hashlib
import gc
import copy
import json
from dataclasses import dataclass
from typing import List, Tuple, Optional, Literal
from contextlib import contextmanager
import numpy as np
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_flatten, tree_unflatten
from PIL import Image
from tqdm.auto import tqdm
from huggingface_hub import snapshot_download
from safetensors import safe_open
import imageio
import ffmpeg
# ==============================================================================
# Configuration & Globals
# ==============================================================================
@dataclass
class Args:
input: str = "input.mp4"
output_folder: str = "./outputs"
scale: int = 4
version: str = "10"
mode: str = "tiny"
tiled_vae: bool = True
tiled_dit: bool = True
tile_size: int = 16
overlap: int = 8
unload_dit: bool = True
color_fix: bool = False
seed: int = 0
dtype: str = "fp16" # MLX uses fp16 or fp32
fps: int = 30
quality: int = 6
attention: str = "sdpa"
args = Args()
CACHE_T = 2
root = os.path.dirname(os.path.abspath(__file__))
temp = os.path.join(root, "_temp")
def log(message: str, message_type: str = "normal"):
if message_type == "error":
message = "\033[1;41m" + message + "\033[m"
elif message_type == "warning":
message = "\033[1;31m" + message + "\033[m"
elif message_type == "finish":
message = "\033[1;32m" + message + "\033[m"
elif message_type == "info":
message = "\033[1;33m" + message + "\033[m"
print(f"{message}")
# ==============================================================================
# MLX Utils & Helpers
# ==============================================================================
def rearrange(x, pattern, **axes_lengths):
# Simplified einops-like rearrange for MLX using reshape and transpose
# Note: This is a manual implementation for specific patterns used in the code
# to avoid external dependency issues with compiling.
# In a real scenario, explicit MLX ops are preferred for speed.
return x # Placeholder, logic implemented inline for optimization
def sinusoidal_embedding_1d(dim, position):
# position: (N,)
half_dim = dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = mx.exp(mx.arange(half_dim) * -emb)
emb = position[:, None] * emb[None, :]
emb = mx.concatenate([mx.cos(emb), mx.sin(emb)], axis=-1)
if dim % 2 == 1:
emb = mx.pad(emb, ((0,0), (0,1)))
return emb
def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
return f_freqs_cis, h_freqs_cis, w_freqs_cis
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
freqs = 1.0 / (theta ** (mx.arange(0, dim, 2).astype(mx.float32) / dim))
t = mx.arange(end).astype(mx.float32)
freqs = mx.outer(t, freqs) # (end, dim//2)
# MLX doesn't have complex number support like PyTorch's polar for this exact usage yet
# We store as (cos, sin) for rope_apply
freqs_cos = mx.cos(freqs)
freqs_sin = mx.sin(freqs)
return mx.stack([freqs_cos, freqs_sin], axis=-1)
def rope_apply(x, freqs, num_heads):
# x: (B, S, n*d) -> (B, S, n, d)
B, S, _ = x.shape
x = x.reshape(B, S, num_heads, -1)
# x is formatted as real numbers, logic needs to treat pairs as complex numbers
# x: (..., d) where d is even
x0 = x[..., 0::2]
x1 = x[..., 1::2]
# freqs: (S, 1, d//2, 2) -> (1, S, 1, d//2, 2)
# Broadcast freqs
freqs_cos = freqs[..., 0][None, :, None, :] # (1, S, 1, D/2)
freqs_sin = freqs[..., 1][None, :, None, :] # (1, S, 1, D/2)
# Rotate
x_out0 = x0 * freqs_cos - x1 * freqs_sin
x_out1 = x0 * freqs_sin + x1 * freqs_cos
# Stack back (interleave)
x_out = mx.stack([x_out0, x_out1], axis=-1).reshape(B, S, num_heads, -1)
return x_out.reshape(B, S, -1)
def modulate(x, shift, scale):
return x * (1 + scale) + shift
# ==============================================================================
# Layers (MLX Port)
# ==============================================================================
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = mx.ones((dim,))
def __call__(self, x):
# x: (..., dim)
return mx.fast.rms_norm(x, self.weight, self.eps)
class RMS_norm_General(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
# MLX prefers channel last, so we adapt input if channel_first is True
self.channel_first = channel_first
self.images = images
self.scale = dim**0.5
self.gamma = mx.ones((dim,))
self.bias = mx.zeros((dim,)) if bias else None
def __call__(self, x):
# Standardize to channel last for norm, then convert back if needed
if self.channel_first:
# (B, C, ...) -> (B, ..., C)
if x.ndim == 5: # B, C, T, H, W
x = x.transpose(0, 2, 3, 4, 1)
elif x.ndim == 4: # B, C, H, W
x = x.transpose(0, 2, 3, 1)
# Norm
out = mx.fast.rms_norm(x, self.gamma, 1e-6) # Using fast kernel
out = out * self.scale
if self.bias is not None:
out = out + self.bias
if self.channel_first:
# (B, ..., C) -> (B, C, ...)
if x.ndim == 5:
out = out.transpose(0, 4, 1, 2, 3)
elif x.ndim == 4:
out = out.transpose(0, 3, 1, 2)
return out
class CausalConv3d(nn.Module):
"""
MLX Conv3d expects (N, D, H, W, C).
PyTorch Conv3d expects (N, C, D, H, W).
We wrap this to handle layout internally but optimize weights for NHWC.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, mode='zeros'):
super().__init__()
if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(stride, int): stride = (stride, stride, stride)
if isinstance(padding, int): padding = (padding, padding, padding)
self.mode = mode
self.padding_val = padding
# Causal padding on Time dimension (dim 0 of kernel in MLX terms for 3D)
# padding is (time, height, width)
self.time_pad = 2 * padding[0]
self.spatial_padding = (padding[1], padding[2])
# MLX Conv3D: weight shape (kD, kH, kW, C_in, C_out)
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding=0, bias=bias)
# Adjust internal padding for spatial
# We handle Time padding manually for causality
def __call__(self, x, cache_x=None):
# Input x is expected to be (N, C, T, H, W) to match PyTorch code flow,
# OR (N, T, H, W, C) for MLX optimization.
# Let's standardize on (N, T, H, W, C) for MLX layers.
# Handle Causal Padding
if cache_x is not None and self.time_pad > 0:
# cache_x: (N, T_cache, H, W, C)
x = mx.concatenate([cache_x, x], axis=1) # Concat on Time
# Adjust effective padding: we already have history, so minimal padding needed
# For simplicity in this port, we assume manual causal padding logic from caller
# or simply slice
pad_t = 0 # Handled by cache
else:
# Pad only the "past"
if self.time_pad > 0:
# pad (0,0) for C, (0,0) for W, (0,0) for H, (pad, 0) for T, (0,0) for N
if self.mode == 'replicate':
# simple replicate pad for time (index 1)
first_frame = x[:, :1, ...]
padding = mx.repeat(first_frame, self.time_pad, axis=1)
x = mx.concatenate([padding, x], axis=1)
else:
# Zero pad
pad_width = [(0,0), (self.time_pad, 0), (self.spatial_padding[0], self.spatial_padding[0]), (self.spatial_padding[1], self.spatial_padding[1]), (0,0)]
x = mx.pad(x, pad_width)
else:
# Handle spatial padding if no time padding needed (1x1 convolution usually)
if self.spatial_padding[0] > 0 or self.spatial_padding[1] > 0:
pad_width = [(0,0), (0, 0), (self.spatial_padding[0], self.spatial_padding[0]), (self.spatial_padding[1], self.spatial_padding[1]), (0,0)]
x = mx.pad(x, pad_width)
out = self.conv(x)
return out
class Resample(nn.Module):
def __init__(self, dim, mode):
super().__init__()
self.dim = dim
self.mode = mode
if mode == "upsample2d":
self.conv = nn.Conv2d(dim, dim // 2, 3, padding=1)
elif mode == "upsample3d":
self.conv = nn.Conv2d(dim, dim // 2, 3, padding=1)
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
self.conv = nn.Conv2d(dim, dim, 3, stride=2) # Manual padding needed
elif mode == "downsample3d":
self.conv = nn.Conv2d(dim, dim, 3, stride=2)
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
def __call__(self, x, feat_cache=None, feat_idx=[0]):
# x: (N, T, H, W, C)
B, T, H, W, C = x.shape
if self.mode == "upsample3d":
if feat_cache is not None:
idx = feat_idx[0]
# Caching logic
if feat_cache[idx] is None:
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, -CACHE_T:, ...].astype(x.dtype)
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and isinstance(feat_cache[idx], mx.array):
cache_x = mx.concatenate([feat_cache[idx][:, -1:, ...], cache_x], axis=1)
if isinstance(feat_cache[idx], str) and feat_cache[idx] == "Rep":
# Time conv handles replicate
x_t = self.time_conv(x) # Handle mode='replicate' inside CausalConv3d if needed or pad manually
else:
x_t = self.time_conv(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
# Reshape for channel mix
# x_t: (B, T, H, W, C*2) -> split C
x_t = x_t.reshape(B, T, H, W, 2, C)
x = mx.stack([x_t[..., 0, :], x_t[..., 1, :]], axis=2) # (B, T, 2, H, W, C)
x = x.reshape(B, T * 2, H, W, C)
# Spatial Resampling
# Flatten time: (B*T, H, W, C)
T_curr = x.shape[1]
x_flat = x.reshape(-1, x.shape[2], x.shape[3], x.shape[4])
if "upsample" in self.mode:
# Nearest neighbor upsampling
# MLX doesn't have Upsample layer, utilize broadcast or repeat
# (B*T, H, W, C) -> (B*T, H*2, W*2, C)
x_up = mx.repeat(x_flat, 2, axis=1)
x_up = mx.repeat(x_up, 2, axis=2)
x_flat = self.conv(x_up)
elif "downsample" in self.mode:
# Pad for stride 2 conv
x_flat = mx.pad(x_flat, ((0,0), (0,1), (0,1), (0,0)))
x_flat = self.conv(x_flat)
x = x_flat.reshape(B, T_curr, x_flat.shape[1], x_flat.shape[2], -1)
if self.mode == "downsample3d":
# Time downsample
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x
feat_idx[0] += 1
else:
cache_x = x[:, -1:, ...].astype(x.dtype)
# concat time
inp = mx.concatenate([feat_cache[idx][:, -1:, ...], x], axis=1)
x = self.time_conv(inp)
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
class PixelShuffle3d(nn.Module):
def __init__(self, ff, hh, ww):
super().__init__()
self.ff, self.hh, self.ww = ff, hh, ww
def __call__(self, x):
# FIX: Chuyển đổi từ PixelShuffle (Upscale) sang PixelUnshuffle (Space-to-Depth)
# Input x: (B, T, H, W, C) -> Output: (B, T, H/hh, W/ww, C*hh*ww)
B, T, H, W, C = x.shape
# Kiểm tra kích thước
if H % self.hh != 0 or W % self.ww != 0:
raise ValueError(f"Input spatial dims ({H},{W}) must be divisible by patch size ({self.hh},{self.ww})")
# 1. Reshape để tách các khối pixel (block)
# Shape: (B, T, H//hh, hh, W//ww, ww, C)
# Lưu ý: Chúng ta bỏ qua ff (thời gian) trong logic reshape này vì ff=1 trong config mặc định,
# nhưng logic này vẫn hoạt động đúng về mặt không gian.
x = x.reshape(B, T, H // self.hh, self.hh, W // self.ww, self.ww, C)
# 2. Transpose để gom các chiều block (hh, ww) về phía Channel
# Hiện tại: (B, T, NewH, hh, NewW, ww, C)
# Target: (B, T, NewH, NewW, C, hh, ww)
# Permutation axes: 0, 1, 2, 4, 6, 3, 5
x = x.transpose(0, 1, 2, 4, 6, 3, 5)
# 3. Flatten 3 chiều cuối cùng vào Channel
# Output: (B, T, H//hh, W//ww, C * hh * ww)
x = x.reshape(B, T, H // self.hh, W // self.ww, -1)
return x
class PixelShuffle3dInverse(nn.Module):
# Used for unpatchify logic if needed, or simple reshapes
pass
# ==============================================================================
# Blocks
# ==============================================================================
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
# Using MLX layers (NHWC)
self.norm1 = RMS_norm_General(in_dim, channel_first=False, images=False)
self.act1 = nn.SiLU()
self.conv1 = CausalConv3d(in_dim, out_dim, 3, padding=1)
self.norm2 = RMS_norm_General(out_dim, channel_first=False, images=False)
self.act2 = nn.SiLU()
self.dropout = nn.Dropout(dropout)
self.conv2 = CausalConv3d(out_dim, out_dim, 3, padding=1)
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
def __call__(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
x = self.norm1(x)
x = self.act1(x)
# Conv1 with cache handling
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, ...].astype(x.dtype)
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([feat_cache[idx][:, -1:, ...], cache_x], axis=1) # naive pad
x = self.conv1(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
x = self.norm2(x)
x = self.act2(x)
x = self.dropout(x)
# Conv2 with cache handling
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, ...].astype(x.dtype)
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([feat_cache[idx][:, -1:, ...], cache_x], axis=1)
x = self.conv2(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv2(x)
return x + h
class AttentionBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.norm = RMS_norm_General(dim, channel_first=False)
# Conv2d for (B*T, H, W, C)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
def __call__(self, x):
# x: (B, T, H, W, C)
identity = x
B, T, H, W, C = x.shape
x = self.norm(x)
# Merge B, T for Conv2d
x_flat = x.reshape(B * T, H, W, C)
qkv = self.to_qkv(x_flat)
q, k, v = mx.split(qkv, 3, axis=-1)
# Reshape for Attention: (B*T, H*W, C)
q = q.reshape(B*T, H*W, C)
k = k.reshape(B*T, H*W, C)
v = v.reshape(B*T, H*W, C)
# Standard Scaled Dot Product Attention
# Single head attention essentially here as implemented in source
# or Multihead if dim allows, source says dim -> dim*3 so 1 head implicitly or handled by shape
# PyTorch F.scaled_dot_product_attention expects (N, L, E) or (N, H, L, E)
# Here we have (N, L, E).
scale = 1 / math.sqrt(C)
x_attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
x_attn = x_attn.reshape(B*T, H, W, C)
x_attn = self.proj(x_attn)
x_out = x_attn.reshape(B, T, H, W, C)
return x_out + identity
# ==============================================================================
# DiT / Transformer Components
# ==============================================================================
class SelfAttention(nn.Module):
def __init__(self, dim, num_heads, eps=1e-6):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
def __call__(self, x, freqs, f=None, h=None, w=None, is_stream=False, pre_cache_k=None, pre_cache_v=None):
B, L, D = x.shape
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
v = self.v(x)
# Apply RoPE
q = rope_apply(q, freqs, self.num_heads)
k = rope_apply(k, freqs, self.num_heads)
# Reshape for multi-head attention: (B, L, n_heads, head_dim)
q = q.reshape(B, L, self.num_heads, self.head_dim)
k = k.reshape(B, L, self.num_heads, self.head_dim)
v = v.reshape(B, L, self.num_heads, self.head_dim)
# Caching logic for stream
# Source uses Window Partitioning logic which is complex.
# For this MLX port, we implement standard Attention with optional caching.
# Ideally, we would reimplement the window partitioning, but full attention is simpler and faster on M-chips for small tiles.
# FlashAttention on M-chips is very efficient.
if pre_cache_k is not None:
k = mx.concatenate([pre_cache_k, k], axis=1)
v = mx.concatenate([pre_cache_v, v], axis=1)
# Compute Attention
# q: (B, L_q, H, D), k: (B, L_k, H, D)
# Transpose to (B, H, L, D) for MLX attention if needed?
# mx.fast.scaled_dot_product_attention expects (q, k, v)
# q: (B, H, L, D) or (B, L, H, D)? Docs say: (batch, heads, queries, keys) if using transpose.
# Actually mx.fast.sdpa is flexible.
# Let's perform standard transpose to (B, n_heads, L, head_dim)
q_t = q.transpose(0, 2, 1, 3)
k_t = k.transpose(0, 2, 1, 3)
v_t = v.transpose(0, 2, 1, 3)
x_out = mx.fast.scaled_dot_product_attention(q_t, k_t, v_t, scale=1.0/math.sqrt(self.head_dim))
# Transpose back: (B, L, n_heads, head_dim) -> (B, L, D)
x_out = x_out.transpose(0, 2, 1, 3).reshape(B, L, D)
out = self.o(x_out)
cache_k = k if is_stream else None
cache_v = v if is_stream else None
# Handle KV cache trimming if logic requires window sliding (simplified here)
if is_stream and cache_k.shape[1] > 2000: # heuristic limit from config
cache_k = cache_k[:, -1000:, ...]
cache_v = cache_v[:, -1000:, ...]
if is_stream:
return out, cache_k, cache_v
return out
class CrossAttention(nn.Module):
def __init__(self, dim, num_heads, eps=1e-6):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps)
self.norm_k = RMSNorm(dim, eps)
self.cache_k = None
self.cache_v = None
def init_cache(self, ctx):
# ctx: (B, S, D)
k = self.norm_k(self.k(ctx))
v = self.v(ctx)
# Pre-shape for attention: (B, n_heads, S, head_dim)
B, S, _ = ctx.shape
self.cache_k = k.reshape(B, S, self.num_heads, -1).transpose(0, 2, 1, 3)
self.cache_v = v.reshape(B, S, self.num_heads, -1).transpose(0, 2, 1, 3)
def clear_cache(self):
self.cache_k = None
self.cache_v = None
def __call__(self, x, context, is_stream=False):
B, L, D = x.shape
q = self.norm_q(self.q(x))
q = q.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
# Use cached K, V if available (typically context is static per prompt)
k = self.cache_k
v = self.cache_v
if k is None: # Fallback
k = self.norm_k(self.k(context)).reshape(B, -1, self.num_heads, D // self.num_heads).transpose(0, 2, 1, 3)
v = self.v(context).reshape(B, -1, self.num_heads, D // self.num_heads).transpose(0, 2, 1, 3)
x_out = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0/math.sqrt(D // self.num_heads))
x_out = x_out.transpose(0, 2, 1, 3).reshape(B, L, D)
return self.o(x_out)
class DiTBlock(nn.Module):
def __init__(self, dim, num_heads, ffn_dim, eps=1e-6):
super().__init__()
self.self_attn = SelfAttention(dim, num_heads, eps)
self.cross_attn = CrossAttention(dim, num_heads, eps)
self.norm1 = nn.LayerNorm(dim, eps=eps, affine=False)
self.norm2 = nn.LayerNorm(dim, eps=eps, affine=False)
self.norm3 = nn.LayerNorm(dim, eps=eps) # Affine True default
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim),
nn.GELU(), # Approximate tanh not in standard MLX nn yet, standard GELU is fine or custom
nn.Linear(ffn_dim, dim)
)
self.modulation = mx.random.normal((1, 6, dim)) / dim**0.5
def __call__(self, x, context, t_mod, freqs, f, h, w, is_stream=False, pre_cache_k=None, pre_cache_v=None, **kwargs):
# Modulation
# t_mod: (1, 6, dim)
mod = self.modulation + t_mod
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mx.split(mod, 6, axis=1)
# Self Attention
norm1_x = self.norm1(x)
norm1_x = modulate(norm1_x, shift_msa, scale_msa)
attn_out = self.self_attn(norm1_x, freqs, f, h, w, is_stream, pre_cache_k, pre_cache_v)
if is_stream:
attn_out, cache_k, cache_v = attn_out
x = x + gate_msa * attn_out
# Cross Attention
x = x + self.cross_attn(self.norm3(x), context, is_stream)
# FFN
norm2_x = self.norm2(x)
norm2_x = modulate(norm2_x, shift_mlp, scale_mlp)
x = x + gate_mlp * self.ffn(norm2_x)
if is_stream:
return x, cache_k, cache_v
return x
class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps):
super().__init__()
self.norm = nn.LayerNorm(dim, eps=eps, affine=False)
self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
self.modulation = mx.random.normal((1, 2, dim)) / dim**0.5
def __call__(self, x, t_mod):
mod = self.modulation + t_mod
shift, scale = mx.split(mod, 2, axis=1)
x = self.norm(x)
x = x * (1 + scale) + shift
return self.head(x)
# ==============================================================================
# Models: WanModel, VAE
# ==============================================================================
class Clamp(nn.Module):
def __call__(self, x):
return mx.tanh(x / 3) * 3
class IdentityConv2d(nn.Module):
# Used in TAEHV deep layers
def __init__(self, channels, kernel_size=3):
super().__init__()
# In MLX, identity conv can be initialized with a dirac delta weight
# Weight shape: (H, W, In, Out)
self.conv = nn.Conv2d(channels, channels, kernel_size, padding=kernel_size//2, bias=False)
# Manually init weights to Identity (Dirac)
w = mx.zeros((kernel_size, kernel_size, channels, channels))
center = kernel_size // 2
for i in range(channels):
w[center, center, i, i] = 1.0
self.conv.weight = w
def __call__(self, x):
return self.conv(x)
# --- VAE Components ---
class Encoder3d(nn.Module):
def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_downsample=[True, True, False], dropout=0.0):
super().__init__()
self.dim = dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# Initial Conv: Input 3 channels -> dims[0]
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
self.downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# ResBlocks
for _ in range(num_res_blocks):
self.downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
self.downsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# Downsample Layer
if i != len(dim_mult) - 1:
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
self.downsamples.append(Resample(out_dim, mode=mode))
scale /= 2.0
# Middle
self.middle = [
ResidualBlock(out_dim, out_dim, dropout),
AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim, dropout)
]
# Head
self.head_norm = RMS_norm_General(out_dim, channel_first=False, images=False)
self.head_act = nn.SiLU()
self.head_conv = CausalConv3d(out_dim, z_dim, 3, padding=1)
def __call__(self, x, feat_cache=None, feat_idx=[0]):
# x: (N, T, H, W, C)
# Conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, ...].astype(x.dtype)
# Causal Pad logic
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([feat_cache[idx][:, -1:, ...], cache_x], axis=1)
x = self.conv1(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
# Downsamples
for layer in self.downsamples:
if isinstance(layer, (ResidualBlock, Resample)) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
# Middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
# Head
x = self.head_norm(x)
x = self.head_act(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, ...].astype(x.dtype)
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([feat_cache[idx][:, -1:, ...], cache_x], axis=1)
x = self.head_conv(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.head_conv(x)
return x
class Decoder3d(nn.Module):
def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_upsample=[False, True, True], dropout=0.0):
super().__init__()
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2 ** (len(dim_mult) - 2)
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
self.middle = [
ResidualBlock(dims[0], dims[0], dropout),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout)
]
self.upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# Channel adjustment logic from original code
if i == 1 or i == 2 or i == 3:
in_dim = in_dim // 2
# Blocks
for _ in range(num_res_blocks + 1):
self.upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
self.upsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# Upsample Layer
if i != len(dim_mult) - 1:
mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
self.upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0
self.head_norm = RMS_norm_General(out_dim, channel_first=False, images=False)
self.head_act = nn.SiLU()
self.head_conv = CausalConv3d(out_dim, 3, 3, padding=1) # Output RGB
def __call__(self, x, feat_cache=None, feat_idx=[0]):
# x: (N, T, H, W, C)
# Conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, ...].astype(x.dtype)
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([feat_cache[idx][:, -1:, ...], cache_x], axis=1)
x = self.conv1(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
# Middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
# Upsamples
for layer in self.upsamples:
if isinstance(layer, (ResidualBlock, Resample)) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
# Head
x = self.head_norm(x)
x = self.head_act(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, ...].astype(x.dtype)
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([feat_cache[idx][:, -1:, ...], cache_x], axis=1)
x = self.head_conv(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.head_conv(x)
return x
class WanModel(nn.Module):
def __init__(self, dim, in_dim, ffn_dim, out_dim, text_dim, freq_dim, eps, patch_size, num_heads, num_layers):
super().__init__()
self.dim = dim
self.freq_dim = freq_dim
self.patch_size = patch_size
self.num_heads = num_heads
# Embeddings
self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim),
nn.GELU(),
nn.Linear(dim, dim)
)
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim)
)
self.time_projection = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, dim * 6)
)
self.blocks = [DiTBlock(dim, num_heads, ffn_dim, eps) for _ in range(num_layers)]
self.head = Head(dim, out_dim, patch_size, eps)
# Freqs
self.freqs = precompute_freqs_cis_3d(dim // num_heads)
# LQ Projector placeholder
self.LQ_proj_in = None
def patchify(self, x):
x = self.patch_embedding(x)
grid_size = (x.shape[1], x.shape[2], x.shape[3])
x = x.reshape(x.shape[0], -1, x.shape[-1])
return x, grid_size
def unpatchify(self, x, grid_size):
N, L, _ = x.shape
f, h, w = grid_size
px, py, pz = self.patch_size
x = x.reshape(N, f, h, w, pz, py, px, -1)
x = x.transpose(0, 1, 4, 2, 5, 3, 6, 7)
x = x.reshape(N, f*pz, h*py, w*px, -1)
return x
def __call__(self, x, timestep, context, LQ_latents=None, is_stream=False, pre_cache_k=None, pre_cache_v=None, **kwargs):
# Time embeddings
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).reshape(1, 6, self.dim)
# Patchify
x, (f, h, w) = self.patchify(x)
# --- FIX FREQS LOGIC (ROBUST V3) ---
# Hàm nội bộ để lấy embedding an toàn, tự động tính lại nếu thiếu
def get_freq_emb(freq_cache, idx, length, dim_sub):
# freq_cache: (MaxLen, DimSub, 2)
# Nếu length > MaxLen, tính lại on-the-fly
if length > freq_cache.shape[0]:
theta = 10000.0
freqs = 1.0 / (theta ** (mx.arange(0, dim_sub * 2, 2).astype(mx.float32) / (dim_sub * 2)))
t_seq = mx.arange(length).astype(mx.float32)
freqs_outer = mx.outer(t_seq, freqs)
freqs_cos = mx.cos(freqs_outer)
freqs_sin = mx.sin(freqs_outer)
emb = mx.stack([freqs_cos, freqs_sin], axis=-1)
return emb.reshape(length, -1)
else:
return freq_cache[:length].reshape(length, -1)
head_dim = self.dim // self.num_heads
dim_h = head_dim // 3
dim_w = head_dim // 3
dim_f = head_dim - dim_h - dim_w
# Lấy embeddings
f_emb = get_freq_emb(self.freqs[0], 0, f, dim_f // 2)
h_emb = get_freq_emb(self.freqs[1], 1, h, dim_h // 2)
w_emb = get_freq_emb(self.freqs[2], 2, w, dim_w // 2)
# Reshape & Broadcast
f_broad = mx.broadcast_to(f_emb.reshape(f, 1, 1, -1), (f, h, w, f_emb.shape[-1]))
h_broad = mx.broadcast_to(h_emb.reshape(1, h, 1, -1), (f, h, w, h_emb.shape[-1]))
w_broad = mx.broadcast_to(w_emb.reshape(1, 1, w, -1), (f, h, w, w_emb.shape[-1]))
# Concatenate
freqs = mx.concatenate([f_broad, h_broad, w_broad], axis=-1)
freqs = freqs.reshape(f*h*w, 1, -1)
# -------------------------------------
# Blocks
out_caches_k = []
out_caches_v = []
for i, block in enumerate(self.blocks):
if LQ_latents is not None and i < len(LQ_latents):
x = x + LQ_latents[i]
pck = pre_cache_k[i] if pre_cache_k else None
pcv = pre_cache_v[i] if pre_cache_v else None
if is_stream:
x, ck, cv = block(x, context, t_mod, freqs, f, h, w, is_stream=True, pre_cache_k=pck, pre_cache_v=pcv)
out_caches_k.append(ck)
out_caches_v.append(cv)
else:
x = block(x, context, t_mod, freqs, f, h, w)
# Head
t_for_head = t.reshape(1, 1, -1) if t.ndim==2 else t
x = self.head(x, t_for_head)
x = self.unpatchify(x, (f, h, w))
if is_stream:
return x, out_caches_k, out_caches_v
return x
class WanVideoVAE(nn.Module):
def __init__(self, z_dim=16, dim=96):
super().__init__()
# Standard configs for WanVAE
dim_mult = [1, 2, 4, 4]
num_res_blocks = 2
attn_scales = []
temperal_downsample = [False, True, True]
temperal_upsample = temperal_downsample[::-1]
self.z_dim = z_dim
# Encoder output is z_dim * 2 (mean + logvar)
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, temperal_downsample)
self.quant_conv = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.post_quant_conv = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, temperal_upsample)
def encode(self, x):
# x: (N, T, H, W, 3)
h = self.encoder(x)
moments = self.quant_conv(h)
mean, logvar = mx.split(moments, 2, axis=-1)
return mean, logvar
def decode(self, z):
# z: (N, T, H, W, z_dim)
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
# Helper to count cache size needed
def get_cache_size(self, is_encoder=False):
# Rough heuristic count of CausalConv layers
# In real usage, one would traverse the graph.
# For this specific architecture:
if is_encoder: return 50 # Safe upper bound
return 50 # Safe upper bound
# --- TAEHV (FlashVSR Decoder) Components ---
class Buffer_LQ4x_Proj(nn.Module):
def __init__(self, in_dim, out_dim, layer_num=30):
super().__init__()
self.hidden_dim1, self.hidden_dim2, self.layer_num = 2048, 3072, layer_num
self.pixel_shuffle = PixelShuffle3d(1, 16, 16)
# Input channel calc: in_dim * 1 * 16 * 16
in_c = in_dim * 256
# Note: Padding is handled manually in stream_forward to satisfy MLX Kernel constraints
self.conv1 = CausalConv3d(in_c, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(0, 1, 1), mode='replicate')
self.norm1 = RMS_norm_General(self.hidden_dim1, channel_first=False, images=False)
self.act1 = nn.SiLU()
self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(0, 1, 1), mode='replicate')
self.norm2 = RMS_norm_General(self.hidden_dim2, channel_first=False, images=False)
self.act2 = nn.SiLU()
self.linear_layers = [nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)]
self.cache = {"conv1": None, "conv2": None}
self.clip_idx = 0
def stream_forward(self, video_clip):
# FIX: Ensure input Time dimension >= Kernel Time dimension (4)
if self.clip_idx == 0:
# First frame: Pad generously at start (Replicate)
# Input: [0]. Target Conv Input: [0,0,0,0]
first = video_clip[:, :1, ...]
# Pad 3 frames at start so total is 4
video_clip_padded = mx.concatenate([mx.repeat(first, 3, axis=1), video_clip], axis=1)
x = self.pixel_shuffle(video_clip_padded) # (B, 4, ...)
# Save raw input to cache (last 2 frames) for next step
self.cache["conv1"] = x[:, -CACHE_T:, ...]
x = self.conv1(x) # Input T=4, Kernel=4, Stride=2 -> Out T=1
x = self.act1(self.norm1(x))
# Save feature to cache 2
self.cache["conv2"] = x[:, -CACHE_T:, ...]
self.clip_idx += 1
return None
else:
x = self.pixel_shuffle(video_clip)
# --- CONV 1 ---
# Retrieve cache
prev_cache = self.cache["conv1"]
# Update cache for NEXT step immediately with current raw input (Last CACHE_T frames)
combined_raw = mx.concatenate([prev_cache, x], axis=1)
self.cache["conv1"] = combined_raw[:, -CACHE_T:, ...]
# Construct Input for Conv: Need T >= 4
inp1 = combined_raw
if inp1.shape[1] < 4:
pad_amt = 4 - inp1.shape[1]
first = inp1[:, :1, ...]
inp1 = mx.concatenate([mx.repeat(first, pad_amt, axis=1), inp1], axis=1)
# Run Conv
x = self.conv1(inp1)
x = self.act1(self.norm1(x))
# --- CONV 2 ---
prev_cache_2 = self.cache["conv2"]
combined_feat = mx.concatenate([prev_cache_2, x], axis=1)
# Update cache
self.cache["conv2"] = combined_feat[:, -CACHE_T:, ...]
# Pad for Conv: Need T >= 4
inp2 = combined_feat
if inp2.shape[1] < 4:
pad_amt = 4 - inp2.shape[1]
first = inp2[:, :1, ...]
inp2 = mx.concatenate([mx.repeat(first, pad_amt, axis=1), inp2], axis=1)
x = self.conv2(inp2)
x = self.act2(self.norm2(x))
# Linear projections
B, T, H, W, C = x.shape
x_flat = x.reshape(B, T*H*W, C)
out = [l(x_flat) for l in self.linear_layers]
self.clip_idx += 1
return out
class MemBlock(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
# Note: FlashVSR uses Conv2d for these blocks treating Time as Batch for spatial mixing
# Inputs will be (N*T, H, W, C)
self.conv_main_0 = nn.Conv2d(n_in * 2, n_out, 3, padding=1)
self.conv_main_1 = nn.Conv2d(n_out, n_out, 3, padding=1)
self.conv_main_2 = nn.Conv2d(n_out, n_out, 3, padding=1)
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else None
self.act = nn.ReLU()
def __call__(self, x, past):
# x: (B, H, W, C)
# past: (B, H, W, C) - Memory from previous timestep
# Concat along Channel
inp = mx.concatenate([x, past], axis=-1)
h = self.act(self.conv_main_0(inp))
h = self.act(self.conv_main_1(h))
h = self.conv_main_2(h)
skip_x = self.skip(x) if self.skip is not None else x
return self.act(h + skip_x)
class TGrow(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False)
def __call__(self, x):
# x: (B*T, H, W, C) -> (B*T, H, W, C*stride)
x = self.conv(x)
if self.stride > 1:
# We need to expand the time dimension.
# Current layout: (NT, H, W, C*stride)
# Target: (NT * stride, H, W, C)
NT, H, W, C_full = x.shape
C = C_full // self.stride
# Reshape to separate stride
x = x.reshape(NT, H, W, self.stride, C)
# Permute to bring stride next to NT
x = x.transpose(0, 3, 1, 2, 4) # (NT, stride, H, W, C)
x = x.reshape(NT * self.stride, H, W, C)
return x
class PixelShuffle3dTAEHV(nn.Module):
def __init__(self, ff, hh, ww):
super().__init__()
self.ff, self.hh, self.ww = ff, hh, ww
def __call__(self, x):
# Input x: (B, T, H, W, C) where H,W are small, C is large
# We want to reverse pixel shuffle into channels?
# Wait, the original code is:
# rearrange(x, "b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w")
# This is essentially a "Pixel Unshuffle" or "Space-to-Depth".
B, T, H, W, C = x.shape
# Target logic: Compress dimensions into channels
# T -> T_new * ff, H -> H_new * hh ...
# But here inputs are likely (B, C, F, H, W) in Torch.
# Original: Input video (High Res) -> Pixel Unshuffle -> Latent
# MLX Input: (B, T_full, H_full, W_full, C)
# Check divisibility
pad_t = (self.ff - T % self.ff) % self.ff
if pad_t > 0:
first = x[:, :1, ...]
x = mx.concatenate([mx.repeat(first, pad_t, axis=1), x], axis=1)
T = x.shape[1]
# Reshape for unshuffle
# (B, T//ff, ff, H//hh, hh, W//ww, ww, C)
x = x.reshape(B, T//self.ff, self.ff, H//self.hh, self.hh, W//self.ww, self.ww, C)
# Transpose to gather blocks into channel
# Target: (B, T//ff, H//hh, W//ww, C * ff * hh * ww)
# Permute: (0, 1, 3, 5, 7, 2, 4, 6)
x = x.transpose(0, 1, 3, 5, 7, 2, 4, 6)
# Flatten
x = x.reshape(B, T//self.ff, H//self.hh, W//self.ww, -1)
return x
class TAEHV(nn.Module):
image_channels = 3
def __init__(self, channels=[256, 128, 64, 64], latent_channels=16, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True)):
super().__init__()
self.latent_channels = latent_channels
self.frames_to_trim = 2 ** sum(decoder_time_upscale) - 1
n_f = channels
# Building the layers list (Manual construction to match PyTorch Sequential)
self.layers = []
# --- Block 0 ---
self.layers.append(Clamp())
self.layers.append(nn.Conv2d(self.latent_channels, n_f[0], 3, padding=1))
self.layers.append(nn.ReLU())
self.layers.append(MemBlock(n_f[0], n_f[0]))
self.layers.append(MemBlock(n_f[0], n_f[0]))
self.layers.append(MemBlock(n_f[0], n_f[0]))
# Upsample 0
if decoder_space_upscale[0]:
# In MLX, use UpSampling2d or Repeat
self.layers.append("upsample_space_2x")
self.layers.append(TGrow(n_f[0], 1)) # Stride 1 (No T upsample)
self.layers.append(nn.Conv2d(n_f[0], n_f[1], 3, padding=1, bias=False))
# --- Block 1 ---
self.layers.append(MemBlock(n_f[1], n_f[1]))
self.layers.append(MemBlock(n_f[1], n_f[1]))
self.layers.append(MemBlock(n_f[1], n_f[1]))
# Upsample 1
if decoder_space_upscale[1]:
self.layers.append("upsample_space_2x")
self.layers.append(TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1))
self.layers.append(nn.Conv2d(n_f[1], n_f[2], 3, padding=1, bias=False))
# --- Block 2 ---
self.layers.append(MemBlock(n_f[2], n_f[2]))
self.layers.append(MemBlock(n_f[2], n_f[2]))
self.layers.append(MemBlock(n_f[2], n_f[2]))
# Upsample 2
if decoder_space_upscale[2]:
self.layers.append("upsample_space_2x")
self.layers.append(TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1))
self.layers.append(nn.Conv2d(n_f[2], n_f[3], 3, padding=1, bias=False))
self.layers.append(nn.ReLU())
self.layers.append(nn.Conv2d(n_f[3], TAEHV.image_channels, 3, padding=1))
# Apply "Identity Deepen" logic (Inserting Identity Conv + ReLU after existing ReLUs)
# To simplify porting, we assume the provided weights match the deepened structure
# or we manually construct the exact list.
# Here we apply the expansion dynamically similar to PyTorch code.
self.final_layers = []
for layer in self.layers:
self.final_layers.append(layer)
# Check if layer is ReLU
is_relu = isinstance(layer, nn.ReLU)
if is_relu:
# Need to find channel count of previous layer
prev = self.final_layers[-2]
C = None
if isinstance(prev, nn.Conv2d): C = prev.weight.shape[-1] # NHWC: output is last
elif isinstance(prev, MemBlock): C = prev.conv_main_2.weight.shape[-1]
if C is not None:
# Add Identity Block (how_many_each=1)
self.final_layers.append(IdentityConv2d(C))
self.final_layers.append(nn.ReLU())
self.pixel_shuffle = PixelShuffle3dTAEHV(4, 8, 8)
self.mem_state = [None] * len(self.final_layers)
def clean_mem(self):
self.mem_state = [None] * len(self.final_layers)
def decode_video(self, x, cond=None):
# x: (B, T, H, W, C) - Latents
# cond: (B, T_full, H_full, W_full, 3) - LQ Video input
trim_flag = self.mem_state[0] is None # Simple check if it's first run
if cond is not None:
# Preprocess condition (pixel unshuffle)
cond_feat = self.pixel_shuffle(cond)
# x should match T dimension logic or be compatible
# x is latents, cond_feat is reshaped LQ
# Concat along channel
x = mx.concatenate([cond_feat, x], axis=-1)
# Apply model with memblocks logic
# For MLX, we iterate Time steps?
# Since MemBlocks rely on "past", we must process sequentially along Time dim,
# OR if we have full sequence, we can use scan or simple loop.
# But wait, TGrow expands Time.
# The logic in original code `apply_model_with_memblocks` handles splitting T.
B, T, H, W, C = x.shape
# Flatten B*T for spatial layers
# But MemBlock needs T separation.
# We process chunk by chunk (frame by frame effectively relative to layer stride)
# Initial chunking
current_x = x # (B, T, H, W, C)
# We iterate through layers.
# If layer is Spatial (Conv2d, ReLU), flatten T.
# If layer is MemBlock, unflatten T, apply recurrent, flatten back.
# If layer is TGrow, expand T.
for i, layer in enumerate(self.final_layers):
if layer == "upsample_space_2x":
# Nearest Upsample
# current_x: (B, T, H, W, C)
current_x = mx.repeat(current_x, 2, axis=2)
current_x = mx.repeat(current_x, 2, axis=3)
continue
# Helper to reshape for spatial op: (B*T, H, W, C)
B_curr, T_curr, H_curr, W_curr, C_curr = current_x.shape
if isinstance(layer, MemBlock):
# Need explicit time iteration or parallel shift
# Original code: pad with 1 zero frame at start, slice off end.
# (Equivalent to prev_frame)
# Check memory state
if self.mem_state[i] is None:
# First chunk: prev is 0
prev = mx.zeros_like(current_x[:, 0:1, ...]) # (B, 1, H, W, C)
# Pad x at T=-1 with 0
x_padded = mx.concatenate([prev, current_x[:, :-1, ...]], axis=1)
# Save last frame for next chunk
self.mem_state[i] = current_x[:, -1:, ...]
else:
# Subsequent chunk
prev_mem = self.mem_state[i]
# Pad x at T=-1 with prev_mem
x_padded = mx.concatenate([prev_mem, current_x[:, :-1, ...]], axis=1)
# Save last frame
self.mem_state[i] = current_x[:, -1:, ...]
# Apply MemBlock spatially
# Flatten T
x_in_flat = current_x.reshape(-1, H_curr, W_curr, C_curr)
x_past_flat = x_padded.reshape(-1, H_curr, W_curr, C_curr)
out_flat = layer(x_in_flat, x_past_flat)
current_x = out_flat.reshape(B_curr, T_curr, H_curr, W_curr, -1)
elif isinstance(layer, TGrow):
# Flatten
x_flat = current_x.reshape(-1, H_curr, W_curr, C_curr)
out_flat = layer(x_flat)
# Reshape back handled inside TGrow logic?
# My TGrow implementation above returns expanded time if stride > 1
# But it expects (BT, ...) input.
# If stride > 1, output of layer(x_flat) is (BT*stride, H, W, C)
# We need to reshape to (B, T*stride, H, W, C)
new_stride = layer.stride
current_x = out_flat.reshape(B_curr, T_curr * new_stride, H_curr, W_curr, -1)
else:
# Standard Spatial Layer
x_flat = current_x.reshape(-1, H_curr, W_curr, C_curr)
out_flat = layer(x_flat)
current_x = out_flat.reshape(B_curr, T_curr, H_curr, W_curr, -1)
# Trim frames if first run
if trim_flag and self.frames_to_trim > 0:
current_x = current_x[:, self.frames_to_trim:, ...]
return current_x
# ==============================================================================
# Weight Loading & Conversion
# ==============================================================================
def map_torch_to_mlx(key, val):
# Transpose weights for MLX layout
if "conv" in key and val.ndim == 4: # Conv2d (Out, In, H, W) -> (Out, H, W, In)
val = val.transpose(0, 2, 3, 1)
elif "conv" in key and val.ndim == 5: # Conv3d (Out, In, D, H, W) -> (Out, D, H, W, In)
val = val.transpose(0, 2, 3, 4, 1)
elif "linear" in key or "proj" in key:
if val.ndim == 2: # (Out, In) -> (Out, In) -- MLX Linear weights are (In, Out) usually transposed in loading
val = val.transpose(1, 0)
return key, mx.array(val.numpy())
def load_weights(model, ckpt_path):
print(f"Loading weights from {ckpt_path}...")
if ckpt_path.endswith(".safetensors"):
weights = {}
with safe_open(ckpt_path, framework="numpy") as f:
for k in f.keys():
weights[k] = f.get_tensor(k)
else:
# Load torch pth using numpy
import torch
weights = torch.load(ckpt_path, map_location="cpu", weights_only=True)
weights = {k: v.numpy() for k, v in weights.items()}
# Conversion loop
mlx_weights = {}
for k, v in weights.items():
# Heuristic mapping
new_k = k
# Map specific PyTorch keys to MLX model structure
# Example: blocks.0.attn1.to_q.weight -> blocks[0].self_attn.q.weight
# This requires a detailed mapping dictionary based on the implemented class names
# Generic Transpose
if v.ndim == 4: # Conv2d
v = v.transpose(0, 2, 3, 1)
elif v.ndim == 5: # Conv3d
v = v.transpose(0, 2, 3, 4, 1)
elif v.ndim == 2 and "weight" in k: # Linear
v = v.transpose(1, 0)
mlx_weights[new_k] = mx.array(v)
# Load into model (partial loading allowed)
model.load_weights(list(mlx_weights.items()), strict=False)
mx.eval(model.parameters())
print("Weights loaded.")
# ==============================================================================
# Pipeline
# ==============================================================================
class FlashVSRTinyPipeline:
def __init__(self, device="gpu"):
self.dit = None
self.vae = None
self.lq_proj = None
def load_models(self):
# Config WanModel
self.dit = WanModel(
dim=5120, in_dim=16, ffn_dim=13824, out_dim=16, text_dim=4096,
freq_dim=256, eps=1e-6, patch_size=(1,2,2), num_heads=40, num_layers=40
)
self.lq_proj = Buffer_LQ4x_Proj(in_dim=3, out_dim=5120, layer_num=1)
# Compile
self.step_fn = mx.compile(self.dit)
def __call__(self, video_path, num_frames=16, tile_size=128):
# 1. Read Video
print("Reading video...")
frames, fps = self.read_video(video_path, num_frames)
frames = mx.array(frames)
frames = frames[None, ...]
# 2. Latent Projection
print("Running LQ Projection...")
lq_latents_raw = []
for i in tqdm(range(frames.shape[1])):
clip = frames[:, i:i+1, ...]
lat = self.lq_proj.stream_forward(clip)
if lat is not None:
lq_latents_raw.append(lat)
final_lq_latents = []
if lq_latents_raw:
num_layers = len(lq_latents_raw[0])
for layer_idx in range(num_layers):
time_steps_for_layer = [step[layer_idx] for step in lq_latents_raw]
concatenated = mx.concatenate(time_steps_for_layer, axis=1)
final_lq_latents.append(concatenated)
if not final_lq_latents:
print("Warning: No latents produced (video too short?). Returning None.")
return None
# 3. DiT Inference
print("Running DiT...")
ref_latent = final_lq_latents[0] # (B, L, 5120)
B, L, D = ref_latent.shape
# Noise input
x_start = mx.random.normal((B, L, 5120))
# --- FIX: CREATE DUMMY CONTEXT ---
# WanModel text_dim = 4096 (được định nghĩa trong load_models)
# Tạo một context rỗng (ví dụ: chuỗi dài 512 tokens, dim 4096)
# Trong thực tế, đây là output của T5 Encoder.
text_dim = 4096
context_len = 512
dummy_context = mx.zeros((B, context_len, text_dim))
# ---------------------------------
# Patching Patchify
original_patchify = self.dit.patchify
def smart_patchify(inp):
if inp.ndim == 3:
return inp, (1, 1, inp.shape[1])
else:
return original_patchify(inp)
self.dit.patchify = smart_patchify
t = mx.array([1000.0])
# Truyền dummy_context vào thay vì None
out = self.step_fn(x_start, t, context=dummy_context, LQ_latents=final_lq_latents, is_stream=True)
self.dit.patchify = original_patchify
# 4. Decode
print("Decoding (Mock)...")
return out
def read_video(self, path, num_frames):
reader = imageio.get_reader(path)
frames = []
for i, im in enumerate(reader):
if i >= num_frames: break
frames.append(im)
return np.array(frames) / 255.0, 30
# ==============================================================================
# Main
# ==============================================================================
def main():
if not os.path.exists(args.output_folder):
os.makedirs(args.output_folder)
# Check for M-series
print(f"Running on Device: {mx.default_device()}")
pipe = FlashVSRTinyPipeline()
pipe.load_models() # Uncomment when weights are present
# Mock run
print("Initializing pipeline (Mock mode for code verification)...")
# Example flow
pipe(args.input)
print("Done. (Note: Weights are needed to produce actual video)")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment