Skip to content

Instantly share code, notes, and snippets.

@leslie-fang-intel
Created September 26, 2024 00:32
Show Gist options
  • Select an option

  • Save leslie-fang-intel/a9647f279069b3ec82524ebefbd9447c to your computer and use it in GitHub Desktop.

Select an option

Save leslie-fang-intel/a9647f279069b3ec82524ebefbd9447c to your computer and use it in GitHub Desktop.
import torch
import torch._inductor.config as config
from torchao.quantization import quant_api
from torchao.utils import unwrap_tensor_subclass
import copy
import time
import intel_extension_for_pytorch as ipex
from intel_extension_for_pytorch.quantization import (
prepare,
convert,
)
import random
import numpy as np
local_seed = 2024
torch.manual_seed(local_seed) # Set PyTorch seed
np.random.seed(seed=local_seed) # Set Numpy seed
random.seed(local_seed) # Set the Python seed
config.freezing = True
config.max_autotune = True
config.max_autotune_gemm_backends = "CPP,ATEN"
M=4096
N=4096
K=4096
class Model(torch.nn.Module):
def __init__(self,):
super().__init__()
self.linear = torch.nn.Linear(K, N, True)
def forward(self, x, x2):
tmp = self.linear(x)
return tmp
# return tmp + x2
# skip_benchmark = True
skip_benchmark = False
def benchmark(m, input, input2):
if skip_benchmark:
return
warm_up_iter = 50
iter = 200
for _ in range(warm_up_iter):
m(input, input2)
start = time.time()
for _ in range(iter):
m(input, input2)
print("---- time is : {}".format((time.time() - start) / iter), flush=True)
if __name__ == "__main__":
m = Model().eval()
print("m.linear.weight: ", m.linear.weight, flush=True)
input = torch.randn(M, K)
input2 = torch.randn(M, N)
with torch.autocast(device_type="cpu"), torch.no_grad():
# BF16 Run
bf16_m = copy.deepcopy(m)
bf16_cm = torch.compile(bf16_m)
bf16_cm(input, input2)
print("---- benchmark Inductor BF16 ----", flush=True)
benchmark(bf16_cm, input, input2)
# IPEX WOQ int8
ipex_m = copy.deepcopy(m)
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping()
prepared_model = prepare(m, qconfig, example_inputs=(input, input2), inplace=False)
converted_model = convert(prepared_model)
traced_model = torch.jit.trace(converted_model, (input, input2))
traced_model = torch.jit.freeze(traced_model)
traced_model(input, input2)
print("---- benchmark IPEX WOQ INT8 ----", flush=True)
benchmark(traced_model, input, input2)
# Inductor WOQ int8
quant_api.quantize_(m, quant_api.int8_weight_only(), set_inductor_config=False)
unwrap_tensor_subclass(m)
ref_res = m(input, input2)
cm = torch.compile(m)
res = cm(input, input2)
print("---- benchmark Inductor WOQ INT8 ----", flush=True)
benchmark(cm, input, input2)
# print("ref_res is: {}".format(ref_res), flush=True)
# print("res is: {}".format(res), flush=True)
print(torch.allclose(ref_res, res, atol=1e-2, rtol=1e-2), flush=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment