Skip to content

Instantly share code, notes, and snippets.

@tin2tin
Created January 3, 2026 19:50
Show Gist options
  • Select an option

  • Save tin2tin/91e4093e0067f588f347ca142a987d3f to your computer and use it in GitHub Desktop.

Select an option

Save tin2tin/91e4093e0067f588f347ca142a987d3f to your computer and use it in GitHub Desktop.
Qwen Layered
import torch
import os
from PIL import Image
from diffusers import (
QwenImageTransformer2DModel,
QwenImageLayeredPipeline,
AutoencoderKLQwenImage, # Using the specialized 3D VAE class
BitsAndBytesConfig as DiffusersBitsAndBytesConfig
)
from transformers import (
Qwen2_5_VLForConditionalGeneration,
BitsAndBytesConfig as TransformersBitsAndBytesConfig
)
print("Loading: Qwen-Image-Layered (Unified BFloat16)")
# --- 1. Model IDs ---
base_model_id = "Qwen/Qwen-Image-Layered"
transformer_id = "OzzyGT/qwen-image-layered-bnb-4bit-transformer"
text_encoder_id = "OzzyGT/qwen-image-layered-bnb-4bit-text-encoder"
# Unified dtype for everything
torch_dtype = torch.bfloat16
# --- 2. Configure 4-bit Loading ---
quantization_config_transformer = DiffusersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
llm_int8_skip_modules=["transformer_blocks.0.img_mod"],
)
quantization_config_te = TransformersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
)
# --- 3. Load Components ---
print("Loading Transformer (4-bit)...")
transformer = QwenImageTransformer2DModel.from_pretrained(
transformer_id,
quantization_config=quantization_config_transformer,
torch_dtype=torch_dtype,
)
print("Loading Text Encoder (4-bit)...")
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
text_encoder_id,
quantization_config=quantization_config_te,
torch_dtype=torch_dtype,
)
print("Loading VAE (BFloat16)...")
# FIX: Load VAE in bfloat16 to match the input image dtype.
# bfloat16 is usually stable enough to avoid black images.
vae = AutoencoderKLQwenImage.from_pretrained(
base_model_id,
subfolder="vae",
torch_dtype=torch_dtype
)
# --- 4. Assemble Pipeline ---
print("Assembling Pipeline...")
pipeline = QwenImageLayeredPipeline.from_pretrained(
base_model_id,
transformer=transformer,
text_encoder=text_encoder,
vae=vae,
torch_dtype=torch_dtype,
)
pipeline.enable_model_cpu_offload()
# --- 5. Run Inference ---
image_path = r"C:\Users\peter\Downloads\You_Disappear-282932284-large.jpg"
if not os.path.exists(image_path):
raise FileNotFoundError(f"Could not find image at: {image_path}")
print(f"Processing image: {image_path}")
image = Image.open(image_path).convert("RGBA")
inputs = {
"image": image,
"generator": torch.Generator(device="cuda").manual_seed(777),
"true_cfg_scale": 4.0,
"negative_prompt": " ",
"num_inference_steps": 50,
"num_images_per_prompt": 1,
"layers": 4,
"resolution": 640,
"cfg_normalize": True,
"use_en_prompt": True,
}
print("Generating layers...")
with torch.inference_mode():
output = pipeline(**inputs)
# Get the list of PIL images
output_layers = output.images[0]
print(f"Saving {len(output_layers)} layers...")
for i, layer_img in enumerate(output_layers):
filename = f"layer_{i}.png"
layer_img.save(filename)
print(f" - Saved {filename}")
print("Done!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment