Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Last active January 13, 2025 01:51
Show Gist options
  • Save sayakpaul/e1f28e86d0756d587c0b898c73822c47 to your computer and use it in GitHub Desktop.
Save sayakpaul/e1f28e86d0756d587c0b898c73822c47 to your computer and use it in GitHub Desktop.
Shows how to run Flux schnell under 17GBs without bells and whistles. It additionally shows how to serialize the quantized checkpoint and load it back.
import torch
from huggingface_hub import hf_hub_download
from diffusers import FluxTransformer2DModel, DiffusionPipeline
dtype, device = torch.bfloat16, "cuda"
ckpt_id = "black-forest-labs/FLUX.1-schnell"
with torch.device("meta"):
config = FluxTransformer2DModel.load_config(ckpt_id, subfolder="transformer")
model = FluxTransformer2DModel.from_config(config).to(dtype)
ckpt_path = hf_hub_download(repo_id="sayakpaul/flux.1-schell-int8wo", filename="flux_schnell_int8wo.pt")
state_dict = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(state_dict, assign=True)
pipeline = DiffusionPipeline.from_pretrained(ckpt_id, transformer=model, torch_dtype=dtype).to("cuda")
image = pipeline(
"cat", guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256
).images[0]
image.save("flux_schnell_int8.png")
# Install `torchao` from source: https://github.com/pytorch/ao
# Install PyTorch nightly
from diffusers import DiffusionPipeline, FluxTransformer2DModel, AutoencoderKL
from transformers import T5EncoderModel, CLIPTextModel
from torchao.quantization import quantize_, int8_weight_only
import torch
ckpt_id = "black-forest-labs/FLUX.1-schnell"
# Quantize the components individually.
# If quality is taking a hit then don't quantize all components.
# Mix and match.
############ Diffusion Transformer ############
transformer = FluxTransformer2DModel.from_pretrained(
ckpt_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
quantize_(transformer, int8_weight_only())
############ Text Encoder ############
text_encoder = CLIPTextModel.from_pretrained(
ckpt_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
)
quantize_(text_encoder, int8_weight_only())
############ Text Encoder 2 ############
text_encoder_2 = T5EncoderModel.from_pretrained(
ckpt_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
)
quantize_(text_encoder_2, int8_weight_only())
############ VAE ############
vae = AutoencoderKL.from_pretrained(
ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16
)
quantize_(vae, int8_weight_only())
# Initialize the pipeline now.
pipeline = DiffusionPipeline.from_pretrained(
ckpt_id,
transformer=transformer,
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline(
"cat", guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256
).images[0]
torch.cuda.empty_cache()
memory = (torch.cuda.memory_allocated() / 1024 / 1024 / 1024)
print(f"{memory=:.3f} GB")
image.save("quantized_image.png")
from diffusers import FluxTransformer2DModel
from torchao.quantization import quantize_, int8_weight_only
import torch
ckpt_id = "black-forest-labs/FLUX.1-schnell"
transformer = FluxTransformer2DModel.from_pretrained(
ckpt_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
quantize_(transformer, int8_weight_only())
# should ideally be possible with safetensors but
# https://github.com/huggingface/safetensors/issues/515
# this checkpoint is 12GB instead of 23GB.
torch.save(transformer.state_dict(), "flux_schnell_int8wo.pt")
@sayakpaul
Copy link
Author

Oh okay, the error you're facing is a known issue: huggingface/safetensors#515. Cc: @jerryzh168.

@al-swaiti
Copy link

al-swaiti commented Aug 23, 2024

okay i tried to save each alone

from diffusers import DiffusionPipeline, FluxTransformer2DModel, AutoencoderKL
from transformers import T5EncoderModel, CLIPTextModel
from torchao.quantization import quantize_, int8_weight_only
import torch
import os

ckpt_id = "sayakpaul/FLUX.1-merged"

def load_and_quantize(model_class, subfolder):
    model = model_class.from_pretrained(
        ckpt_id, subfolder=subfolder, torch_dtype=torch.float16
    )
    quantize_(model, int8_weight_only())
    return model

# Load and quantize components
transformer = load_and_quantize(FluxTransformer2DModel, "transformer")
text_encoder = load_and_quantize(CLIPTextModel, "text_encoder")
text_encoder_2 = load_and_quantize(T5EncoderModel, "text_encoder_2")
vae = load_and_quantize(AutoencoderKL, "vae")

# Initialize the pipeline
pipeline = DiffusionPipeline.from_pretrained(
    ckpt_id,
    transformer=transformer,
    vae=vae,
    text_encoder=text_encoder,
    text_encoder_2=text_encoder_2,
    torch_dtype=torch.float16
)

# Save the quantized pipeline components separately
save_directory = "quantized_flux_pipeline"
os.makedirs(save_directory, exist_ok=True)

# Move all components to CPU before saving
pipeline.to("cpu")

# Save each component separately
torch.save(pipeline.transformer.state_dict(), os.path.join(save_directory, "transformer.pt"))
torch.save(pipeline.vae.state_dict(), os.path.join(save_directory, "vae.pt"))
torch.save(pipeline.text_encoder.state_dict(), os.path.join(save_directory, "text_encoder.pt"))
torch.save(pipeline.text_encoder_2.state_dict(), os.path.join(save_directory, "text_encoder_2.pt"))

#forget those 
# Save the config files
# pipeline.transformer.save_config(save_directory)
# pipeline.vae.save_config(save_directory)
#pipeline.text_encoder.save_config(save_directory)
# pipeline.text_encoder_2.save_config(save_directory)

print(f"Quantized pipeline components saved to {save_directory}")

# Optional: Clear CUDA cache and print memory usage
torch.cuda.empty_cache()
memory = (torch.cuda.memory_allocated() / 1024 / 1024 / 1024)
print(f"GPU memory usage: {memory:.3f} GB")

image

now trying to upload them again with original config ,,,, you are an experience , how to load them again to produce image , is it the same way as loading transformer > for each !

@danieltudosiu
Copy link

@sayakpaul I would recommend setting the dtype to FP32 instead of BF16 so the code will run on pre-ampere architectures as well.

@al-swaiti
Copy link

al-swaiti commented Aug 23, 2024

import torch
from diffusers import FluxTransformer2DModel, DiffusionPipeline, AutoencoderKL
from transformers import T5EncoderModel, CLIPTextModel, CLIPTextConfig ,T5Config 
import os

dtype, device = torch.bfloat16, "cuda"
ckpt_id = "sayakpaul/FLUX.1-merged"
save_directory = "quantized_flux_pipeline"

# Load transformer
with torch.device("meta"):
    config = FluxTransformer2DModel.load_config(ckpt_id, subfolder="transformer")
    transformer = FluxTransformer2DModel.from_config(config).to(dtype)

transformer_state_dict = torch.load(os.path.join(save_directory, "transformer.pt"), map_location="cpu")
transformer.load_state_dict(transformer_state_dict, assign=True)

# Load VAE
with torch.device("meta"):
    vae_config = AutoencoderKL.load_config(ckpt_id, subfolder="vae")
    vae = AutoencoderKL.from_config(vae_config).to(dtype)

vae_state_dict = torch.load(os.path.join(save_directory, "vae.pt"), map_location="cpu")
vae.load_state_dict(vae_state_dict, assign=True)

text_encoder_config = CLIPTextConfig(ckpt_id, subfolder="text_encoder")
text_encoder = CLIPTextModel(text_encoder_config)

text_encoder_state_dict = torch.load(os.path.join(save_directory, "text_encoder.pt"), map_location="cpu")
text_encoder.load_state_dict(text_encoder_state_dict, assign=True)

# Load text encoder 2
# with torch.device("meta"):
text_encoder_2_config = T5Config(ckpt_id, subfolder="text_encoder_2")
text_encoder_2 = T5EncoderModel(text_encoder_2_config)

text_encoder_2_state_dict = torch.load(os.path.join(save_directory, "text_encoder_2.pt"), map_location="cpu",)
text_encoder_2.load_state_dict(text_encoder_2_state_dict, assign=True)

# Initialize the pipeline
pipeline = DiffusionPipeline.from_pretrained(
    ckpt_id,
    transformer=transformer,
    vae=vae,
    text_encoder=text_encoder,
    text_encoder_2=text_encoder_2,
    torch_dtype=dtype
).to(device)

# Generate image
image = pipeline(
    "cat",
    guidance_scale=0.0,
    num_inference_steps=4,
    max_sequence_length=256
).images[0]

image.save("flux_local_int8.png")
print("Image generated and saved as 'flux_local_int8.png'")

i failed with this !! im not that experience of diffusers

@sayakpaul
Copy link
Author

What is your error?

@al-swaiti
Copy link

/home/abdallah/Desktop/webui/qunata-6h/runquanta.py:15: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  transformer_state_dict = torch.load(os.path.join(save_directory, "transformer.pt"), map_location="cpu")
/home/abdallah/Desktop/webui/qunata-6h/runquanta.py:23: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  vae_state_dict = torch.load(os.path.join(save_directory, "vae.pt"), map_location="cpu")
Traceback (most recent call last):
  File "/home/abdallah/Desktop/webui/qunata-6h/runquanta.py", line 27, in <module>
    text_encoder = CLIPTextModel(text_encoder_config)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/abdallah/Desktop/webui/qunata-6h/.venv/lib/python3.12/site-packages/transformers/models/clip/modeling_clip.py", line 946, in __init__
    self.text_model = CLIPTextTransformer(config)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/abdallah/Desktop/webui/qunata-6h/.venv/lib/python3.12/site-packages/transformers/models/clip/modeling_clip.py", line 840, in __init__
    self.embeddings = CLIPTextEmbeddings(config)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/abdallah/Desktop/webui/qunata-6h/.venv/lib/python3.12/site-packages/transformers/models/clip/modeling_clip.py", line 205, in __init__
    self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/abdallah/Desktop/webui/qunata-6h/.venv/lib/python3.12/site-packages/torch/nn/modules/sparse.py", line 167, in __init__
    torch.empty((num_embeddings, embedding_dim), **factory_kwargs),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: empty() received an invalid combination of arguments - got (tuple, dtype=NoneType, device=NoneType), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.memory_format memory_format = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
 * (tuple of ints size, *, torch.memory_format memory_format = None, Tensor out = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)

@brurpo
Copy link

brurpo commented Oct 8, 2024

@tin2tin we don't support Windows really well simply because torch.compile() doesnt support Windows very well. Feel free to open an issue on our repo though we might choose to prioritze this depending on demand.

For now try installing ao from source using USE_CPP=0 pip install .

This works on windows, in case anyone wants to know
The commands on cmd prompt are actually:

set USE_CPP=0
pip install <path to downloaded torchao folder>

@MostHumble
Copy link

MostHumble commented Jan 12, 2025

Edit 12-01-2025:
as of now, even this is throwing an error:

AttributeError: Can't get attribute 'PlainAQTLayout' on <module 'torchao.dtypes.affine_quantized_tensor' from '...Python\\Python311\\site-packages\\torchao\\dtypes\\affine_quantized_tensor.py'>

For those trying to use the one under 17GBs the link is dead, you can you use this instead:

import torch
from diffusers import FluxTransformer2DModel, DiffusionPipeline

dtype, device = torch.bfloat16, "cuda"
ckpt_id = "black-forest-labs/FLUX.1-schnell"

model = FluxTransformer2DModel.from_pretrained(
    "sayakpaul/flux.1-schell-int8wo-improved", torch_dtype=dtype, use_safetensors=False
)
pipeline = DiffusionPipeline.from_pretrained(ckpt_id, transformer=model, torch_dtype=dtype).to("cuda")
image = pipeline(
    "cat", guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256
).images[0]
image.save("flux_schnell_int8.png")

@tin2tin
Copy link

tin2tin commented Jan 12, 2025

@MostHumble
Getting an error running that code:
AttributeError: Can't get attribute 'PlainAQTLayout' on <module 'torchao.dtypes.affine_quantized_tensor' from '...Python\\Python311\\site-packages\\torchao\\dtypes\\affine_quantized_tensor.py'>
On torchao == 0.7.0

@MostHumble
Copy link

@sayakpaul
Copy link
Author

Please take the issue with torchao. Until it's resolved, either:

  1. Use torchao integration from Diffusers.
  2. Downgrade torchao installation.

Other than that, I can't provide additional suggestions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment