Skip to content

Instantly share code, notes, and snippets.

@laksjdjf
Created June 20, 2025 07:26
Show Gist options
  • Save laksjdjf/8e615f312aec4bee2236df7602f335b8 to your computer and use it in GitHub Desktop.
Save laksjdjf/8e615f312aec4bee2236df7602f335b8 to your computer and use it in GitHub Desktop.
import numpy as np
import pandas as pd
import gradio as gr
HUNYUAN_VIDEO_DEFAULT = [
"Hunyuan Video 544p",
"1.0, 1.06971, 1.29073, 1.11245, 1.09596, 1.05233, 1.01415, 1.05672, 1.00848, 1.03632, 1.02974, 1.00984, 1.03028, 1.00681, 1.06614, 1.05022, 1.02592, 1.01776, 1.02985, 1.00726, 1.03727, 1.01502, 1.00992, 1.03371, 0.9976, 1.02742, 1.0093, 1.01869, 1.00815, 1.01461, 1.01152, 1.03082, 1.0061, 1.02162, 1.01999, 0.99063, 1.01186, 1.0217, 0.99947, 1.01711, 0.9904, 1.00258, 1.00878, 0.97039, 0.97686, 0.94315, 0.97728, 0.91154, 0.86139, 0.76592",
50,
0.24,
-0.01,
6,
0.2,
]
FRAMEPACK_DEFAULT = [
"FramePack",
"1.0, 1.0754, 1.27807, 1.11596, 1.09504, 1.05188, 1.00844, 1.05779, 1.00657, 1.04142, 1.03101, 1.00679, 1.02556, 1.00908, 1.06949, 1.05438, 1.02214, 1.02321, 1.03019, 1.00779, 1.03381, 1.01886, 1.01161, 1.02968, 1.00544, 1.02822, 1.00689, 1.02119, 1.0105, 1.01044, 1.01572, 1.02972, 1.0094, 1.02368, 1.0226, 0.98965, 1.01588, 1.02146, 1.0018, 1.01687, 0.99436, 1.00283, 1.01139, 0.97122, 0.98251, 0.94513, 0.97656, 0.90943, 0.85703, 0.75456",
50,
0.1,
0.07,
3,
0.2,
]
FLUX_DEFAULT = [
"Flux",
"1.0, 1.21094, 1.11719, 1.07812, 1.0625, 1.03906, 1.03125, 1.03906, 1.02344, 1.03125, 1.02344, 0.98047, 1.01562, 1.00781, 1.0, 1.00781, 1.0, 1.00781, 1.0, 1.0, 0.99609, 0.99609, 0.98047, 0.98828, 0.96484, 0.95703, 0.93359, 0.89062",
28,
0.24,
-0.01,
5,
0.1,
]
WAN_DEFAULT = [
"WAN2.1 14B",
"1.0, 1.0, 1.02504, 1.03017, 1.00025, 1.00251, 0.9985, 0.99962, 0.99779, 0.99771, 0.9966, 0.99658, 0.99482, 0.99476, 0.99467, 0.99451, 0.99664, 0.99656, 0.99434, 0.99431, 0.99533, 0.99545, 0.99468, 0.99465, 0.99438, 0.99434, 0.99516, 0.99517, 0.99384, 0.9938, 0.99404, 0.99401, 0.99517, 0.99516, 0.99409, 0.99408, 0.99428, 0.99426, 0.99347, 0.99343, 0.99418, 0.99416, 0.99271, 0.99269, 0.99313, 0.99311, 0.99215, 0.99215, 0.99218, 0.99215, 0.99216, 0.99217, 0.99163, 0.99161, 0.99138, 0.99135, 0.98982, 0.9898, 0.98996, 0.98995, 0.9887, 0.98866, 0.98772, 0.9877, 0.98767, 0.98765, 0.98573, 0.9857, 0.98501, 0.98498, 0.9838, 0.98376, 0.98177, 0.98173, 0.98037, 0.98035, 0.97678, 0.97677, 0.97546, 0.97543, 0.97184, 0.97183, 0.96711, 0.96708, 0.96349, 0.96345, 0.95629, 0.95625, 0.94926, 0.94929, 0.93964, 0.93961, 0.92511, 0.92504, 0.90693, 0.90678, 0.8796, 0.87945, 0.86111, 0.86189",
100,
0.24,
-0.01,
6,
0.2,
]
def nearest_interp(src_array, target_length):
src_length = len(src_array)
if target_length == 1:
return np.array([src_array[-1]])
scale = (src_length - 1) / (target_length - 1)
mapped_indices = np.round(np.arange(target_length) * scale).astype(int)
return np.array(src_array)[mapped_indices].tolist()
def magcache(src_array, mag_single_step_threshold, target_length, threshold, k, retention):
skip_steps = []
accumulated_ratio = 1.0
accumulated_step = 0
accumulated_error = 0.0
mag_single_step_threshold = mag_single_step_threshold if mag_single_step_threshold >= 0 else 100.0
k = k if k >= 0 else target_length+1
target_array = nearest_interp(src_array, target_length)
accumulated_errors = []
errors = []
for i in range(target_length):
accumulated_ratio *= target_array[i]
accumulated_step += 1
error = abs(target_array[i] - 1)
accumulated_error += error
errors.append(error)
accumulated_errors.append(accumulated_error)
if i >= max(1, int(retention * target_length)) and accumulated_error <= threshold and accumulated_step <= k and error <= mag_single_step_threshold:
skip_steps.append(i)
else:
accumulated_ratio = 1.0
accumulated_step = 0
accumulated_error = 0.0
return skip_steps, accumulated_errors, errors
def visualize_magcache(hidden_text, mag_ratios, target_length, threshold, mag_single_step_threshold, k, retention):
mag_ratios = list(map(float, mag_ratios.split(',')))
skip_steps, accumulated_errors, errors = magcache(mag_ratios, mag_single_step_threshold, target_length, threshold, k, retention)
df = pd.DataFrame({
'x': np.arange(len(accumulated_errors)).tolist() + np.arange(len(errors)).tolist() + [0, target_length],
'y': accumulated_errors + errors + [threshold, threshold],
"label": ["Accumulated Errors"] * len(accumulated_errors) + ["Errors"] * len(errors) + ["Threshold"] * 2
})
return f"Skip steps: {skip_steps}\n Total {len(skip_steps)}({len(skip_steps)/target_length:.0%}) steps skiped.", df
# Gradio UI
def gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("""# MagCache Visualization
Enter the parameters below to compute skip steps and visualize accumulated errors.
""")
with gr.Row():
with gr.Column():
mag_ratios_input = gr.Textbox(value = "1.0, 1.06971, 1.29073, 1.11245, 1.09596, 1.05233, 1.01415, 1.05672, 1.00848, 1.03632, 1.02974, 1.00984, 1.03028, 1.00681, 1.06614, 1.05022, 1.02592, 1.01776, 1.02985, 1.00726, 1.03727, 1.01502, 1.00992, 1.03371, 0.9976, 1.02742, 1.0093, 1.01869, 1.00815, 1.01461, 1.01152, 1.03082, 1.0061, 1.02162, 1.01999, 0.99063, 1.01186, 1.0217, 0.99947, 1.01711, 0.9904, 1.00258, 1.00878, 0.97039, 0.97686, 0.94315, 0.97728, 0.91154, 0.86139, 0.76592", label="Mag Ratios (comma-separated)", placeholder="e.g., 1.0,1.1,1.2,...")
target_length_input = gr.Number(label="Target Length", value=50)
threshold_input = gr.Slider(label="Threshold", value=0.24, minimum=0.0, maximum=2.0, step=0.01)
mag_single_step_threshold = gr.Slider(label="Single step Threshold", value=-0.01, minimum=-0.01, maximum=2.0, step=0.01)
k_input = gr.Slider(label="K", value=6, minimum=-1, maximum=20, step=1)
retention_input = gr.Slider(label="Retention", value=0.2, minimum=0.0, maximum=1.0, step=0.01)
submit_button = gr.Button("Submit")
with gr.Column():
output_text = gr.Textbox(label="Skip Steps")
output_plot = gr.LinePlot(x="x", y="y", color="label", label="Accumulated Errors Graph")
hidden_text = gr.Textbox(visible=False, value="This is a hidden text box for internal use.")
examples = gr.Examples(
examples=[
HUNYUAN_VIDEO_DEFAULT,
FRAMEPACK_DEFAULT,
FLUX_DEFAULT,
WAN_DEFAULT,
],
inputs=[hidden_text, mag_ratios_input, target_length_input, threshold_input, mag_single_step_threshold, k_input, retention_input],
fn=visualize_magcache,
outputs=[output_text, output_plot],
cache_mode="eager",
run_on_click=True,
)
submit_button.click(
visualize_magcache,
inputs=[hidden_text, mag_ratios_input, target_length_input, threshold_input, mag_single_step_threshold, k_input, retention_input],
outputs=[output_text, output_plot]
)
return demo
if __name__ == "__main__":
demo = gradio_interface()
demo.launch()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment