Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Last active August 12, 2024 07:52
Show Gist options
  • Save sayakpaul/82acb5976509851f2db1a83456e504f1 to your computer and use it in GitHub Desktop.
Save sayakpaul/82acb5976509851f2db1a83456e504f1 to your computer and use it in GitHub Desktop.
The code snippet shows how to run Stable Diffusion 3 with a 8bit T5-xxl, drastically reducing the memory requirements.
from diffusers import StableDiffusion3Pipeline
from transformers import T5EncoderModel
import torch
import time
import gc
def flush():
gc.collect()
torch.cuda.empty_cache()
def bytes_to_giga_bytes(bytes):
return bytes / 1024 / 1024 / 1024
id = "stabilityai/stable-diffusion-3-medium-diffusers"
text_encoder = T5EncoderModel.from_pretrained(
id,
subfolder="text_encoder_3",
load_in_8bit=True,
device_map="auto"
)
pipeline = StableDiffusion3Pipeline.from_pretrained(
id,
text_encoder_3=text_encoder,
transformer=None,
vae=None,
device_map="balanced",
)
with torch.no_grad():
for _ in range(3):
prompt = "a photo of a cat"
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, prompt_3=None)
start = time.time()
for _ in range(10):
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, prompt_3=None)
end = time.time()
avg_prompt_encoding_time = (end - start) / 10
del text_encoder
del pipeline
flush()
pipeline = StableDiffusion3Pipeline.from_pretrained(
id,
text_encoder=None,
text_encoder_2=None,
text_encoder_3=None,
tokenizer=None,
tokenizer_2=None,
tokenizer_3=None,
torch_dtype=torch.float16
).to("cuda")
pipeline.set_progress_bar_config(disable=True)
for _ in range(3):
_ = pipeline(
prompt_embeds=prompt_embeds.half(),
negative_prompt_embeds=negative_prompt_embeds.half(),
pooled_prompt_embeds=pooled_prompt_embeds.half(),
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.half(),
)
start = time.time()
for _ in range(10):
_ = pipeline(
prompt_embeds=prompt_embeds.half(),
negative_prompt_embeds=negative_prompt_embeds.half(),
pooled_prompt_embeds=pooled_prompt_embeds.half(),
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.half(),
)
end = time.time()
avg_inference_time = (end - start) / 10
print(f"Average prompt encoding time: {avg_prompt_encoding_time:.3f} seconds.")
print(f"Average inference time: {avg_inference_time:.3f} seconds.")
print(f"Total time: {(avg_prompt_encoding_time + avg_inference_time):.3f} seconds.")
print(
f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)
image = pipeline(
prompt_embeds=prompt_embeds.half(),
negative_prompt_embeds=negative_prompt_embeds.half(),
pooled_prompt_embeds=pooled_prompt_embeds.half(),
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.half(),
).images[0]
image.save("output_8bit.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment