Skip to content

Instantly share code, notes, and snippets.

@ariG23498
Created October 20, 2024 09:01
Show Gist options
  • Save ariG23498/948c263116886b3aafae95e69ac3336a to your computer and use it in GitHub Desktop.
Save ariG23498/948c263116886b3aafae95e69ac3336a to your computer and use it in GitHub Desktop.
Run FLUX Dev under 8gbs of VRAM.
# Taken from: https://gist.github.com/sayakpaul/23862a2e7f5ab73dfdcc513751289bea
from diffusers import FluxPipeline, FluxTransformer2DModel
from transformers import T5EncoderModel
import torch
import gc
def flush():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def bytes_to_giga_bytes(bytes):
return bytes / 1024 / 1024 / 1024
flush()
ckpt_id = "black-forest-labs/FLUX.1-dev"
ckpt_4bit_id = "sayakpaul/flux.1-dev-nf4-pkg"
prompt = "a cute dog in paris photoshoot"
text_encoder_2_4bit = T5EncoderModel.from_pretrained(
ckpt_4bit_id,
subfolder="text_encoder_2",
)
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
text_encoder_2=text_encoder_2_4bit,
transformer=None,
vae=None,
torch_dtype=torch.float16,
)
pipeline.enable_model_cpu_offload()
with torch.no_grad():
print("Encoding prompts.")
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
prompt=prompt, prompt_2=None, max_sequence_length=256
)
pipeline = pipeline.to("cpu")
del pipeline
flush()
transformer_4bit = FluxTransformer2DModel.from_pretrained(ckpt_4bit_id, subfolder="transformer")
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
text_encoder=None,
text_encoder_2=None,
tokenizer=None,
tokenizer_2=None,
transformer=transformer_4bit,
torch_dtype=torch.float16,
)
pipeline.enable_model_cpu_offload()
print("Running denoising.")
height, width = 512, 768
images = pipeline(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=50,
guidance_scale=5.5,
height=height,
width=width,
output_type="pil",
).images
images[0].save("output.png")
@tin2tin
Copy link

tin2tin commented Oct 20, 2024

Getting this error:
python\Lib\site-packages\transformers\modeling_utils.py", line 2826, in to raise ValueError( ValueError: .tois not supported for4-bitor8-bitbitsandbytes models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correctdtype.

@ariG23498
Copy link
Author

Can you upgrade your transformers package?

pip install --upgrade transformers

I am using '4.46.0.dev0' (but yours does not have to be the dev version)

@tin2tin
Copy link

tin2tin commented Oct 21, 2024

Same error here with:
image

Trying to update CUDA.

Edit: Same problem with CUDA 12.6

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment