Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Last active April 7, 2025 21:44
Show Gist options
  • Save sayakpaul/23862a2e7f5ab73dfdcc513751289bea to your computer and use it in GitHub Desktop.
Save sayakpaul/23862a2e7f5ab73dfdcc513751289bea to your computer and use it in GitHub Desktop.
This gist shows how to run Flux on a 24GB 4090 card with Diffusers.
from diffusers import FluxPipeline, AutoencoderKL
from diffusers.image_processor import VaeImageProcessor
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import gc
def flush():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def bytes_to_giga_bytes(bytes):
return bytes / 1024 / 1024 / 1024
flush()
ckpt_id = "black-forest-labs/FLUX.1-schnell"
prompt = "a photo of a dog with cat-like look"
text_encoder = CLIPTextModel.from_pretrained(
ckpt_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
)
text_encoder_2 = T5EncoderModel.from_pretrained(
ckpt_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
)
tokenizer = CLIPTokenizer.from_pretrained(ckpt_id, subfolder="tokenizer")
tokenizer_2 = T5TokenizerFast.from_pretrained(ckpt_id, subfolder="tokenizer_2")
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=None,
vae=None,
).to("cuda")
with torch.no_grad():
print("Encoding prompts.")
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
prompt=prompt, prompt_2=None, max_sequence_length=256
)
del text_encoder
del text_encoder_2
del tokenizer
del tokenizer_2
del pipeline
flush()
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
text_encoder=None,
text_encoder_2=None,
tokenizer=None,
tokenizer_2=None,
vae=None,
torch_dtype=torch.bfloat16,
).to("cuda")
print("Running denoising.")
height, width = 768, 1360
# No need to wrap it up under `torch.no_grad()` as pipeline call method
# is already wrapped under that.
latents = pipeline(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=4,
guidance_scale=0.0,
height=height,
width=width,
output_type="latent",
).images
print(f"{latents.shape=}")
del pipeline.transformer
del pipeline
flush()
vae = AutoencoderKL.from_pretrained(ckpt_id, revision="refs/pr/1", subfolder="vae", torch_dtype=torch.bfloat16).to(
"cuda"
)
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
with torch.no_grad():
print("Running decoding.")
latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
image = vae.decode(latents, return_dict=False)[0]
image = image_processor.postprocess(image, output_type="pil")
image[0].save("image.png")
@toilaluan
Copy link

Simply version that works well on RTX 4090. Avg 4.5->5s for text encoder, 1->1.2 it/s for denoise model.

from transformers import T5EncoderModel
import time
import gc
import torch
import diffusers

def flush():
    gc.collect()
    torch.cuda.empty_cache()

t5_encoder = T5EncoderModel.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2", revision="refs/pr/7", torch_dtype=torch.bfloat16
)
text_encoder = diffusers.DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    text_encoder_2=t5_encoder,
    transformer=None,
    vae=None,
    revision="refs/pr/7",
)
pipeline = diffusers.DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", 
    torch_dtype=torch.bfloat16,
    revision="refs/pr/1",
    text_encoder_2=None,
    text_encoder=None,
)
pipeline.enable_model_cpu_offload()

@torch.inference_mode()
def inference(self, prompt, num_inference_steps=4, guidance_scale=0.0, width=1024, height=1024):
    self.text_encoder.to("cuda")
    start = time.time()
    (
        prompt_embeds,
        pooled_prompt_embeds,
        _,
    ) = self.text_encoder.encode_prompt(prompt=prompt, prompt_2=None, max_sequence_length=256)
    self.text_encoder.to("cpu")
    flush()
    print(f"Prompt encoding time: {time.time() - start}")
    output = self.pipeline(
        prompt_embeds=prompt_embeds.bfloat16(),
        pooled_prompt_embeds=pooled_prompt_embeds.bfloat16(),
        width=width,
        height=height,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps
    )
    image = output.images[0]
    return image

@sayakpaul
Copy link
Author

Thanks! This is lots cleaner.

@bil-ash
Copy link

bil-ash commented Aug 3, 2024

Simply version that works well on RTX 4090. Avg 4.5->5s for text encoder, 1->1.2 it/s for denoise model.

from transformers import T5EncoderModel
import time
import gc
import torch
import diffusers

def flush():
    gc.collect()
    torch.cuda.empty_cache()

t5_encoder = T5EncoderModel.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2", revision="refs/pr/7", torch_dtype=torch.bfloat16
)
text_encoder = diffusers.DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    text_encoder_2=t5_encoder,
    transformer=None,
    vae=None,
    revision="refs/pr/7",
)
pipeline = diffusers.DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", 
    torch_dtype=torch.bfloat16,
    revision="refs/pr/1",
    text_encoder_2=None,
    text_encoder=None,
)
pipeline.enable_model_cpu_offload()

@torch.inference_mode()
def inference(self, prompt, num_inference_steps=4, guidance_scale=0.0, width=1024, height=1024):
    self.text_encoder.to("cuda")
    start = time.time()
    (
        prompt_embeds,
        pooled_prompt_embeds,
        _,
    ) = self.text_encoder.encode_prompt(prompt=prompt, prompt_2=None, max_sequence_length=256)
    self.text_encoder.to("cpu")
    flush()
    print(f"Prompt encoding time: {time.time() - start}")
    output = self.pipeline(
        prompt_embeds=prompt_embeds.bfloat16(),
        pooled_prompt_embeds=pooled_prompt_embeds.bfloat16(),
        width=width,
        height=height,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps
    )
    image = output.images[0]
    return image

@toilaluan Can you make it work under 16GB VRAM so that it can run on RTX A4000?

@james-imi
Copy link

@bil-ash why is inference using self parameters?

@mortenmoulder
Copy link

@toilaluan I've been trying to run your example, although unsuccessfully. I keep getting errors such as:

AttributeError: 'NoneType' object has no attribute 'to'

I'm guessing that's because self.text_encoder is null (or equivalent in Python).

Removing the self portion of the code does work, but on my 4090 it utilizes the GPU 100% in both RAM and computing, but your seconds are minutes on my end. Any clue what I'm doing wrong?

Oh and I just added this at the bottom:

if __name__ == "__main__":
    prompt = "A cat holding a sign that says hello world"
    generated_image = inference(pipeline, prompt)

    generated_image.save("output.png")

@sumitmamoria
Copy link

Simply version that works well on RTX 4090. Avg 4.5->5s for text encoder, 1->1.2 it/s for denoise model.

from transformers import T5EncoderModel
import time
import gc
import torch
import diffusers

def flush():
    gc.collect()
    torch.cuda.empty_cache()

t5_encoder = T5EncoderModel.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2", revision="refs/pr/7", torch_dtype=torch.bfloat16
)
text_encoder = diffusers.DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    text_encoder_2=t5_encoder,
    transformer=None,
    vae=None,
    revision="refs/pr/7",
)
pipeline = diffusers.DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", 
    torch_dtype=torch.bfloat16,
    revision="refs/pr/1",
    text_encoder_2=None,
    text_encoder=None,
)
pipeline.enable_model_cpu_offload()

@torch.inference_mode()
def inference(self, prompt, num_inference_steps=4, guidance_scale=0.0, width=1024, height=1024):
    self.text_encoder.to("cuda")
    start = time.time()
    (
        prompt_embeds,
        pooled_prompt_embeds,
        _,
    ) = self.text_encoder.encode_prompt(prompt=prompt, prompt_2=None, max_sequence_length=256)
    self.text_encoder.to("cpu")
    flush()
    print(f"Prompt encoding time: {time.time() - start}")
    output = self.pipeline(
        prompt_embeds=prompt_embeds.bfloat16(),
        pooled_prompt_embeds=pooled_prompt_embeds.bfloat16(),
        width=width,
        height=height,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps
    )
    image = output.images[0]
    return image

How to use a Lora with this code? Thanks!

@jason-engage
Copy link

@toilaluan You're refactor works great. Anybody know how to get the image-to-image to work with 4090? I haven't been able to modify this script successfully.

@jordyBonnet
Copy link

@toilaluan I've been trying to run your example, although unsuccessfully. I keep getting errors such as:

AttributeError: 'NoneType' object has no attribute 'to'

I'm guessing that's because self.text_encoder is null (or equivalent in Python).

Removing the self portion of the code does work, but on my 4090 it utilizes the GPU 100% in both RAM and computing, but your seconds are minutes on my end. Any clue what I'm doing wrong?

Oh and I just added this at the bottom:

if __name__ == "__main__":
    prompt = "A cat holding a sign that says hello world"
    generated_image = inference(pipeline, prompt)

    generated_image.save("output.png")

same on my side (win 11)

@fnauman
Copy link

fnauman commented Dec 3, 2024

Only works for 512x512 image size on a 4090 for me, 1024x1024 goes out of memory.

Example:
python main.py "a photo of a dog with cat-like look"

from transformers import T5EncoderModel
import time
import gc
import torch
import diffusers
import argparse
from PIL import Image
import os

def flush():
    gc.collect()
    torch.cuda.empty_cache()

class FluxSchnell:
    def __init__(self):
        self.t5_encoder = T5EncoderModel.from_pretrained(
            "black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2", revision="refs/pr/7", torch_dtype=torch.bfloat16
        )
        self.text_encoder = diffusers.DiffusionPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-schnell",
            text_encoder_2=self.t5_encoder,
            transformer=None,
            vae=None,
            revision="refs/pr/7",
        )
        self.pipeline = diffusers.DiffusionPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-schnell", 
            torch_dtype=torch.bfloat16,
            revision="refs/pr/1",
            text_encoder_2=None,
            text_encoder=None,
        )
        self.pipeline.enable_model_cpu_offload()

    @torch.inference_mode()
    def inference(self, prompt, num_inference_steps=4, guidance_scale=0.0, width=1024, height=1024):
        self.text_encoder.to("cuda")
        start = time.time()
        (
            prompt_embeds,
            pooled_prompt_embeds,
            _,
        ) = self.text_encoder.encode_prompt(prompt=prompt, prompt_2=None, max_sequence_length=256)
        self.text_encoder.to("cpu")

        flush()
        
        print(f"Prompt encoding time: {time.time() - start}")
        
        output = self.pipeline(
            prompt_embeds=prompt_embeds.bfloat16(),
            pooled_prompt_embeds=pooled_prompt_embeds.bfloat16(),
            width=width,
            height=height,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps
        )
        image = output.images[0]
        return image


def main():
    parser = argparse.ArgumentParser(description='Generate images using FLUX.1-schnell model')
    parser.add_argument('prompt', type=str, help='Enter your prompt here')
    parser.add_argument('--steps', type=int, default=4, help='Number of inference steps (default: 4)')
    parser.add_argument('--guidance-scale', type=float, default=0.0, help='Guidance scale (default: 0.0)')
    parser.add_argument('--width', type=int, default=512, help='Image width (default: 512)')
    parser.add_argument('--height', type=int, default=512, help='Image height (default: 512)')
    parser.add_argument('--output', type=str, default='output.png', help='Output image path (default: output.png)')
    
    args = parser.parse_args()
    
    model = FluxSchnell()
    image = model.inference(
        prompt=args.prompt,
        num_inference_steps=args.steps,
        guidance_scale=args.guidance_scale,
        width=args.width,
        height=args.height
    )
    
    # Save the image
    image.save(args.output)
    print(f"Image saved to: {os.path.abspath(args.output)}")

if __name__ == '__main__':
    main()

@jayendramadaram
Copy link

Awesome! works smoothly on 4090! any idea how to run it with lora ? loading lora into pipeline and running it is not working

@thomsonm685
Copy link

Awesome! works smoothly on 4090! any idea how to run it with lora ? loading lora into pipeline and running it is not working

@jayendramadaram , Here's my code loading my LORA (you'll need to 'pip install peft'):

from transformers import T5EncoderModel
import time
import gc
import torch
import diffusers

def flush():
    gc.collect()
    torch.cuda.empty_cache()
    
class FluxSchnell:
    def __init__(self):
        self.t5_encoder = T5EncoderModel.from_pretrained(
            "black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2", revision="refs/pr/7", torch_dtype=torch.bfloat16
        )
        self.text_encoder = diffusers.DiffusionPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-schnell",
            text_encoder_2=self.t5_encoder,
            transformer=None,
            vae=None,
            revision="refs/pr/7",
        )
        self.pipeline = diffusers.DiffusionPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-schnell", 
            torch_dtype=torch.bfloat16,
            revision="refs/pr/1",
            text_encoder_2=None,
            text_encoder=None,
        )
        self.pipeline.enable_model_cpu_offload()
        self.pipeline.load_lora_weights("./path/to/lora")

    print("HERE")

    @torch.inference_mode()
    def inference(self, prompt, filename, num_inference_steps=4, guidance_scale=0.0, width=1024, height=1024):
        self.text_encoder.to("cuda")
        start = time.time()
        (
            prompt_embeds,
            pooled_prompt_embeds,
            _,
        ) = self.text_encoder.encode_prompt(prompt=prompt, prompt_2=None, max_sequence_length=256)
        self.text_encoder.to("cpu")
        flush()
        print(f"Prompt encoding time: {time.time() - start}")
        output = self.pipeline(
            prompt_embeds=prompt_embeds.bfloat16(),
            pooled_prompt_embeds=pooled_prompt_embeds.bfloat16(),
            width=width,
            height=height,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps
        )
        image = output.images[0]
    
        image.save(filename)
        print("FINISHED")
        return image

model = FluxSchnell()
generated_image = model.inference(prompt="some really cool stuff", filename="output2.png")

@Vladimir-Urik
Copy link

torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 90.00 MiB. GPU 0 has a total capacity of 19.56 GiB of which 6.25 MiB is free. Including non-PyTorch memory, this process has 19.54 GiB memory in use. Of the allocated memory 19.35 GiB is allocated by PyTorch, and 6.14 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment