Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Last active April 7, 2025 21:44
Show Gist options
  • Save sayakpaul/23862a2e7f5ab73dfdcc513751289bea to your computer and use it in GitHub Desktop.
Save sayakpaul/23862a2e7f5ab73dfdcc513751289bea to your computer and use it in GitHub Desktop.
This gist shows how to run Flux on a 24GB 4090 card with Diffusers.
from diffusers import FluxPipeline, AutoencoderKL
from diffusers.image_processor import VaeImageProcessor
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import gc
def flush():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def bytes_to_giga_bytes(bytes):
return bytes / 1024 / 1024 / 1024
flush()
ckpt_id = "black-forest-labs/FLUX.1-schnell"
prompt = "a photo of a dog with cat-like look"
text_encoder = CLIPTextModel.from_pretrained(
ckpt_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
)
text_encoder_2 = T5EncoderModel.from_pretrained(
ckpt_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
)
tokenizer = CLIPTokenizer.from_pretrained(ckpt_id, subfolder="tokenizer")
tokenizer_2 = T5TokenizerFast.from_pretrained(ckpt_id, subfolder="tokenizer_2")
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=None,
vae=None,
).to("cuda")
with torch.no_grad():
print("Encoding prompts.")
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
prompt=prompt, prompt_2=None, max_sequence_length=256
)
del text_encoder
del text_encoder_2
del tokenizer
del tokenizer_2
del pipeline
flush()
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
text_encoder=None,
text_encoder_2=None,
tokenizer=None,
tokenizer_2=None,
vae=None,
torch_dtype=torch.bfloat16,
).to("cuda")
print("Running denoising.")
height, width = 768, 1360
# No need to wrap it up under `torch.no_grad()` as pipeline call method
# is already wrapped under that.
latents = pipeline(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=4,
guidance_scale=0.0,
height=height,
width=width,
output_type="latent",
).images
print(f"{latents.shape=}")
del pipeline.transformer
del pipeline
flush()
vae = AutoencoderKL.from_pretrained(ckpt_id, revision="refs/pr/1", subfolder="vae", torch_dtype=torch.bfloat16).to(
"cuda"
)
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
with torch.no_grad():
print("Running decoding.")
latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
image = vae.decode(latents, return_dict=False)[0]
image = image_processor.postprocess(image, output_type="pil")
image[0].save("image.png")
@Vladimir-Urik
Copy link

torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 90.00 MiB. GPU 0 has a total capacity of 19.56 GiB of which 6.25 MiB is free. Including non-PyTorch memory, this process has 19.54 GiB memory in use. Of the allocated memory 19.35 GiB is allocated by PyTorch, and 6.14 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment