Skip to content

Instantly share code, notes, and snippets.

@jerryzh168
Created October 5, 2024 01:22
Show Gist options
  • Save jerryzh168/ec82a8d14b4aa848e29b56eeba750852 to your computer and use it in GitHub Desktop.
Save jerryzh168/ec82a8d14b4aa848e29b56eeba750852 to your computer and use it in GitHub Desktop.
from torchvision import models
import torch
## compilation configs
torch._dynamo.config.automatic_dynamic_shapes = False
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True
## compilation configs end
# temporary workaround to recover the perf with quantized model under torch.compile
torch.backends.mha.set_fastpath_enabled(False)
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel
import torch.utils.benchmark as benchmark
from functools import partial
def get_example_inputs():
example_inputs = (torch.randn(1, 3, 224, 224, dtype=torch.bfloat16, device='cuda'),)
return example_inputs
def benchmark_fn(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f},
num_threads=torch.get_num_threads(),
)
return f"{(t0.blocked_autorange().mean):.3f}"
def load_model():
torch.set_float32_matmul_precision("high")
model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)
model.eval().cuda().to(torch.bfloat16)
return model
def aot_compile(name, fn, sample_args):
path = f"./{name}.so"
print(f"{path=}")
options = {
"aot_inductor.output_path": path,
"max_autotune": True,
"triton.cudagraphs": True,
}
torch._export.aot_compile(
fn,
sample_args,
{},
options=options,
disable_constraint_solver=True,
)
return path
def aot_load(path):
return torch._export.aot_load(path, "cuda")
@torch.no_grad()
def f(model, *args):
return model(*args)
model = load_model()
from torchao.quantization import autoquant
from torchao.quantization import quantize_, int4_weight_only
from torchao.utils import unwrap_tensor_subclass
model = autoquant(torch.compile(model, mode="max-autotune"))
# quantize_(model, int4_weight_only())
inputs1 = get_example_inputs()
import torch
torch._dynamo.config.verbose=True
model(*inputs1)
unwrap_tensor_subclass(model)
path1 = aot_compile("bs_1_1024", partial(f, model), inputs1)
compiled_func_1 = aot_load(path1)
print(f"{compiled_func_1(*inputs1)[0].shape=}")
for _ in range(5):
_ = compiled_func_1(*inputs1)[0]
time = benchmark_fn(f, compiled_func_1, *inputs1)
print(time)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment