Created
September 26, 2024 00:32
-
-
Save leslie-fang-intel/a9647f279069b3ec82524ebefbd9447c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch._inductor.config as config | |
| from torchao.quantization import quant_api | |
| from torchao.utils import unwrap_tensor_subclass | |
| import copy | |
| import time | |
| import intel_extension_for_pytorch as ipex | |
| from intel_extension_for_pytorch.quantization import ( | |
| prepare, | |
| convert, | |
| ) | |
| import random | |
| import numpy as np | |
| local_seed = 2024 | |
| torch.manual_seed(local_seed) # Set PyTorch seed | |
| np.random.seed(seed=local_seed) # Set Numpy seed | |
| random.seed(local_seed) # Set the Python seed | |
| config.freezing = True | |
| config.max_autotune = True | |
| config.max_autotune_gemm_backends = "CPP,ATEN" | |
| M=4096 | |
| N=4096 | |
| K=4096 | |
| class Model(torch.nn.Module): | |
| def __init__(self,): | |
| super().__init__() | |
| self.linear = torch.nn.Linear(K, N, True) | |
| def forward(self, x, x2): | |
| tmp = self.linear(x) | |
| return tmp | |
| # return tmp + x2 | |
| # skip_benchmark = True | |
| skip_benchmark = False | |
| def benchmark(m, input, input2): | |
| if skip_benchmark: | |
| return | |
| warm_up_iter = 50 | |
| iter = 200 | |
| for _ in range(warm_up_iter): | |
| m(input, input2) | |
| start = time.time() | |
| for _ in range(iter): | |
| m(input, input2) | |
| print("---- time is : {}".format((time.time() - start) / iter), flush=True) | |
| if __name__ == "__main__": | |
| m = Model().eval() | |
| print("m.linear.weight: ", m.linear.weight, flush=True) | |
| input = torch.randn(M, K) | |
| input2 = torch.randn(M, N) | |
| with torch.autocast(device_type="cpu"), torch.no_grad(): | |
| # BF16 Run | |
| bf16_m = copy.deepcopy(m) | |
| bf16_cm = torch.compile(bf16_m) | |
| bf16_cm(input, input2) | |
| print("---- benchmark Inductor BF16 ----", flush=True) | |
| benchmark(bf16_cm, input, input2) | |
| # IPEX WOQ int8 | |
| ipex_m = copy.deepcopy(m) | |
| qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping() | |
| prepared_model = prepare(m, qconfig, example_inputs=(input, input2), inplace=False) | |
| converted_model = convert(prepared_model) | |
| traced_model = torch.jit.trace(converted_model, (input, input2)) | |
| traced_model = torch.jit.freeze(traced_model) | |
| traced_model(input, input2) | |
| print("---- benchmark IPEX WOQ INT8 ----", flush=True) | |
| benchmark(traced_model, input, input2) | |
| # Inductor WOQ int8 | |
| quant_api.quantize_(m, quant_api.int8_weight_only(), set_inductor_config=False) | |
| unwrap_tensor_subclass(m) | |
| ref_res = m(input, input2) | |
| cm = torch.compile(m) | |
| res = cm(input, input2) | |
| print("---- benchmark Inductor WOQ INT8 ----", flush=True) | |
| benchmark(cm, input, input2) | |
| # print("ref_res is: {}".format(ref_res), flush=True) | |
| # print("res is: {}".format(res), flush=True) | |
| print(torch.allclose(ref_res, res, atol=1e-2, rtol=1e-2), flush=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment