Skip to content

Instantly share code, notes, and snippets.

View sumitmamoria's full-sized avatar

Sumit Mamoria sumitmamoria

View GitHub Profile
@sumitmamoria
sumitmamoria / inference_with_torchao_serialized.py
Created August 22, 2024 04:22 — forked from sayakpaul/inference_with_torchao_serialized.py
Shows how to run Flux schnell under 17GBs without bells and whistles. It additionally shows how to serialize the quantized checkpoint and load it back.
import torch
from huggingface_hub import hf_hub_download
from diffusers import FluxTransformer2DModel, DiffusionPipeline
dtype, device = torch.bfloat16, "cuda"
ckpt_id = "black-forest-labs/FLUX.1-schnell"
with torch.device("meta"):
config = FluxTransformer2DModel.load_config(ckpt_id, subfolder="transformer")
model = FluxTransformer2DModel.from_config(config).to(dtype)