-
-
Save sayakpaul/e1f28e86d0756d587c0b898c73822c47 to your computer and use it in GitHub Desktop.
import torch | |
from huggingface_hub import hf_hub_download | |
from diffusers import FluxTransformer2DModel, DiffusionPipeline | |
dtype, device = torch.bfloat16, "cuda" | |
ckpt_id = "black-forest-labs/FLUX.1-schnell" | |
with torch.device("meta"): | |
config = FluxTransformer2DModel.load_config(ckpt_id, subfolder="transformer") | |
model = FluxTransformer2DModel.from_config(config).to(dtype) | |
ckpt_path = hf_hub_download(repo_id="sayakpaul/flux.1-schell-int8wo", filename="flux_schnell_int8wo.pt") | |
state_dict = torch.load(ckpt_path, map_location="cpu") | |
model.load_state_dict(state_dict, assign=True) | |
pipeline = DiffusionPipeline.from_pretrained(ckpt_id, transformer=model, torch_dtype=dtype).to("cuda") | |
image = pipeline( | |
"cat", guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256 | |
).images[0] | |
image.save("flux_schnell_int8.png") |
# Install `torchao` from source: https://github.com/pytorch/ao | |
# Install PyTorch nightly | |
from diffusers import DiffusionPipeline, FluxTransformer2DModel, AutoencoderKL | |
from transformers import T5EncoderModel, CLIPTextModel | |
from torchao.quantization import quantize_, int8_weight_only | |
import torch | |
ckpt_id = "black-forest-labs/FLUX.1-schnell" | |
# Quantize the components individually. | |
# If quality is taking a hit then don't quantize all components. | |
# Mix and match. | |
############ Diffusion Transformer ############ | |
transformer = FluxTransformer2DModel.from_pretrained( | |
ckpt_id, subfolder="transformer", torch_dtype=torch.bfloat16 | |
) | |
quantize_(transformer, int8_weight_only()) | |
############ Text Encoder ############ | |
text_encoder = CLIPTextModel.from_pretrained( | |
ckpt_id, subfolder="text_encoder", torch_dtype=torch.bfloat16 | |
) | |
quantize_(text_encoder, int8_weight_only()) | |
############ Text Encoder 2 ############ | |
text_encoder_2 = T5EncoderModel.from_pretrained( | |
ckpt_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16 | |
) | |
quantize_(text_encoder_2, int8_weight_only()) | |
############ VAE ############ | |
vae = AutoencoderKL.from_pretrained( | |
ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16 | |
) | |
quantize_(vae, int8_weight_only()) | |
# Initialize the pipeline now. | |
pipeline = DiffusionPipeline.from_pretrained( | |
ckpt_id, | |
transformer=transformer, | |
vae=vae, | |
text_encoder=text_encoder, | |
text_encoder_2=text_encoder_2, | |
torch_dtype=torch.bfloat16 | |
).to("cuda") | |
image = pipeline( | |
"cat", guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256 | |
).images[0] | |
torch.cuda.empty_cache() | |
memory = (torch.cuda.memory_allocated() / 1024 / 1024 / 1024) | |
print(f"{memory=:.3f} GB") | |
image.save("quantized_image.png") |
from diffusers import FluxTransformer2DModel | |
from torchao.quantization import quantize_, int8_weight_only | |
import torch | |
ckpt_id = "black-forest-labs/FLUX.1-schnell" | |
transformer = FluxTransformer2DModel.from_pretrained( | |
ckpt_id, subfolder="transformer", torch_dtype=torch.bfloat16 | |
) | |
quantize_(transformer, int8_weight_only()) | |
# should ideally be possible with safetensors but | |
# https://github.com/huggingface/safetensors/issues/515 | |
# this checkpoint is 12GB instead of 23GB. | |
torch.save(transformer.state_dict(), "flux_schnell_int8wo.pt") |
al-swaiti
commented
Aug 23, 2024
@tin2tin we don't support Windows really well simply because
torch.compile()
doesnt support Windows very well. Feel free to open an issue on our repo though we might choose to prioritze this depending on demand.For now try installing ao from source using
USE_CPP=0 pip install .
This works on windows, in case anyone wants to know
The commands on cmd prompt are actually:
set USE_CPP=0
pip install <path to downloaded torchao folder>
Edit 12-01-2025:
as of now, even this is throwing an error:
AttributeError: Can't get attribute 'PlainAQTLayout' on <module 'torchao.dtypes.affine_quantized_tensor' from '...Python\\Python311\\site-packages\\torchao\\dtypes\\affine_quantized_tensor.py'>
For those trying to use the one under 17GBs the link is dead, you can you use this instead:
import torch
from diffusers import FluxTransformer2DModel, DiffusionPipeline
dtype, device = torch.bfloat16, "cuda"
ckpt_id = "black-forest-labs/FLUX.1-schnell"
model = FluxTransformer2DModel.from_pretrained(
"sayakpaul/flux.1-schell-int8wo-improved", torch_dtype=dtype, use_safetensors=False
)
pipeline = DiffusionPipeline.from_pretrained(ckpt_id, transformer=model, torch_dtype=dtype).to("cuda")
image = pipeline(
"cat", guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256
).images[0]
image.save("flux_schnell_int8.png")
@MostHumble
Getting an error running that code:
AttributeError: Can't get attribute 'PlainAQTLayout' on <module 'torchao.dtypes.affine_quantized_tensor' from '...Python\\Python311\\site-packages\\torchao\\dtypes\\affine_quantized_tensor.py'>
On torchao == 0.7.0
@tin2tin
Ended up having the same issue, reported it here: https://huggingface.co/sayakpaul/flux.1-schell-int8wo-improved/discussions/2#6783e9e49c52f42c530ff7c1
Please take the issue with torchao
. Until it's resolved, either:
- Use
torchao
integration from Diffusers. - Downgrade
torchao
installation.
Other than that, I can't provide additional suggestions.