Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save sumitmamoria/c4912c4311c73e62d11bd4410c5f4730 to your computer and use it in GitHub Desktop.
Save sumitmamoria/c4912c4311c73e62d11bd4410c5f4730 to your computer and use it in GitHub Desktop.
Shows how to run Flux schnell under 17GBs without bells and whistles. It additionally shows how to serialize the quantized checkpoint and load it back.
import torch
from huggingface_hub import hf_hub_download
from diffusers import FluxTransformer2DModel, DiffusionPipeline
dtype, device = torch.bfloat16, "cuda"
ckpt_id = "black-forest-labs/FLUX.1-schnell"
with torch.device("meta"):
config = FluxTransformer2DModel.load_config(ckpt_id, subfolder="transformer")
model = FluxTransformer2DModel.from_config(config).to(dtype)
ckpt_path = hf_hub_download(repo_id="sayakpaul/flux.1-schell-int8wo", filename="flux_schnell_int8wo.pt")
state_dict = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(state_dict, assign=True)
pipeline = DiffusionPipeline.from_pretrained(ckpt_id, transformer=model, torch_dtype=dtype).to("cuda")
image = pipeline(
"cat", guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256
).images[0]
image.save("flux_schnell_int8.png")
# Install `torchao` from source: https://github.com/pytorch/ao
# Install PyTorch nightly
from diffusers import DiffusionPipeline, FluxTransformer2DModel, AutoencoderKL
from transformers import T5EncoderModel, CLIPTextModel
from torchao.quantization import quantize_, int8_weight_only
import torch
ckpt_id = "black-forest-labs/FLUX.1-schnell"
# Quantize the components individually.
# If quality is taking a hit then don't quantize all components.
# Mix and match.
############ Diffusion Transformer ############
transformer = FluxTransformer2DModel.from_pretrained(
ckpt_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
quantize_(transformer, int8_weight_only())
############ Text Encoder ############
text_encoder = CLIPTextModel.from_pretrained(
ckpt_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
)
quantize_(text_encoder, int8_weight_only())
############ Text Encoder 2 ############
text_encoder_2 = T5EncoderModel.from_pretrained(
ckpt_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
)
quantize_(text_encoder_2, int8_weight_only())
############ VAE ############
vae = AutoencoderKL.from_pretrained(
ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16
)
quantize_(vae, int8_weight_only())
# Initialize the pipeline now.
pipeline = DiffusionPipeline.from_pretrained(
ckpt_id,
transformer=transformer,
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline(
"cat", guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256
).images[0]
torch.cuda.empty_cache()
memory = (torch.cuda.memory_allocated() / 1024 / 1024 / 1024)
print(f"{memory=:.3f} GB")
image.save("quantized_image.png")
from diffusers import FluxTransformer2DModel
from torchao.quantization import quantize_, int8_weight_only
import torch
ckpt_id = "black-forest-labs/FLUX.1-schnell"
transformer = FluxTransformer2DModel.from_pretrained(
ckpt_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
quantize_(transformer, int8_weight_only())
# should ideally be possible with safetensors but
# https://github.com/huggingface/safetensors/issues/515
# this checkpoint is 12GB instead of 23GB.
torch.save(transformer.state_dict(), "flux_schnell_int8wo.pt")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment