Inspired from this example for FLUX.1-dev, with code taken from QwenImagePipeline in Diffusers.
import torch
from diffusers import AutoencoderKLQwenImage, QwenImagePipeline, QwenImageTransformer2DModel
from diffusers.image_processor import VaeImageProcessor
from torchao.quantization import float8_weight_only, quantize_
class QwenImagePromptEncoder:
def __init__(self, pretrained_model_name_or_path: str, device: str, torch_dtype: torch.dtype = torch.bfloat16):
self.encoder_ = QwenImagePipeline.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=torch_dtype,
scheduler=None,
transformer=None,
vae=None,
).to(device)
def __call__(self, prompt: str, **kwargs):
with torch.no_grad():
prompt_embeds, prompt_embeds_mask = self.encoder_.encode_prompt(prompt=prompt, **kwargs)
return prompt_embeds, prompt_embeds_mask
class QwenImageDecoder:
"""
Based off QwenImagePipeline
"""
def __init__(self, pretrained_model_name_or_path: str, device: str, torch_dtype: torch.dtype = torch.bfloat16):
self.vae_ = AutoencoderKLQwenImage.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=torch_dtype,
subfolder="vae",
).to(device)
self.vae_scale_factor_ = 2 ** len(self.vae_.temperal_downsample)
self.image_processor_ = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor_ * 2)
@property
def vae_scale_factor(self):
return self.vae_scale_factor_
def __call__(self, latents: torch.Tensor, output_type: str = "pil"):
vae = self.vae_
latents = latents.to(vae.device)
with torch.no_grad():
latents_mean = (
torch.tensor(vae.config.latents_mean)
.view(1, vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
image = vae.decode(latents, return_dict=False)[0][:, :, 0]
image = self.image_processor_.postprocess(image, output_type=output_type)
return image
class QwenImageTransformer:
def __init__(
self,
pretrained_model_name_or_path: str,
vae_scale_factor: float,
device: str,
torch_dtype: torch.dtype = torch.bfloat16,
):
transformer = QwenImageTransformer2DModel.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=torch_dtype,
subfolder="transformer",
)
quantize_(transformer, float8_weight_only())
transformer = transformer.to(device)
transformer_pipeline = QwenImagePipeline.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=torch_dtype,
text_encoder=None,
tokenizer=None,
transformer=transformer,
vae=None,
)
self.transformer_pipeline_ = transformer_pipeline
self.vae_scale_factor_ = vae_scale_factor
self.device_ = device
@property
def device(self):
return self.device_
def __call__(self, height: int, width: int, **kwargs) -> torch.Tensor:
latents = self.transformer_pipeline_(output_type="latent", height=height, width=width, **kwargs)
latents = self.transformer_pipeline_._unpack_latents(latents[0], height, width, self.vae_scale_factor_)
return latents
if __name__ == "__main__":
pretrained_model_name_or_path = "/home/ubuntu/models/Qwen-Image"
torch_dtype = torch.bfloat16
encoder = QwenImagePromptEncoder(pretrained_model_name_or_path, "cuda:0", torch_dtype)
decoder = QwenImageDecoder(pretrained_model_name_or_path, "cuda:1", torch_dtype)
transformer = QwenImageTransformer(pretrained_model_name_or_path, decoder.vae_scale_factor, "cuda:2", torch_dtype)
prompt = "a majestic dragon flying over a medieval castle, fantasy art, highly detailed"
negative_prompt = "blurry, low quality, cartoonish"
positive_prompt_embeds, positive_prompt_embeds_mask = encoder(prompt)
negative_prompt_embeds, negative_prompt_embeds_mask = encoder(negative_prompt)
latents = transformer(
prompt_embeds=positive_prompt_embeds.to(transformer.device),
prompt_embeds_mask=positive_prompt_embeds_mask.to(transformer.device),
negative_prompt_embeds=negative_prompt_embeds.to(transformer.device),
negative_prompt_embeds_mask=negative_prompt_embeds_mask.to(transformer.device),
num_inference_steps=30,
height=1024,
width=1024,
)
image = decoder(latents)
image[0].save("qwen_t2i.jpg")
Qwen Image is made of a few parts:
- Text encoder and tokenizer: ~16 GiB
- Transformer: ~20 GiB after 8-bit quantization with TorchAO
- VAE: ~500 MiB
The objective was to get Qwen-Image on an EC2 instance with 4 A10G GPUs (23 GiB VRAM each), to generate a large quantity of images, following a base prompt randomized with wildcards. Splitting up the T2I process allowed for a procedure to efficiently leverage the hardware:
- Encode prompts and save embeddings to disk from an encoder process on each GPU
- Load and handle the embeddings with a transformer process on each GPU, + decoder on CPU
- Even though the Transformer and VAE could be loaded onto a single A10G, there seemed to be insufficient memory to be allocated during forward passes
- If less than a 24 GB card, it may be necessary to load models with the max_memory kwarg (see e.g. this thread)