Created
June 20, 2025 07:26
-
-
Save laksjdjf/8e615f312aec4bee2236df7602f335b8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
import gradio as gr | |
HUNYUAN_VIDEO_DEFAULT = [ | |
"Hunyuan Video 544p", | |
"1.0, 1.06971, 1.29073, 1.11245, 1.09596, 1.05233, 1.01415, 1.05672, 1.00848, 1.03632, 1.02974, 1.00984, 1.03028, 1.00681, 1.06614, 1.05022, 1.02592, 1.01776, 1.02985, 1.00726, 1.03727, 1.01502, 1.00992, 1.03371, 0.9976, 1.02742, 1.0093, 1.01869, 1.00815, 1.01461, 1.01152, 1.03082, 1.0061, 1.02162, 1.01999, 0.99063, 1.01186, 1.0217, 0.99947, 1.01711, 0.9904, 1.00258, 1.00878, 0.97039, 0.97686, 0.94315, 0.97728, 0.91154, 0.86139, 0.76592", | |
50, | |
0.24, | |
-0.01, | |
6, | |
0.2, | |
] | |
FRAMEPACK_DEFAULT = [ | |
"FramePack", | |
"1.0, 1.0754, 1.27807, 1.11596, 1.09504, 1.05188, 1.00844, 1.05779, 1.00657, 1.04142, 1.03101, 1.00679, 1.02556, 1.00908, 1.06949, 1.05438, 1.02214, 1.02321, 1.03019, 1.00779, 1.03381, 1.01886, 1.01161, 1.02968, 1.00544, 1.02822, 1.00689, 1.02119, 1.0105, 1.01044, 1.01572, 1.02972, 1.0094, 1.02368, 1.0226, 0.98965, 1.01588, 1.02146, 1.0018, 1.01687, 0.99436, 1.00283, 1.01139, 0.97122, 0.98251, 0.94513, 0.97656, 0.90943, 0.85703, 0.75456", | |
50, | |
0.1, | |
0.07, | |
3, | |
0.2, | |
] | |
FLUX_DEFAULT = [ | |
"Flux", | |
"1.0, 1.21094, 1.11719, 1.07812, 1.0625, 1.03906, 1.03125, 1.03906, 1.02344, 1.03125, 1.02344, 0.98047, 1.01562, 1.00781, 1.0, 1.00781, 1.0, 1.00781, 1.0, 1.0, 0.99609, 0.99609, 0.98047, 0.98828, 0.96484, 0.95703, 0.93359, 0.89062", | |
28, | |
0.24, | |
-0.01, | |
5, | |
0.1, | |
] | |
WAN_DEFAULT = [ | |
"WAN2.1 14B", | |
"1.0, 1.0, 1.02504, 1.03017, 1.00025, 1.00251, 0.9985, 0.99962, 0.99779, 0.99771, 0.9966, 0.99658, 0.99482, 0.99476, 0.99467, 0.99451, 0.99664, 0.99656, 0.99434, 0.99431, 0.99533, 0.99545, 0.99468, 0.99465, 0.99438, 0.99434, 0.99516, 0.99517, 0.99384, 0.9938, 0.99404, 0.99401, 0.99517, 0.99516, 0.99409, 0.99408, 0.99428, 0.99426, 0.99347, 0.99343, 0.99418, 0.99416, 0.99271, 0.99269, 0.99313, 0.99311, 0.99215, 0.99215, 0.99218, 0.99215, 0.99216, 0.99217, 0.99163, 0.99161, 0.99138, 0.99135, 0.98982, 0.9898, 0.98996, 0.98995, 0.9887, 0.98866, 0.98772, 0.9877, 0.98767, 0.98765, 0.98573, 0.9857, 0.98501, 0.98498, 0.9838, 0.98376, 0.98177, 0.98173, 0.98037, 0.98035, 0.97678, 0.97677, 0.97546, 0.97543, 0.97184, 0.97183, 0.96711, 0.96708, 0.96349, 0.96345, 0.95629, 0.95625, 0.94926, 0.94929, 0.93964, 0.93961, 0.92511, 0.92504, 0.90693, 0.90678, 0.8796, 0.87945, 0.86111, 0.86189", | |
100, | |
0.24, | |
-0.01, | |
6, | |
0.2, | |
] | |
def nearest_interp(src_array, target_length): | |
src_length = len(src_array) | |
if target_length == 1: | |
return np.array([src_array[-1]]) | |
scale = (src_length - 1) / (target_length - 1) | |
mapped_indices = np.round(np.arange(target_length) * scale).astype(int) | |
return np.array(src_array)[mapped_indices].tolist() | |
def magcache(src_array, mag_single_step_threshold, target_length, threshold, k, retention): | |
skip_steps = [] | |
accumulated_ratio = 1.0 | |
accumulated_step = 0 | |
accumulated_error = 0.0 | |
mag_single_step_threshold = mag_single_step_threshold if mag_single_step_threshold >= 0 else 100.0 | |
k = k if k >= 0 else target_length+1 | |
target_array = nearest_interp(src_array, target_length) | |
accumulated_errors = [] | |
errors = [] | |
for i in range(target_length): | |
accumulated_ratio *= target_array[i] | |
accumulated_step += 1 | |
error = abs(target_array[i] - 1) | |
accumulated_error += error | |
errors.append(error) | |
accumulated_errors.append(accumulated_error) | |
if i >= max(1, int(retention * target_length)) and accumulated_error <= threshold and accumulated_step <= k and error <= mag_single_step_threshold: | |
skip_steps.append(i) | |
else: | |
accumulated_ratio = 1.0 | |
accumulated_step = 0 | |
accumulated_error = 0.0 | |
return skip_steps, accumulated_errors, errors | |
def visualize_magcache(hidden_text, mag_ratios, target_length, threshold, mag_single_step_threshold, k, retention): | |
mag_ratios = list(map(float, mag_ratios.split(','))) | |
skip_steps, accumulated_errors, errors = magcache(mag_ratios, mag_single_step_threshold, target_length, threshold, k, retention) | |
df = pd.DataFrame({ | |
'x': np.arange(len(accumulated_errors)).tolist() + np.arange(len(errors)).tolist() + [0, target_length], | |
'y': accumulated_errors + errors + [threshold, threshold], | |
"label": ["Accumulated Errors"] * len(accumulated_errors) + ["Errors"] * len(errors) + ["Threshold"] * 2 | |
}) | |
return f"Skip steps: {skip_steps}\n Total {len(skip_steps)}({len(skip_steps)/target_length:.0%}) steps skiped.", df | |
# Gradio UI | |
def gradio_interface(): | |
with gr.Blocks() as demo: | |
gr.Markdown("""# MagCache Visualization | |
Enter the parameters below to compute skip steps and visualize accumulated errors. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
mag_ratios_input = gr.Textbox(value = "1.0, 1.06971, 1.29073, 1.11245, 1.09596, 1.05233, 1.01415, 1.05672, 1.00848, 1.03632, 1.02974, 1.00984, 1.03028, 1.00681, 1.06614, 1.05022, 1.02592, 1.01776, 1.02985, 1.00726, 1.03727, 1.01502, 1.00992, 1.03371, 0.9976, 1.02742, 1.0093, 1.01869, 1.00815, 1.01461, 1.01152, 1.03082, 1.0061, 1.02162, 1.01999, 0.99063, 1.01186, 1.0217, 0.99947, 1.01711, 0.9904, 1.00258, 1.00878, 0.97039, 0.97686, 0.94315, 0.97728, 0.91154, 0.86139, 0.76592", label="Mag Ratios (comma-separated)", placeholder="e.g., 1.0,1.1,1.2,...") | |
target_length_input = gr.Number(label="Target Length", value=50) | |
threshold_input = gr.Slider(label="Threshold", value=0.24, minimum=0.0, maximum=2.0, step=0.01) | |
mag_single_step_threshold = gr.Slider(label="Single step Threshold", value=-0.01, minimum=-0.01, maximum=2.0, step=0.01) | |
k_input = gr.Slider(label="K", value=6, minimum=-1, maximum=20, step=1) | |
retention_input = gr.Slider(label="Retention", value=0.2, minimum=0.0, maximum=1.0, step=0.01) | |
submit_button = gr.Button("Submit") | |
with gr.Column(): | |
output_text = gr.Textbox(label="Skip Steps") | |
output_plot = gr.LinePlot(x="x", y="y", color="label", label="Accumulated Errors Graph") | |
hidden_text = gr.Textbox(visible=False, value="This is a hidden text box for internal use.") | |
examples = gr.Examples( | |
examples=[ | |
HUNYUAN_VIDEO_DEFAULT, | |
FRAMEPACK_DEFAULT, | |
FLUX_DEFAULT, | |
WAN_DEFAULT, | |
], | |
inputs=[hidden_text, mag_ratios_input, target_length_input, threshold_input, mag_single_step_threshold, k_input, retention_input], | |
fn=visualize_magcache, | |
outputs=[output_text, output_plot], | |
cache_mode="eager", | |
run_on_click=True, | |
) | |
submit_button.click( | |
visualize_magcache, | |
inputs=[hidden_text, mag_ratios_input, target_length_input, threshold_input, mag_single_step_threshold, k_input, retention_input], | |
outputs=[output_text, output_plot] | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = gradio_interface() | |
demo.launch() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment