Skip to content

Instantly share code, notes, and snippets.

_int8da_int8w_api
(2, 1536, 256): elapsed time: 0.26234880447387693, ref elapsed time: 0.2546819114685059, bf16 elapsed time: 0.2320425605773926
(2, 1536, 1536): elapsed time: 0.30359392166137694, ref elapsed time: 0.26366655349731444, bf16 elapsed time: 0.23699455261230468
(2, 1536, 2048): elapsed time: 0.3142892837524414, ref elapsed time: 0.26277664184570315, bf16 elapsed time: 0.22612543106079103
(666, 1536, 4096): elapsed time: 0.31355167388916017, ref elapsed time: 0.2893782424926758, bf16 elapsed time: 0.2186182403564453
(2, 1536, 9216): elapsed time: 0.28453664779663085, ref elapsed time: 0.2939740753173828, bf16 elapsed time: 0.21930879592895508
(8192, 1536, 1536): elapsed time: 0.28819999694824217, ref elapsed time: 0.2803721618652344, bf16 elapsed time: 0.25683423995971677
(666, 1536, 1536): elapsed time: 0.32346622467041014, ref elapsed time: 0.3024294471740723, bf16 elapsed time: 0.25059999465942384
(8192, 6144, 1536): elapsed time: 0.6899622344970703, ref elapsed time: 0.7737372589111328, bf16 e
index 778ad67..a7308d2 100644
--- a/inference/benchmark_pixart.py
+++ b/inference/benchmark_pixart.py
@@ -96,7 +96,7 @@ def load_pipeline(
if compile_vae:
pipeline.vae = quantize_to_float8(pipeline.vae, QuantConfig(ActivationCasting.DYNAMIC))
elif quantization == "autoquant":
- pipeline.transformer = autoquant(pipeline.transformer)
+ pipeline.transformer = autoquant(pipeline.transformer, error_on_unseen=False)
if compile_vae:
_int8da_int8w_api
(6, 4096, 11008): elapsed time: 0.2683568000793457, ref elapsed time: 0.3081395149230957, bf16 elapsed time: 0.2551568031311035
(1, 11008, 4096): elapsed time: 0.2620355224609375, ref elapsed time: 0.33718048095703124, bf16 elapsed time: 0.20895679473876952
(6, 32000, 4096): elapsed time: 0.29149824142456054, ref elapsed time: 0.26438495635986325, bf16 elapsed time: 0.22105215072631837
(1, 12288, 4096): elapsed time: 0.32363166809082033, ref elapsed time: 0.26286848068237306, bf16 elapsed time: 0.22746400833129882
(6, 4096, 4096): elapsed time: 0.2673535919189453, ref elapsed time: 0.25161792755126955, bf16 elapsed time: 0.21523712158203126
(1, 4096, 11008): elapsed time: 0.2712985610961914, ref elapsed time: 0.25464864730834963, bf16 elapsed time: 0.2168137550354004
(1, 4096, 4096): elapsed time: 0.2569926452636719, ref elapsed time: 0.2699363136291504, bf16 elapsed time: 0.21389440536499024
(6, 12288, 4096): elapsed time: 0.2512361526489258, ref elapsed time: 0.2580838394165039, bf16 elaps
_int8da_int8w_api
(4096, 3072, 12288): elapsed time: 1.0266156768798829, ref elapsed time: 1.0287254333496094, bf16 elapsed time: 1.5214883422851562
(512, 3072, 12288): elapsed time: 0.2772822380065918, ref elapsed time: 0.26691999435424807, bf16 elapsed time: 0.22141408920288086
(4096, 3072, 3072): elapsed time: 0.3714015960693359, ref elapsed time: 0.32135711669921874, bf16 elapsed time: 0.4453104019165039
(4608, 3072, 3072): elapsed time: 0.40113407135009765, ref elapsed time: 0.35113761901855467, bf16 elapsed time: 0.45243934631347654
(1, 3072, 768): elapsed time: 0.2531641578674316, ref elapsed time: 0.2605705642700195, bf16 elapsed time: 0.22195039749145506
(512, 3072, 3072): elapsed time: 0.2756582450866699, ref elapsed time: 0.29623712539672853, bf16 elapsed time: 0.22314495086669922
(4608, 3072, 15360): elapsed time: 1.370133514404297, ref elapsed time: 1.3693881225585938, bf16 elapsed time: 1.9846876525878907
(1, 3072, 3072): elapsed time: 0.2843340873718262, ref elapsed time: 0.2885798454284668, bf
This file has been truncated, but you can view the full file.
W0820 21:13:02.269044 2066771 torch/_logging/_internal.py:432] Using TORCH_LOGS environment variable for log settings, ignoring call to set_logs
V0820 21:13:06.770973 2066771 torch/_inductor/codecache.py:1084] [0/0] [__output_code] Output code written to: /tmp/torchinductor_jerryzh/ak/cakds3xuwam3xcl3fm2m4ulll46a7yqozwcauz5ixh442xybsdkd.py
V0820 21:13:06.772004 2066771 torch/_inductor/codecache.py:1085] [0/0] [__output_code] Output code:
V0820 21:13:06.772004 2066771 torch/_inductor/codecache.py:1085] [0/0] [__output_code]
V0820 21:13:06.772004 2066771 torch/_inductor/codecache.py:1085] [0/0] [__output_code] # AOT ID: ['0_inference']
V0820 21:13:06.772004 2066771 torch/_inductor/codecache.py:1085] [0/0] [__output_code] from ctypes import c_void_p, c_long
V0820 21:13:06.772004 2066771 torch/_inductor/codecache.py:1085] [0/0] [__output_code] import torch
V0820 21:13:06.772004 2066771 torch/_inductor/codecache.py:1085] [0/0] [__output_code] import math
V0820 21:13:06.772004 2066771 torch/_inductor/codecache.py
Package Version
------------------------ -----------------------
accelerate 0.33.0
aiohappyeyeballs 2.3.5
aiohttp 3.10.1
aiosignal 1.3.1
async-timeout 4.0.3
attrs 24.2.0
autocommand 2.2.2
backports.tarfile 1.2.0
from diffusers import FluxTransformer2DModel
from torchao.quantization import quantize_, int8_weight_only
import torch
ckpt_id = "black-forest-labs/FLUX.1-schnell"
transformer = FluxTransformer2DModel.from_pretrained(
ckpt_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
quantize_(transformer, int8_weight_only())
from diffusers import FluxTransformer2DModel
import torch
ckpt_id = "jerryzh168/flux-schnell-int8wo"
transformer = FluxTransformer2DModel.from_pretrained(
ckpt_id, torch_dtype=torch.bfloat16, use_safetensors=False
)
from diffusers import FluxTransformer2DModel
from torchao.quantization import quantize_, int8_weight_only
import torch
from torchao import autoquant
ckpt_id = "black-forest-labs/FLUX.1-schnell"
transformer = FluxTransformer2DModel.from_pretrained(
ckpt_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
...........frames [('total', 1), ('ok', 1)]
inductor [('pattern_matcher_count', 4), ('pattern_matcher_nodes', 4), ('fxgraph_cache_miss', 1), ('extern_calls', 1)]
inline_call []
stats [('calls_captured', 1), ('unique_graphs', 1)]
aot_autograd [('total', 1), ('ok', 1)]
.frames [('total', 1), ('ok', 1)]
inductor [('pattern_matcher_count', 4), ('pattern_matcher_nodes', 4), ('fxgraph_cache_miss', 1), ('extern_calls', 1)]
inline_call []
stats [('calls_captured', 1), ('unique_graphs', 1)]
aot_autograd [('total', 1), ('ok', 1)]