Created
February 28, 2025 10:03
-
-
Save leslie-fang-intel/f9686fa15b181d82294861aac1d5abad to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch._inductor.config as config | |
| config.freezing = True | |
| in_feature = 32 | |
| out_feature = 64 | |
| q_min, q_max = -32, 31 | |
| reshape_a = True | |
| expand_a_scale = False | |
| inplace_add = True | |
| test_for_pointwise_binary = True | |
| M = 1 | |
| dtype = torch.bfloat16 | |
| has_bias = False | |
| class Mod(torch.nn.Module): | |
| def __init__(self, dtype: torch.dtype, has_bias: bool): | |
| super().__init__() | |
| self.dtype = dtype | |
| self.has_bias = has_bias | |
| self.b = torch.randint( | |
| q_min, q_max, [in_feature, out_feature], dtype=torch.int8 | |
| ) | |
| self.a_scale = torch.rand([M, 1], dtype=dtype) * 0.01 + 0.01 | |
| self.b_scale = torch.rand([out_feature]) * 0.01 + 0.01 | |
| self.b_scale = self.b_scale.to(dtype) | |
| self.bias = torch.rand([out_feature], dtype=dtype) if has_bias else None | |
| self.additive = torch.rand([M, out_feature], dtype=dtype) | |
| def forward(self, a): | |
| if reshape_a: | |
| a_reshaped = a.reshape(-1, a.size(-1)) | |
| else: | |
| a_reshaped = a | |
| c = torch._int_mm(a_reshaped, self.b) | |
| c = c.to(self.dtype) | |
| if expand_a_scale: | |
| a_scale = self.a_scale.expand(c.shape) | |
| else: | |
| a_scale = self.a_scale | |
| c = c * a_scale | |
| c = c * self.b_scale | |
| if self.has_bias: | |
| c = c + self.bias | |
| elif inplace_add and test_for_pointwise_binary: | |
| # When M is 1, dynamic shapes are enabled with torch.compile, has_bias is False, | |
| # expand_a_scale is False and inplace_add is true, | |
| # the output's outermost dim's stride can't be determined due to some Inductor bug. | |
| c.add_(self.additive) | |
| return c | |
| mod = Mod(dtype, has_bias).eval() | |
| a = torch.randint(q_min, q_max, [M, in_feature], dtype=torch.int8) | |
| with torch.no_grad(): | |
| cmod = torch.compile(mod, dynamic=True) | |
| cmod(a) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment