Skip to content

Instantly share code, notes, and snippets.

@anthonywu
Last active August 30, 2025 07:10
Show Gist options
  • Save anthonywu/2832147ff5f5f50c81df4d13152d2bed to your computer and use it in GitHub Desktop.
Save anthonywu/2832147ff5f5f50c81df4d13152d2bed to your computer and use it in GitHub Desktop.
mflux 0.9.0 scheduler sigmas snapshot for unit testing
#!/bin/bash -e
common_seed=12345
common_args="--steps 20 --prompt sunset --seed 12345"
uvx --from mflux==0.9.0 mflux-generate --model dev --output sunset_dev_before.png $common_args
uvx --from mflux==0.9.0 mflux-generate --model schnell --output sunset_schnell_before.png $common_args
mflux-generate --scheduler linear --model dev --output sunset_dev_after.png $common_args
mflux-generate --scheduler linear --model schnell --output sunset_schnell_after.png $common_args
#!/usr/bin/env uvx --from mflux==0.9.0 python
import io
import mlx.core as mx
from mflux.config.config import Config
from mflux.config.runtime_config import RuntimeConfig
from mflux.config.model_config import ModelConfig
rt_config = RuntimeConfig(
Config(
# these are the only attributes relevant to schedulers
num_inference_steps=14,
width=1024,
height=1024,
),
ModelConfig.dev(), # requires_sigma_shift=True
)
sigmas_with_shift = RuntimeConfig._create_sigmas(rt_config, ModelConfig.dev())
print(sigmas_with_shift)
sigmas_without_shift = RuntimeConfig._create_sigmas(rt_config, ModelConfig.schnell())
print(sigmas_without_shift)
out1, out2 = io.BytesIO(), io.BytesIO()
mx.save(out1, sigmas_with_shift)
mx.save(out2, sigmas_without_shift)
compare1 = out1.getvalue().hex()
compare2 = out2.getvalue().hex()
print(compare1, compare2)
"""python
# to reverse the process above: hex string -> mx.array
mx.array(
np.load(
io.BytesIO(
bytes.fromhex(
"abcde...12345"
)
)
)
)
"""
@anthonywu
Copy link
Author

the files are visually the same, but the binary data differs, I suspect due to mlx version diff

image

@anthonywu
Copy link
Author

in my image_generation_test_helper.py I have a un-commited diff that does this:

            actual_data = np.array(Image.open(output_image_path))
            desired_data = np.array(Image.open(reference_image_path))
            try:
                np.testing.assert_allclose(
                    actual_data,
                    desired_data,
                    atol=5,
                    rtol=0.03,
                    err_msg=f"Generated image doesn't match reference image. Check {output_image_path} vs {reference_image_path}",
                )
            except AssertionError as fail:
                # Calculate the absolute difference for each pixel/channel
                diff = actual_data - desired_data

                # Scale the difference so it's visible. A difference of 50 will become bright.
                # We'll scale it so that a difference of 50 or more becomes pure white (255).
                diff_visual = (diff / 50 * 255).clip(0, 255).astype(np.uint8)
                import cv2
                cv2.imwrite(output_image_path.with_stem(output_image_path.stem + "_diff"), diff_visual)
                raise fail

it's showing the spots where the diff is greater than my assert_allclose tolerances set above

image expand to see my `pip freeze`
backports-tarfile==1.2.0
certifi==2025.8.3
charset-normalizer==3.4.3
contourpy==1.3.3
cycler==0.12.1
docutils==0.22
filelock==3.19.1
fonttools==4.59.1
fsspec==2025.7.0
hf-xet==1.1.8
huggingface-hub==0.34.4
id==1.5.0
idna==3.10
importlib-metadata==8.7.0
iniconfig==2.1.0
jaraco-classes==3.4.0
jaraco-context==6.0.1
jaraco-functools==4.3.0
jinja2==3.1.6
keyring==25.6.0
kiwisolver==1.4.9
markdown-it-py==4.0.0
markupsafe==3.0.2
matplotlib==3.10.5
mdurl==0.1.2
-e file:///Users/anthonywu/workspace/mflux-3
mlx==0.27.1
mlx-metal==0.27.1
more-itertools==10.7.0
mpmath==1.3.0
networkx==3.5
nh3==0.3.0
numpy==2.3.2
opencv-python==4.11.0.86
packaging==25.0
piexif==1.1.3
pillow==10.4.0
platformdirs==4.3.8
pluggy==1.6.0
pygments==2.19.2
pyparsing==3.2.3
pytest==8.4.1
python-dateutil==2.9.0.post0
pyyaml==6.0.2
readme-renderer==44.0
regex==2025.7.34
requests==2.32.5
requests-toolbelt==1.0.0
rfc3986==2.0.0
rich==14.1.0
safetensors==0.6.2
sentencepiece==0.2.1
six==1.17.0
sympy==1.14.0
tokenizers==0.21.4
toml==0.10.2
torch==2.8.0
tqdm==4.67.1
transformers==4.55.4
twine==6.1.0
typing-extensions==4.15.0
urllib3==2.5.0
zipp==3.23.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment