Last active
August 5, 2025 07:04
-
-
Save davidberenstein1957/643f79c31526969aeb483da91e4f52a6 to your computer and use it in GitHub Desktop.
Compress and optimize FLUX to make it run faster
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # requires-python = ">=3.11,<3.12" | |
| # dependencies = [ | |
| # "pruna", | |
| # "pyarrow<20" | |
| # ] | |
| # /// | |
| import torch | |
| import gc | |
| import shutil | |
| import os | |
| from diffusers import FluxPipeline | |
| from pruna import SmashConfig, smash | |
| def clear_gpu_cache(): | |
| """Clear GPU cache and run garbage collection""" | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| gc.collect() | |
| def clear_local_cache(): | |
| """Clear local Hugging Face cache""" | |
| cache_dir = os.path.expanduser("~/.cache/huggingface") | |
| if os.path.exists(cache_dir): | |
| try: | |
| shutil.rmtree(cache_dir) | |
| print(" Cleared local HuggingFace cache") | |
| except Exception as e: | |
| print(f" Warning: Could not clear cache: {e}") | |
| # Also clear transformers cache if it exists | |
| transformers_cache = os.path.expanduser("~/.cache/transformers") | |
| if os.path.exists(transformers_cache): | |
| try: | |
| shutil.rmtree(transformers_cache) | |
| print(" Cleared transformers cache") | |
| except Exception as e: | |
| print(f" Warning: Could not clear transformers cache: {e}") | |
| def process_model(model_name, hub_name_compiled, hub_name_no_compile): | |
| """Load model, apply smash optimization with and without compilation, and save to hub""" | |
| print(f"Processing {model_name}...") | |
| # Process with compilation | |
| print(f" Processing with compilation...") | |
| pipe = FluxPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to("cpu") | |
| smash_config = SmashConfig() | |
| smash_config["cacher"] = "fora" | |
| smash_config["fora_interval"] = 3 # or 2 for even faster inference | |
| smash_config["compiler"] = "torch_compile" | |
| smash_config["torch_compile_mode"] = "max-autotune-no-cudagraphs" | |
| smash_config["quantizer"] = "torchao" | |
| smash_config["torchao_quant_type"] = "int8dq" # you can also try fp8dq | |
| smash_config["torchao_excluded_modules"] = "norm+embedding" | |
| smashed_pipe = smash(pipe, smash_config) | |
| smashed_pipe.save_to_hub(hub_name_compiled) | |
| print(f" Saved {hub_name_compiled} to hub") | |
| del pipe | |
| del smashed_pipe | |
| clear_gpu_cache() | |
| # Process without compilation | |
| print(f" Processing without compilation...") | |
| pipe = FluxPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to("cpu") | |
| smash_config = SmashConfig(device="cpu") | |
| smash_config["cacher"] = "fora" | |
| smash_config["fora_interval"] = 3 # or 2 for even faster inference | |
| smash_config["quantizer"] = "torchao" | |
| smash_config["torchao_quant_type"] = "int8dq" # you can also try fp8dq | |
| smash_config["torchao_excluded_modules"] = "norm+embedding" | |
| smashed_pipe = smash(pipe, smash_config) | |
| smashed_pipe.save_to_hub(hub_name_no_compile) | |
| print(f" Saved {hub_name_no_compile} to hub") | |
| del pipe | |
| del smashed_pipe | |
| clear_gpu_cache() | |
| def main(): | |
| """Main function to process both FLUX models""" | |
| models = [ | |
| ("black-forest-labs/FLUX.1-dev", "PrunaAI/FLUX.1-dev-smashed", "PrunaAI/FLUX.1-dev-smashed-no-compile"), | |
| ("black-forest-labs/FLUX.1-schnell", "PrunaAI/FLUX.1-schnell-smashed", "PrunaAI/FLUX.1-schnell-smashed-no-compile"), | |
| ("black-forest-labs/FLUX.1-Krea-dev", "PrunaAI/FLUX.1-Krea-dev-smashed", "PrunaAI/FLUX.1-Krea-dev-smashed-no-compile"), | |
| ] | |
| for model_name, hub_name_compiled, hub_name_no_compile in models: | |
| process_model(model_name, hub_name_compiled, hub_name_no_compile) | |
| # Final cleanup | |
| clear_local_cache() | |
| print("All models processed successfully and local cache cleared!") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment