Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Last active October 22, 2024 03:02
Show Gist options
  • Save sayakpaul/a9266fe2d0d510ec44a9cdc385b3dd74 to your computer and use it in GitHub Desktop.
Save sayakpaul/a9266fe2d0d510ec44a9cdc385b3dd74 to your computer and use it in GitHub Desktop.
This code snippet shows how to split the Flux transformer across two 16GB GPUs and run inference with the full pipeline.
from diffusers import FluxPipeline, AutoencoderKL, FluxTransformer2DModel
from diffusers.image_processor import VaeImageProcessor
import torch
import gc
def flush():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
ckpt_id = "black-forest-labs/FLUX.1-dev"
prompt = "a photo of a cat with tiger-like look"
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
transformer=None,
vae=None,
device_map="balanced",
max_memory={0: "16GB", 1: "16GB"},
torch_dtype=torch.bfloat16
)
print(pipeline.hf_device_map)
with torch.no_grad():
print("Encoding prompts.")
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
prompt=prompt, prompt_2=None, max_sequence_length=512
)
print(prompt_embeds.shape)
del pipeline.text_encoder
del pipeline.text_encoder_2
del pipeline.tokenizer
del pipeline.tokenizer_2
del pipeline
flush()
transformer = FluxTransformer2DModel.from_pretrained(
ckpt_id,
subfolder="transformer",
device_map="auto",
max_memory={0: "16GB", 1: "16GB"},
torch_dtype=torch.bfloat16
)
print(transformer.hf_device_map)
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
text_encoder=None,
text_encoder_2=None,
tokenizer=None,
tokenizer_2=None,
vae=None,
transformer=transformer,
torch_dtype=torch.bfloat16
)
print("Running denoising.")
height, width = 768, 1360
# No need to wrap it up under `torch.no_grad()` as pipeline call method
# is already wrapped under that.
latents = pipeline(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=50,
guidance_scale=3.5,
height=height,
width=width,
output_type="latent",
).images
print(latents.shape)
del pipeline.transformer
del pipeline
flush()
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16).to("cuda")
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
with torch.no_grad():
print("Running decoding.")
latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
image = vae.decode(latents, return_dict=False)[0]
image = image_processor.postprocess(image, output_type="pil")
image[0].save("split_transformer.png")
@sayakpaul
Copy link
Author

I didn't face any problems without it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment