Skip to content

Instantly share code, notes, and snippets.

@jamestkpoon
Created September 7, 2025 00:05
Show Gist options
  • Save jamestkpoon/852dd75da4b73993a4cf4a03431a2d19 to your computer and use it in GitHub Desktop.
Save jamestkpoon/852dd75da4b73993a4cf4a03431a2d19 to your computer and use it in GitHub Desktop.
Running Qwen Image on multiple GPUs (diffusers)

Inspired from this example for FLUX.1-dev, with code taken from QwenImagePipeline in Diffusers.

import torch
from diffusers import AutoencoderKLQwenImage, QwenImagePipeline, QwenImageTransformer2DModel
from diffusers.image_processor import VaeImageProcessor
from torchao.quantization import float8_weight_only, quantize_


class QwenImagePromptEncoder:
    def __init__(self, pretrained_model_name_or_path: str, device: str, torch_dtype: torch.dtype = torch.bfloat16):
        self.encoder_ = QwenImagePipeline.from_pretrained(
            pretrained_model_name_or_path,
            torch_dtype=torch_dtype,
            scheduler=None,
            transformer=None,
            vae=None,
        ).to(device)

    def __call__(self, prompt: str, **kwargs):
        with torch.no_grad():
            prompt_embeds, prompt_embeds_mask = self.encoder_.encode_prompt(prompt=prompt, **kwargs)
            return prompt_embeds, prompt_embeds_mask


class QwenImageDecoder:
    """
    Based off QwenImagePipeline
    """

    def __init__(self, pretrained_model_name_or_path: str, device: str, torch_dtype: torch.dtype = torch.bfloat16):
        self.vae_ = AutoencoderKLQwenImage.from_pretrained(
            pretrained_model_name_or_path,
            torch_dtype=torch_dtype,
            subfolder="vae",
        ).to(device)

        self.vae_scale_factor_ = 2 ** len(self.vae_.temperal_downsample)
        self.image_processor_ = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor_ * 2)

    @property
    def vae_scale_factor(self):
        return self.vae_scale_factor_

    def __call__(self, latents: torch.Tensor, output_type: str = "pil"):
        vae = self.vae_
        latents = latents.to(vae.device)

        with torch.no_grad():
            latents_mean = (
                torch.tensor(vae.config.latents_mean)
                .view(1, vae.config.z_dim, 1, 1, 1)
                .to(latents.device, latents.dtype)
            )
            latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to(
                latents.device, latents.dtype
            )
            latents = latents / latents_std + latents_mean
            image = vae.decode(latents, return_dict=False)[0][:, :, 0]

            image = self.image_processor_.postprocess(image, output_type=output_type)
            return image


class QwenImageTransformer:
    def __init__(
        self,
        pretrained_model_name_or_path: str,
        vae_scale_factor: float,
        device: str,
        torch_dtype: torch.dtype = torch.bfloat16,
    ):
        transformer = QwenImageTransformer2DModel.from_pretrained(
            pretrained_model_name_or_path,
            torch_dtype=torch_dtype,
            subfolder="transformer",
        )
        quantize_(transformer, float8_weight_only())
        transformer = transformer.to(device)

        transformer_pipeline = QwenImagePipeline.from_pretrained(
            pretrained_model_name_or_path,
            torch_dtype=torch_dtype,
            text_encoder=None,
            tokenizer=None,
            transformer=transformer,
            vae=None,
        )
        self.transformer_pipeline_ = transformer_pipeline

        self.vae_scale_factor_ = vae_scale_factor
        self.device_ = device

    @property
    def device(self):
        return self.device_

    def __call__(self, height: int, width: int, **kwargs) -> torch.Tensor:
        latents = self.transformer_pipeline_(output_type="latent", height=height, width=width, **kwargs)
        latents = self.transformer_pipeline_._unpack_latents(latents[0], height, width, self.vae_scale_factor_)
        return latents


if __name__ == "__main__":
    pretrained_model_name_or_path = "/home/ubuntu/models/Qwen-Image"
    torch_dtype = torch.bfloat16

    encoder = QwenImagePromptEncoder(pretrained_model_name_or_path, "cuda:0", torch_dtype)
    decoder = QwenImageDecoder(pretrained_model_name_or_path, "cuda:1", torch_dtype)
    transformer = QwenImageTransformer(pretrained_model_name_or_path, decoder.vae_scale_factor, "cuda:2", torch_dtype)

    prompt = "a majestic dragon flying over a medieval castle, fantasy art, highly detailed"
    negative_prompt = "blurry, low quality, cartoonish"

    positive_prompt_embeds, positive_prompt_embeds_mask = encoder(prompt)
    negative_prompt_embeds, negative_prompt_embeds_mask = encoder(negative_prompt)

    latents = transformer(
        prompt_embeds=positive_prompt_embeds.to(transformer.device),
        prompt_embeds_mask=positive_prompt_embeds_mask.to(transformer.device),
        negative_prompt_embeds=negative_prompt_embeds.to(transformer.device),
        negative_prompt_embeds_mask=negative_prompt_embeds_mask.to(transformer.device),
        num_inference_steps=30,
        height=1024,
        width=1024,
    )

    image = decoder(latents)
    image[0].save("qwen_t2i.jpg")

Background

Qwen Image is made of a few parts:

  • Text encoder and tokenizer: ~16 GiB
  • Transformer: ~20 GiB after 8-bit quantization with TorchAO
  • VAE: ~500 MiB

The objective was to get Qwen-Image on an EC2 instance with 4 A10G GPUs (23 GiB VRAM each), to generate a large quantity of images, following a base prompt randomized with wildcards. Splitting up the T2I process allowed for a procedure to efficiently leverage the hardware:

  1. Encode prompts and save embeddings to disk from an encoder process on each GPU
  2. Load and handle the embeddings with a transformer process on each GPU, + decoder on CPU

Notes

  • Even though the Transformer and VAE could be loaded onto a single A10G, there seemed to be insufficient memory to be allocated during forward passes
  • If less than a 24 GB card, it may be necessary to load models with the max_memory kwarg (see e.g. this thread)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment