Skip to content

Instantly share code, notes, and snippets.

@davidberenstein1957
Last active August 5, 2025 07:04
Show Gist options
  • Select an option

  • Save davidberenstein1957/643f79c31526969aeb483da91e4f52a6 to your computer and use it in GitHub Desktop.

Select an option

Save davidberenstein1957/643f79c31526969aeb483da91e4f52a6 to your computer and use it in GitHub Desktop.
Compress and optimize FLUX to make it run faster
# /// script
# requires-python = ">=3.11,<3.12"
# dependencies = [
# "pruna",
# "pyarrow<20"
# ]
# ///
import torch
import gc
import shutil
import os
from diffusers import FluxPipeline
from pruna import SmashConfig, smash
def clear_gpu_cache():
"""Clear GPU cache and run garbage collection"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
def clear_local_cache():
"""Clear local Hugging Face cache"""
cache_dir = os.path.expanduser("~/.cache/huggingface")
if os.path.exists(cache_dir):
try:
shutil.rmtree(cache_dir)
print(" Cleared local HuggingFace cache")
except Exception as e:
print(f" Warning: Could not clear cache: {e}")
# Also clear transformers cache if it exists
transformers_cache = os.path.expanduser("~/.cache/transformers")
if os.path.exists(transformers_cache):
try:
shutil.rmtree(transformers_cache)
print(" Cleared transformers cache")
except Exception as e:
print(f" Warning: Could not clear transformers cache: {e}")
def process_model(model_name, hub_name_compiled, hub_name_no_compile):
"""Load model, apply smash optimization with and without compilation, and save to hub"""
print(f"Processing {model_name}...")
# Process with compilation
print(f" Processing with compilation...")
pipe = FluxPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to("cpu")
smash_config = SmashConfig()
smash_config["cacher"] = "fora"
smash_config["fora_interval"] = 3 # or 2 for even faster inference
smash_config["compiler"] = "torch_compile"
smash_config["torch_compile_mode"] = "max-autotune-no-cudagraphs"
smash_config["quantizer"] = "torchao"
smash_config["torchao_quant_type"] = "int8dq" # you can also try fp8dq
smash_config["torchao_excluded_modules"] = "norm+embedding"
smashed_pipe = smash(pipe, smash_config)
smashed_pipe.save_to_hub(hub_name_compiled)
print(f" Saved {hub_name_compiled} to hub")
del pipe
del smashed_pipe
clear_gpu_cache()
# Process without compilation
print(f" Processing without compilation...")
pipe = FluxPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to("cpu")
smash_config = SmashConfig(device="cpu")
smash_config["cacher"] = "fora"
smash_config["fora_interval"] = 3 # or 2 for even faster inference
smash_config["quantizer"] = "torchao"
smash_config["torchao_quant_type"] = "int8dq" # you can also try fp8dq
smash_config["torchao_excluded_modules"] = "norm+embedding"
smashed_pipe = smash(pipe, smash_config)
smashed_pipe.save_to_hub(hub_name_no_compile)
print(f" Saved {hub_name_no_compile} to hub")
del pipe
del smashed_pipe
clear_gpu_cache()
def main():
"""Main function to process both FLUX models"""
models = [
("black-forest-labs/FLUX.1-dev", "PrunaAI/FLUX.1-dev-smashed", "PrunaAI/FLUX.1-dev-smashed-no-compile"),
("black-forest-labs/FLUX.1-schnell", "PrunaAI/FLUX.1-schnell-smashed", "PrunaAI/FLUX.1-schnell-smashed-no-compile"),
("black-forest-labs/FLUX.1-Krea-dev", "PrunaAI/FLUX.1-Krea-dev-smashed", "PrunaAI/FLUX.1-Krea-dev-smashed-no-compile"),
]
for model_name, hub_name_compiled, hub_name_no_compile in models:
process_model(model_name, hub_name_compiled, hub_name_no_compile)
# Final cleanup
clear_local_cache()
print("All models processed successfully and local cache cleared!")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment