Skip to content

Instantly share code, notes, and snippets.

@asomoza
Created September 4, 2024 02:07
Show Gist options
  • Save asomoza/5286bf17060a3296b7769a5fb2d6cd08 to your computer and use it in GitHub Desktop.
Save asomoza/5286bf17060a3296b7769a5fb2d6cd08 to your computer and use it in GitHub Desktop.
import torch
from optimum.quanto import QuantizedDiffusersModel, freeze, qfloat8, quantize
from diffusers import FluxPipeline, FluxTransformer2DModel
class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
base_class = FluxTransformer2DModel
dtype = torch.bfloat16
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev", subfolder="transformer", torch_dtype=dtype
)
quantize(transformer, weights=qfloat8)
freeze(transformer)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=None, torch_dtype=dtype
)
pipe.transformer = transformer
pipe.enable_model_cpu_offload()
generator = torch.Generator().manual_seed(12345)
image = pipe(
prompt="high quality photo of a dog sitting beside a tree with a red soccer ball at its side, the background it's a lake with a boat sailing in it and an airplane flying in the cloudy sky.",
width=1024,
height=1024,
num_inference_steps=20,
generator=generator,
guidance_scale=3.5,
).images[0]
image.save("fluxfp8_image.png")
@asomoza
Copy link
Author

asomoza commented Oct 11, 2024

Hi, this is kind of old, I only tested this with a 3090 back then.

I tried just now to run it with an A100 and got a bunch of errors in the quantization, probably because quanto changed a lot. Sadly I don't have the time to look into but with A100 and A6000 you shouldn't need to use quantization though. Still this shouldn't be a problem when we have official support for quantization in diffusers.

@Leommm-byte
Copy link

Thank you very much for your response. Yeah, I do not need to use quantization, I just wanted to test out the speed difference as I've seen (at least with replicate) that the fp8 variant is considerably faster.

https://replicate.com/blog/flux-is-fast-and-open-source

I have used this quantization method with a 4090 (16gb) previously and it worked well. I'm just surprised that this isn't working with Ampere GPUs. Infact I went further to store the frozen weights so I can just load directly (worked well with my 4090), but it still generates noise with A-series GPU.

Thank you once again.

@Leommm-byte
Copy link

Thank you very much for your response. Yeah, I do not need to use quantization, I just wanted to test out the speed difference as I've seen (at least with replicate) that the fp8 variant is considerably faster.

https://replicate.com/blog/flux-is-fast-and-open-source

I have used this quantization method with a 4090 (16gb) previously and it worked well. I'm just surprised that this isn't working with Ampere GPUs. Infact I went further to store the frozen weights so I can just load directly (worked well with my 4090), but it still generates noise with A-series GPU.

Thank you once again.

To follow up on this, I've been able to resolve the issue. Apparently, there is a bug in the recent optimum-quanto 0.25.0 that corrupts the transformer weights during qfloat8 quantization. So all I had to do was revert to the dev version found in this branch.

https://github.com/huggingface/optimum-quanto.git@65ace79d6af6ccc27afbb3576541cc36b3e3a98b

Hopefully, it will be resolved in the next update.

@asomoza
Copy link
Author

asomoza commented Oct 15, 2024

Thanks a lot for investigating into this, I really appreciate it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment