Last active
August 12, 2024 07:52
-
-
Save sayakpaul/82acb5976509851f2db1a83456e504f1 to your computer and use it in GitHub Desktop.
The code snippet shows how to run Stable Diffusion 3 with a 8bit T5-xxl, drastically reducing the memory requirements.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from diffusers import StableDiffusion3Pipeline | |
from transformers import T5EncoderModel | |
import torch | |
import time | |
import gc | |
def flush(): | |
gc.collect() | |
torch.cuda.empty_cache() | |
def bytes_to_giga_bytes(bytes): | |
return bytes / 1024 / 1024 / 1024 | |
id = "stabilityai/stable-diffusion-3-medium-diffusers" | |
text_encoder = T5EncoderModel.from_pretrained( | |
id, | |
subfolder="text_encoder_3", | |
load_in_8bit=True, | |
device_map="auto" | |
) | |
pipeline = StableDiffusion3Pipeline.from_pretrained( | |
id, | |
text_encoder_3=text_encoder, | |
transformer=None, | |
vae=None, | |
device_map="balanced", | |
) | |
with torch.no_grad(): | |
for _ in range(3): | |
prompt = "a photo of a cat" | |
( | |
prompt_embeds, | |
negative_prompt_embeds, | |
pooled_prompt_embeds, | |
negative_pooled_prompt_embeds, | |
) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, prompt_3=None) | |
start = time.time() | |
for _ in range(10): | |
( | |
prompt_embeds, | |
negative_prompt_embeds, | |
pooled_prompt_embeds, | |
negative_pooled_prompt_embeds, | |
) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, prompt_3=None) | |
end = time.time() | |
avg_prompt_encoding_time = (end - start) / 10 | |
del text_encoder | |
del pipeline | |
flush() | |
pipeline = StableDiffusion3Pipeline.from_pretrained( | |
id, | |
text_encoder=None, | |
text_encoder_2=None, | |
text_encoder_3=None, | |
tokenizer=None, | |
tokenizer_2=None, | |
tokenizer_3=None, | |
torch_dtype=torch.float16 | |
).to("cuda") | |
pipeline.set_progress_bar_config(disable=True) | |
for _ in range(3): | |
_ = pipeline( | |
prompt_embeds=prompt_embeds.half(), | |
negative_prompt_embeds=negative_prompt_embeds.half(), | |
pooled_prompt_embeds=pooled_prompt_embeds.half(), | |
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.half(), | |
) | |
start = time.time() | |
for _ in range(10): | |
_ = pipeline( | |
prompt_embeds=prompt_embeds.half(), | |
negative_prompt_embeds=negative_prompt_embeds.half(), | |
pooled_prompt_embeds=pooled_prompt_embeds.half(), | |
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.half(), | |
) | |
end = time.time() | |
avg_inference_time = (end - start) / 10 | |
print(f"Average prompt encoding time: {avg_prompt_encoding_time:.3f} seconds.") | |
print(f"Average inference time: {avg_inference_time:.3f} seconds.") | |
print(f"Total time: {(avg_prompt_encoding_time + avg_inference_time):.3f} seconds.") | |
print( | |
f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB" | |
) | |
image = pipeline( | |
prompt_embeds=prompt_embeds.half(), | |
negative_prompt_embeds=negative_prompt_embeds.half(), | |
pooled_prompt_embeds=pooled_prompt_embeds.half(), | |
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.half(), | |
).images[0] | |
image.save("output_8bit.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment