-
-
Save AmericanPresidentJimmyCarter/873985638e1f3541ba8b00137e7dacd9 to your computer and use it in GitHub Desktop.
# First, in your terminal. | |
# | |
# $ python3 -m virtualenv env | |
# $ source env/bin/activate | |
# $ pip install torch torchvision transformers sentencepiece protobuf accelerate | |
# $ pip install git+https://github.com/huggingface/diffusers.git | |
# $ pip install optimum-quanto | |
import torch | |
from optimum.quanto import freeze, qfloat8, quantize | |
from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL | |
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel | |
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline | |
from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast | |
dtype = torch.bfloat16 | |
# schnell is the distilled turbo model. For the CFG distilled model, use: | |
# bfl_repo = "black-forest-labs/FLUX.1-dev" | |
# revision = "refs/pr/3" | |
# | |
# The undistilled model that uses CFG ("pro") which can use negative prompts | |
# was not released. | |
bfl_repo = "black-forest-labs/FLUX.1-schnell" | |
revision = "refs/pr/1" | |
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder="scheduler", revision=revision) | |
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=dtype) | |
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=dtype) | |
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype, revision=revision) | |
tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype, revision=revision) | |
vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder="vae", torch_dtype=dtype, revision=revision) | |
transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder="transformer", torch_dtype=dtype, revision=revision) | |
# Experimental: Try this to load in 4-bit for <16GB cards. | |
# | |
# from optimum.quanto import qint4 | |
# quantize(transformer, weights=qint4, exclude=["proj_out", "x_embedder", "norm_out", "context_embedder"]) | |
# freeze(transformer) | |
quantize(transformer, weights=qfloat8) | |
freeze(transformer) | |
quantize(text_encoder_2, weights=qfloat8) | |
freeze(text_encoder_2) | |
pipe = FluxPipeline( | |
scheduler=scheduler, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
text_encoder_2=None, | |
tokenizer_2=tokenizer_2, | |
vae=vae, | |
transformer=None, | |
) | |
pipe.text_encoder_2 = text_encoder_2 | |
pipe.transformer = transformer | |
pipe.enable_model_cpu_offload() | |
generator = torch.Generator().manual_seed(12345) | |
image = pipe( | |
prompt='nekomusume cat girl, digital painting', | |
width=1024, | |
height=1024, | |
num_inference_steps=4, | |
generator=generator, | |
guidance_scale=3.5, | |
).images[0] | |
image.save('test_flux_distilled.png') |
Got it finally running with some help from here
# quantize text_encoder_2 qfloat8
print("start loading text_encoder_2...")
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)#, revision=revision)
print("start quantizing text_encoder_2...")
quantize(text_encoder_2, weights=qfloat8)
print("start freezing text_encoder_2...")
freeze(text_encoder_2)
print("start saving text_encoder_2...")
text_encoder_2.save_pretrained(f"{bfl_repo}/q_text_encoder_2")
# Save quantization map to be able to reload the model
qmap_name = os.path.join(f"{bfl_repo}/q_text_encoder_2", f"{QuantizedTransformersModel.BASE_NAME}_qmap.json")
qmap = quantization_map(text_encoder_2)
with open(qmap_name, "w", encoding="utf8") as f:
json.dump(qmap, f, indent=4)
print("start loading text_encoder_2...")
T5EncoderModel.from_config = lambda c: T5EncoderModel(c) # Duck and tape for Quanto support.
text_encoder_2 = QuantizedT5Model.from_pretrained(f"{bfl_repo}/q_text_encoder_2")#, torch_dytpe=dtype)
Still need some help on another thing: I cannot load the quantized transformer. I guess
# manual classes are necessary since optimum.quanto does not support these yet
class QuantizedFlux2DModel(QuantizedDiffusersModel):
base_class = FluxTransformer2DModel
class QuantizedT5Model(QuantizedTransformersModel):
auto_class = T5EncoderModel
# quantize transformer qfloat8
transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder="transformer", torch_dtype=dtype)#, revision=revision)
quantize(transformer, weights=qfloat8)
freeze(transformer)
# save transformer qfloat8
transformer.save_pretrained(f"{bfl_repo}/q_transformer_qfloat8")
# Save quantization map to be able to reload the model
qmap_name = os.path.join(f"{bfl_repo}/q_transformer_qfloat8", f"{QuantizedDiffusersModel.BASE_NAME}_qmap.json")
qmap = quantization_map(transformer)
with open(qmap_name, "w", encoding="utf8") as f:
json.dump(qmap, f, indent=4)
# load transformer # currently NOT WORKING
transformer = QuantizedFlux2DModel.from_pretrained(f"{bfl_repo}/q_transformer_qfloat8", torch_dtype=dtype) # THIS IS WHAT I WOULD PROBABLY NEED OR SOMETHIN LIKE (low_cpu_mem_usage=True, device_map='auto')
transformer = QuantizedFlux2DModel.from_pretrained(f"{bfl_repo}/q_transformer_qfloat8").to(torch_dtype=dtype) # THIS WOULD PROBABLY RUN BUT NOT ON MY MACHINE
When I run this code:
transformer = QuantizedFlux2DModel.from_pretrained(f"{bfl_repo}/q_transformer_qfloat8", torch_dtype=dtype)
I get: "TypeError: QuantizedDiffusersModel.from_pretrained() got an unexpected keyword argument 'torch_dtype'" which arises since there's no torch_dtype as argument here
@sayakpaul Do you have an idea how to solve this? Thank you!
A simple torch.save() / torch.load() combo for saving the whole model works fine and is relatively fast (loading the transformer takes about 10 seconds and the text encoder ~4 seconds). You'll get security warnings from pickle but they can be ignored in this case. Also make sure to put model.eval() after torch.load() as stated here: https://pytorch.org/tutorials/beginner/saving_loading_models.html.
Code example:
...
# quantizing and saving
quantize(transformer, weights=qfloat8)
freeze(transformer)
torch.save(transformer, 'D:/models/transformer.pt')
# loading
transformer = torch.load('D:/models/transformer.pt')
transformer.eval()
...
A simple torch.save() / torch.load() combo for saving the whole model works fine and is relatively fast (loading the transformer takes about 10 seconds and the text encoder ~4 seconds). You'll get security warnings from pickle but they can be ignored in this case. Also make sure to put model.eval() after torch.load() as stated here: https://pytorch.org/tutorials/beginner/saving_loading_models.html.
Code example:
... # quantizing and saving quantize(transformer, weights=qfloat8) freeze(transformer) torch.save(transformer, 'D:/models/transformer.pt') # loading transformer = torch.load('D:/models/transformer.pt') transformer.eval() ...
thanks,I test it , relatively fast. Code example:
#loading
transformer = torch.load('D:/models/transformer.pt')
transformer.eval()
pipe=FluxPipeline(...)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2
stupid question but, what are we supposed to do with this code exactly to get it work ?
transformer = torch.load('D:/models/transformer.pt')
transformer.eval()
Thanks. That gets the run time down from 2m45s to 1m01s.
Using the same logic to save text_encoders_2 and load it gets the total runtime down to 25 seconds.
Can you modify the script to support the various Flux LoRAs out there at the moment?
I tried adding in
pipeline.load_lora_weights(adapter_id)
from here
https://huggingface.co/zouzoumaki/flux-loras
but that has issues with optimum.
LoRA support would be good. Especially now that with the loading of the models the whole script takes only 25 seconds.
You need to
pipe.load_lora_weights("./pytorch_lora_weights.safetensors")
pipe.fuse_lora(lora_scale=0.125)
pipe.unload_lora_weights()
But it takes a long time. I am still waiting for a better way to do this with quanto.
You need to
pipe.load_lora_weights("./pytorch_lora_weights.safetensors") pipe.fuse_lora(lora_scale=0.125) pipe.unload_lora_weights()But it takes a long time. I am still waiting for a better way to do this with quanto.
Like this?
pipe = FluxPipeline(
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
vae=vae,
transformer=transformer,
)
pipe.load_lora_weights("./pytorch_lora_weights.safetensors")
pipe.fuse_lora(lora_scale=0.125)
pipe.unload_lora_weights()
because that gives these errors
Traceback (most recent call last):
File "flux_on_potato.py", line 146, in <module>
pipe.load_lora_weights("./pytorch_lora_weights.safetensors")
File "venv\lib\site-packages\diffusers\loaders\lora_pipeline.py", line 1620, in load_lora_weights
self.load_lora_into_transformer(
File "venv\lib\site-packages\diffusers\loaders\lora_pipeline.py", line 1700, in load_lora_into_transformer
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
File "venv\lib\site-packages\peft\utils\save_and_load.py", line 395, in set_peft_model_state_dict
load_result = model.load_state_dict(peft_model_state_dict, strict=False)
File "venv\lib\site-packages\torch\nn\modules\module.py", line 2201, in load_state_dict
load(self, state_dict)
File "venv\lib\site-packages\torch\nn\modules\module.py", line 2189, in load
load(child, child_state_dict, child_prefix) # noqa: F821
File "venv\lib\site-packages\torch\nn\modules\module.py", line 2189, in load
load(child, child_state_dict, child_prefix) # noqa: F821
File "venv\lib\site-packages\torch\nn\modules\module.py", line 2189, in load
load(child, child_state_dict, child_prefix) # noqa: F821
File "venv\lib\site-packages\torch\nn\modules\module.py", line 2183, in load
module._load_from_state_dict(
File "venv\lib\site-packages\optimum\quanto\nn\qmodule.py", line 159, in _load_from_state_dict
deserialized_weight = QBytesTensor.load_from_state_dict(
File "venv\lib\site-packages\optimum\quanto\tensor\qbytes.py", line 90, in load_from_state_dict
inner_tensors_dict[name] = state_dict.pop(prefix + name)
KeyError: 'time_text_embed.timestep_embedder.linear_1.weight._data'
Yeah, something about that lora isn't working :( You can open an issue at https://github.com/huggingface/diffusers/issues
Yeah, something about that lora isn't working :( You can open an issue at https://github.com/huggingface/diffusers/issues
So you can use other LoRAs with that syntax? Can you share one that works?
I tried making the edit in the script for quantize, ie
#quantize(transformer, weights=qfloat8) quantize(transformer, weights=qint4, exclude=["proj_out", "x_embedder", "norm_out", "context_embedder"])
You also need to edit the import for optimum.quanto to
from optimum.quanto import freeze, qfloat8, qint4, quantize
so it know what qint4 is. But the script was even slower as it seems to only use CPU now? I gave up waiting after iterations 0/4 just sat there.
I noticed the same thing, too, 0% GPU usage and only CPU usage. Were you ever able to resolve this?
Edit: I should note that this seems to happen only on Windows. On my Ubuntu machine, the GPU does get utilized.
I never got an answer so I gave up on LoRA support.
you need to add lora before quantizing
In my case, qfloat8 works well on V100. But 4090 needs qfloat8_e5m2, don't really know the reson tho :)
Got it almost running by manually saving the qmap; and also changed to "auto_class = T5EncoderModel":
Now, its saying "AttributeError: type object 'T5EncoderModel' has no attribute 'from_config'. Did you mean: '_from_config'?"