Skip to content

Instantly share code, notes, and snippets.

@twobob
Created February 11, 2025 04:37
Show Gist options
  • Save twobob/6454eb7c242a48aa09c51b0d5b767af7 to your computer and use it in GitHub Desktop.
Save twobob/6454eb7c242a48aa09c51b0d5b767af7 to your computer and use it in GitHub Desktop.
a low calorie version of the very fine https://huggingface.co/bleepybloops/flux-collage-v1 for machines with 16GB and up (probably)
#!/usr/bin/env python
"""
This script demonstrates how to load a quantized Flux.1-dev model (4-bit, NF4)
using bitsandbytes with Diffusers and then apply a LoRA adapter.
It loads:
(a) the T5EncoderModel from the "text_encoder_2" subfolder in 4-bit mode, and
(b) the FluxTransformer2DModel from the "transformer" subfolder in 4-bit mode.
Then it instantiates a FluxPipeline with device_map="balanced" and applies LoRA.
"""
import time
import torch
from diffusers import (
FluxPipeline,
FluxTransformer2DModel,
BitsAndBytesConfig as DiffusersBitsAndBytesConfig
)
from transformers import (
T5EncoderModel,
BitsAndBytesConfig as TransformersBitsAndBytesConfig
)
from IPython.display import Image, display
# -----------------------------------------------------------------------------
# Step 1: Define quantization configurations for 4-bit using NF4.
# For Ada+ GPUs, using torch.bfloat16 for compute is recommended.
# You may change torch_dtype to torch.float16 if preferred.
# -----------------------------------------------------------------------------
quant_config_text = TransformersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
quant_config_transformer = DiffusersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
# -----------------------------------------------------------------------------
# Step 2: Load quantized models from the Flux.1-dev repository.
# -----------------------------------------------------------------------------
# Load the quantized T5EncoderModel for text_encoder_2.
text_encoder_2_4bit = T5EncoderModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="text_encoder_2",
quantization_config=quant_config_text,
torch_dtype=torch.float16 # weights stored in fp16; compute as specified above.
)
# Load the quantized transformer (the diffusion module).
transformer_4bit = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quant_config_transformer,
torch_dtype=torch.float16
)
# -----------------------------------------------------------------------------
# Step 3: Instantiate the FluxPipeline using the quantized models.
# -----------------------------------------------------------------------------
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer_4bit,
text_encoder_2=text_encoder_2_4bit,
torch_dtype=torch.float16,
device_map="balanced" # Automatically assigns parts to GPU, CPU, or disk.
)
# (Optional) You can enable CPU offload if desired:
# pipe.enable_model_cpu_offload()
# (Optional) If you wish to set channels-last format on the diffusion module,
# FluxPipeline does not expose a "unet". Instead, if beneficial you might check if
# the pipeline exposes a "transformer" attribute.
if hasattr(pipe, "transformer"):
pipe.transformer.to(memory_format=torch.channels_last)
else:
print("No transformer attribute found to set channels-last.")
# -----------------------------------------------------------------------------
# Step 4: Apply the LoRA adapter.
# -----------------------------------------------------------------------------
# Here we load and set a LoRA adapter. Adjust the adapter repository name and weight.
lora_weight = 1.2
pipe.load_lora_weights(
'bleepybloops/flux-collage-v1', # Repository or local path for the LoRA adapter.
adapter_name='collage_v1'
)
pipe.set_adapters('collage_v1', adapter_weights=[lora_weight])
# -----------------------------------------------------------------------------
# Step 5: Generate an image.
# -----------------------------------------------------------------------------
pipe_kwargs = {
"prompt": "multimedia collage, cut and paste, animatroincs",
"height": 1024,
"width": 1024,
"guidance_scale": 3.5,
"num_inference_steps": 30,
"max_sequence_length": 512,
}
# Set a manual seed for reproducibility.
generator = torch.manual_seed(0)
# Generate the image.
result = pipe(**pipe_kwargs, generator=generator)
image = result.images[0]
# Save and display the generated image.
filename = f"{time.time()}.png"
image.save(filename)
print(f"Image saved as {filename}")
display(Image(filename=filename))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment