-
-
Save sayakpaul/e1f28e86d0756d587c0b898c73822c47 to your computer and use it in GitHub Desktop.
import torch | |
from huggingface_hub import hf_hub_download | |
from diffusers import FluxTransformer2DModel, DiffusionPipeline | |
dtype, device = torch.bfloat16, "cuda" | |
ckpt_id = "black-forest-labs/FLUX.1-schnell" | |
with torch.device("meta"): | |
config = FluxTransformer2DModel.load_config(ckpt_id, subfolder="transformer") | |
model = FluxTransformer2DModel.from_config(config).to(dtype) | |
ckpt_path = hf_hub_download(repo_id="sayakpaul/flux.1-schell-int8wo", filename="flux_schnell_int8wo.pt") | |
state_dict = torch.load(ckpt_path, map_location="cpu") | |
model.load_state_dict(state_dict, assign=True) | |
pipeline = DiffusionPipeline.from_pretrained(ckpt_id, transformer=model, torch_dtype=dtype).to("cuda") | |
image = pipeline( | |
"cat", guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256 | |
).images[0] | |
image.save("flux_schnell_int8.png") |
# Install `torchao` from source: https://github.com/pytorch/ao | |
# Install PyTorch nightly | |
from diffusers import DiffusionPipeline, FluxTransformer2DModel, AutoencoderKL | |
from transformers import T5EncoderModel, CLIPTextModel | |
from torchao.quantization import quantize_, int8_weight_only | |
import torch | |
ckpt_id = "black-forest-labs/FLUX.1-schnell" | |
# Quantize the components individually. | |
# If quality is taking a hit then don't quantize all components. | |
# Mix and match. | |
############ Diffusion Transformer ############ | |
transformer = FluxTransformer2DModel.from_pretrained( | |
ckpt_id, subfolder="transformer", torch_dtype=torch.bfloat16 | |
) | |
quantize_(transformer, int8_weight_only()) | |
############ Text Encoder ############ | |
text_encoder = CLIPTextModel.from_pretrained( | |
ckpt_id, subfolder="text_encoder", torch_dtype=torch.bfloat16 | |
) | |
quantize_(text_encoder, int8_weight_only()) | |
############ Text Encoder 2 ############ | |
text_encoder_2 = T5EncoderModel.from_pretrained( | |
ckpt_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16 | |
) | |
quantize_(text_encoder_2, int8_weight_only()) | |
############ VAE ############ | |
vae = AutoencoderKL.from_pretrained( | |
ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16 | |
) | |
quantize_(vae, int8_weight_only()) | |
# Initialize the pipeline now. | |
pipeline = DiffusionPipeline.from_pretrained( | |
ckpt_id, | |
transformer=transformer, | |
vae=vae, | |
text_encoder=text_encoder, | |
text_encoder_2=text_encoder_2, | |
torch_dtype=torch.bfloat16 | |
).to("cuda") | |
image = pipeline( | |
"cat", guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256 | |
).images[0] | |
torch.cuda.empty_cache() | |
memory = (torch.cuda.memory_allocated() / 1024 / 1024 / 1024) | |
print(f"{memory=:.3f} GB") | |
image.save("quantized_image.png") |
from diffusers import FluxTransformer2DModel | |
from torchao.quantization import quantize_, int8_weight_only | |
import torch | |
ckpt_id = "black-forest-labs/FLUX.1-schnell" | |
transformer = FluxTransformer2DModel.from_pretrained( | |
ckpt_id, subfolder="transformer", torch_dtype=torch.bfloat16 | |
) | |
quantize_(transformer, int8_weight_only()) | |
# should ideally be possible with safetensors but | |
# https://github.com/huggingface/safetensors/issues/515 | |
# this checkpoint is 12GB instead of 23GB. | |
torch.save(transformer.state_dict(), "flux_schnell_int8wo.pt") |
sayakpaul
commented
Aug 19, 2024
Component | Memory (GB) |
---|---|
Transformer | 20.494 |
T5 | 16.155 |
CLIP | 16.072 |
VAE | 16.069 |
Has this been tested on Windows? I can't seem to find a version of torchao which includes quantization.
@tin2tin we don't support Windows really well simply because torch.compile()
doesnt support Windows very well. Feel free to open an issue on our repo though we might choose to prioritze this depending on demand.
For now try installing ao from source using USE_CPP=0 pip install .
Is it really possible to do this for flux.1-dev?
I don't see any reason not to. The only extra thing present in Flux.1 dev is guidance_in
. So I am not sure what makes you think it might not be possible for Dev.
i success with merged flux ,, trying to apply this for normal check point(comfyui , webui type) any suggestion
from diffusers import DiffusionPipeline, FluxTransformer2DModel, AutoencoderKL
from transformers import T5EncoderModel, CLIPTextModel
from torchao.quantization import quantize_, int8_weight_only
import torch
import os
ckpt_id = "sayakpaul/FLUX.1-merged"
# Function to load and quantize a component
def load_and_quantize(model_class, subfolder):
model = model_class.from_pretrained(
ckpt_id, subfolder=subfolder, torch_dtype=torch.float16
)
quantize_(model, int8_weight_only())
return model
# Load and quantize components
transformer = load_and_quantize(FluxTransformer2DModel, "transformer")
text_encoder = load_and_quantize(CLIPTextModel, "text_encoder")
text_encoder_2 = load_and_quantize(T5EncoderModel, "text_encoder_2")
vae = load_and_quantize(AutoencoderKL, "vae")
# Initialize the pipeline
pipeline = DiffusionPipeline.from_pretrained(
ckpt_id,
transformer=transformer,
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
torch_dtype=torch.float16
)
# Enable CPU offloading to save GPU memory
# pipeline.enable_sequential_cpu_offload()
# Save the quantized pipeline
save_directory = "quantized_flux_pipeline"
os.makedirs(save_directory, exist_ok=True)
pipeline.save_pretrained(save_directory)
print(f"Quantized pipeline saved to {save_directory}")
i'm trying to save the pipeline as diffusers to reuse it directly ,, i have this error
python xyz.py Fetching 3 files: 100%|██████████████| 3/3 [00:00<00:00, 15630.95it/s] Downloading shards: 100%|██████████████| 2/2 [00:00<00:00, 569.88it/s] Loading checkpoint shards: 100%|████████| 2/2 [00:03<00:00, 1.66s/it] Couldn't connect to the Hub: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /api/models/sayakpaul/FLUX.1-merged (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x70ca0e6b4ad0>: Failed to establish a new connection: [Errno -3] Temporary failure in name resolution'))"), '(Request ID: 8775cf40-67d8-4ecd-9320-ee8eb06b9029)'). Will try to load from local cache. Loading pipeline components...: 40%|█▏ | 2/5 [00:00<00:00, 19.27it/s]You set
add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|███| 5/5 [00:00<00:00, 18.88it/s]
/home/abdallah/Desktop/webui/qunata-6h/.venv/lib/python3.12/site-packages/torchao-0.4.0+git99644e9-py3.12-linux-x86_64.egg/torchao/dtypes/utils.py:57: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
return func(*args, **kwargs)
Traceback (most recent call last):
File "/home/abdallah/Desktop/webui/qunata-6h/.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_torch.py", line 406, in storage_ptr
return tensor.untyped_storage().data_ptr()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Attempted to access the data pointer on an invalid python storage.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/abdallah/Desktop/webui/qunata-6h/xyz.py", line 39, in
pipeline.save_pretrained(save_directory)
File "/home/abdallah/Desktop/webui/qunata-6h/.venv/lib/python3.12/site-packages/diffusers/pipelines/pipeline_utils.py", line 288, in save_pretrained
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
File "/home/abdallah/Desktop/webui/qunata-6h/.venv/lib/python3.12/site-packages/diffusers/models/modeling_utils.py", line 344, in save_pretrained
state_dict_split = split_torch_state_dict_into_shards(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/abdallah/Desktop/webui/qunata-6h/.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_torch.py", line 330, in split_torch_state_dict_into_shards
return split_state_dict_into_shards_factory(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/abdallah/Desktop/webui/qunata-6h/.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_base.py", line 108, in split_state_dict_into_shards_factory
storage_id = get_storage_id(tensor)
^^^^^^^^^^^^^^^^^^^^^^
File "/home/abdallah/Desktop/webui/qunata-6h/.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_torch.py", line 359, in get_torch_storage_id
unique_id = storage_ptr(tensor)
^^^^^^^^^^^^^^^^^^^
File "/home/abdallah/Desktop/webui/qunata-6h/.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_torch.py", line 410, in storage_ptr
return tensor.storage().data_ptr()
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/abdallah/Desktop/webui/qunata-6h/.venv/lib/python3.12/site-packages/torch/storage.py", line 1220, in data_ptr
return self._data_ptr()
^^^^^^^^^^^^^^^^
File "/home/abdallah/Desktop/webui/qunata-6h/.venv/lib/python3.12/site-packages/torch/storage.py", line 1224, in _data_ptr
return self._untyped_storage.data_ptr()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Attempted to access the data pointer on an invalid python storage.`
Why are you using enable_sequential_cpu_offload()
before save_pretrained()
?
it was from old code i canceled it # its not used
i used it to generate image but it failed !
Are you using torch nightly?
Right, is this windows?
no its arch linux ,, whith xmonad desktop (haskell) , i compiled and designed
Oh okay, the error you're facing is a known issue: huggingface/safetensors#515. Cc: @jerryzh168.
okay i tried to save each alone
from diffusers import DiffusionPipeline, FluxTransformer2DModel, AutoencoderKL
from transformers import T5EncoderModel, CLIPTextModel
from torchao.quantization import quantize_, int8_weight_only
import torch
import os
ckpt_id = "sayakpaul/FLUX.1-merged"
def load_and_quantize(model_class, subfolder):
model = model_class.from_pretrained(
ckpt_id, subfolder=subfolder, torch_dtype=torch.float16
)
quantize_(model, int8_weight_only())
return model
# Load and quantize components
transformer = load_and_quantize(FluxTransformer2DModel, "transformer")
text_encoder = load_and_quantize(CLIPTextModel, "text_encoder")
text_encoder_2 = load_and_quantize(T5EncoderModel, "text_encoder_2")
vae = load_and_quantize(AutoencoderKL, "vae")
# Initialize the pipeline
pipeline = DiffusionPipeline.from_pretrained(
ckpt_id,
transformer=transformer,
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
torch_dtype=torch.float16
)
# Save the quantized pipeline components separately
save_directory = "quantized_flux_pipeline"
os.makedirs(save_directory, exist_ok=True)
# Move all components to CPU before saving
pipeline.to("cpu")
# Save each component separately
torch.save(pipeline.transformer.state_dict(), os.path.join(save_directory, "transformer.pt"))
torch.save(pipeline.vae.state_dict(), os.path.join(save_directory, "vae.pt"))
torch.save(pipeline.text_encoder.state_dict(), os.path.join(save_directory, "text_encoder.pt"))
torch.save(pipeline.text_encoder_2.state_dict(), os.path.join(save_directory, "text_encoder_2.pt"))
#forget those
# Save the config files
# pipeline.transformer.save_config(save_directory)
# pipeline.vae.save_config(save_directory)
#pipeline.text_encoder.save_config(save_directory)
# pipeline.text_encoder_2.save_config(save_directory)
print(f"Quantized pipeline components saved to {save_directory}")
# Optional: Clear CUDA cache and print memory usage
torch.cuda.empty_cache()
memory = (torch.cuda.memory_allocated() / 1024 / 1024 / 1024)
print(f"GPU memory usage: {memory:.3f} GB")
now trying to upload them again with original config ,,,, you are an experience , how to load them again to produce image , is it the same way as loading transformer > for each !
@sayakpaul I would recommend setting the dtype to FP32 instead of BF16 so the code will run on pre-ampere architectures as well.
import torch
from diffusers import FluxTransformer2DModel, DiffusionPipeline, AutoencoderKL
from transformers import T5EncoderModel, CLIPTextModel, CLIPTextConfig ,T5Config
import os
dtype, device = torch.bfloat16, "cuda"
ckpt_id = "sayakpaul/FLUX.1-merged"
save_directory = "quantized_flux_pipeline"
# Load transformer
with torch.device("meta"):
config = FluxTransformer2DModel.load_config(ckpt_id, subfolder="transformer")
transformer = FluxTransformer2DModel.from_config(config).to(dtype)
transformer_state_dict = torch.load(os.path.join(save_directory, "transformer.pt"), map_location="cpu")
transformer.load_state_dict(transformer_state_dict, assign=True)
# Load VAE
with torch.device("meta"):
vae_config = AutoencoderKL.load_config(ckpt_id, subfolder="vae")
vae = AutoencoderKL.from_config(vae_config).to(dtype)
vae_state_dict = torch.load(os.path.join(save_directory, "vae.pt"), map_location="cpu")
vae.load_state_dict(vae_state_dict, assign=True)
text_encoder_config = CLIPTextConfig(ckpt_id, subfolder="text_encoder")
text_encoder = CLIPTextModel(text_encoder_config)
text_encoder_state_dict = torch.load(os.path.join(save_directory, "text_encoder.pt"), map_location="cpu")
text_encoder.load_state_dict(text_encoder_state_dict, assign=True)
# Load text encoder 2
# with torch.device("meta"):
text_encoder_2_config = T5Config(ckpt_id, subfolder="text_encoder_2")
text_encoder_2 = T5EncoderModel(text_encoder_2_config)
text_encoder_2_state_dict = torch.load(os.path.join(save_directory, "text_encoder_2.pt"), map_location="cpu",)
text_encoder_2.load_state_dict(text_encoder_2_state_dict, assign=True)
# Initialize the pipeline
pipeline = DiffusionPipeline.from_pretrained(
ckpt_id,
transformer=transformer,
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
torch_dtype=dtype
).to(device)
# Generate image
image = pipeline(
"cat",
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256
).images[0]
image.save("flux_local_int8.png")
print("Image generated and saved as 'flux_local_int8.png'")
i failed with this !! im not that experience of diffusers
What is your error?
/home/abdallah/Desktop/webui/qunata-6h/runquanta.py:15: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
transformer_state_dict = torch.load(os.path.join(save_directory, "transformer.pt"), map_location="cpu")
/home/abdallah/Desktop/webui/qunata-6h/runquanta.py:23: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
vae_state_dict = torch.load(os.path.join(save_directory, "vae.pt"), map_location="cpu")
Traceback (most recent call last):
File "/home/abdallah/Desktop/webui/qunata-6h/runquanta.py", line 27, in <module>
text_encoder = CLIPTextModel(text_encoder_config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/abdallah/Desktop/webui/qunata-6h/.venv/lib/python3.12/site-packages/transformers/models/clip/modeling_clip.py", line 946, in __init__
self.text_model = CLIPTextTransformer(config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/abdallah/Desktop/webui/qunata-6h/.venv/lib/python3.12/site-packages/transformers/models/clip/modeling_clip.py", line 840, in __init__
self.embeddings = CLIPTextEmbeddings(config)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/abdallah/Desktop/webui/qunata-6h/.venv/lib/python3.12/site-packages/transformers/models/clip/modeling_clip.py", line 205, in __init__
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/abdallah/Desktop/webui/qunata-6h/.venv/lib/python3.12/site-packages/torch/nn/modules/sparse.py", line 167, in __init__
torch.empty((num_embeddings, embedding_dim), **factory_kwargs),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: empty() received an invalid combination of arguments - got (tuple, dtype=NoneType, device=NoneType), but expected one of:
* (tuple of ints size, *, tuple of names names, torch.memory_format memory_format = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
* (tuple of ints size, *, torch.memory_format memory_format = None, Tensor out = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
@tin2tin we don't support Windows really well simply because
torch.compile()
doesnt support Windows very well. Feel free to open an issue on our repo though we might choose to prioritze this depending on demand.For now try installing ao from source using
USE_CPP=0 pip install .
This works on windows, in case anyone wants to know
The commands on cmd prompt are actually:
set USE_CPP=0
pip install <path to downloaded torchao folder>
Edit 12-01-2025:
as of now, even this is throwing an error:
AttributeError: Can't get attribute 'PlainAQTLayout' on <module 'torchao.dtypes.affine_quantized_tensor' from '...Python\\Python311\\site-packages\\torchao\\dtypes\\affine_quantized_tensor.py'>
For those trying to use the one under 17GBs the link is dead, you can you use this instead:
import torch
from diffusers import FluxTransformer2DModel, DiffusionPipeline
dtype, device = torch.bfloat16, "cuda"
ckpt_id = "black-forest-labs/FLUX.1-schnell"
model = FluxTransformer2DModel.from_pretrained(
"sayakpaul/flux.1-schell-int8wo-improved", torch_dtype=dtype, use_safetensors=False
)
pipeline = DiffusionPipeline.from_pretrained(ckpt_id, transformer=model, torch_dtype=dtype).to("cuda")
image = pipeline(
"cat", guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256
).images[0]
image.save("flux_schnell_int8.png")
@MostHumble
Getting an error running that code:
AttributeError: Can't get attribute 'PlainAQTLayout' on <module 'torchao.dtypes.affine_quantized_tensor' from '...Python\\Python311\\site-packages\\torchao\\dtypes\\affine_quantized_tensor.py'>
On torchao == 0.7.0
@tin2tin
Ended up having the same issue, reported it here: https://huggingface.co/sayakpaul/flux.1-schell-int8wo-improved/discussions/2#6783e9e49c52f42c530ff7c1
Please take the issue with torchao
. Until it's resolved, either:
- Use
torchao
integration from Diffusers. - Downgrade
torchao
installation.
Other than that, I can't provide additional suggestions.