Skip to content

Instantly share code, notes, and snippets.

@leslie-fang-intel
Created February 28, 2025 10:03
Show Gist options
  • Select an option

  • Save leslie-fang-intel/f9686fa15b181d82294861aac1d5abad to your computer and use it in GitHub Desktop.

Select an option

Save leslie-fang-intel/f9686fa15b181d82294861aac1d5abad to your computer and use it in GitHub Desktop.
import torch
import torch._inductor.config as config
config.freezing = True
in_feature = 32
out_feature = 64
q_min, q_max = -32, 31
reshape_a = True
expand_a_scale = False
inplace_add = True
test_for_pointwise_binary = True
M = 1
dtype = torch.bfloat16
has_bias = False
class Mod(torch.nn.Module):
def __init__(self, dtype: torch.dtype, has_bias: bool):
super().__init__()
self.dtype = dtype
self.has_bias = has_bias
self.b = torch.randint(
q_min, q_max, [in_feature, out_feature], dtype=torch.int8
)
self.a_scale = torch.rand([M, 1], dtype=dtype) * 0.01 + 0.01
self.b_scale = torch.rand([out_feature]) * 0.01 + 0.01
self.b_scale = self.b_scale.to(dtype)
self.bias = torch.rand([out_feature], dtype=dtype) if has_bias else None
self.additive = torch.rand([M, out_feature], dtype=dtype)
def forward(self, a):
if reshape_a:
a_reshaped = a.reshape(-1, a.size(-1))
else:
a_reshaped = a
c = torch._int_mm(a_reshaped, self.b)
c = c.to(self.dtype)
if expand_a_scale:
a_scale = self.a_scale.expand(c.shape)
else:
a_scale = self.a_scale
c = c * a_scale
c = c * self.b_scale
if self.has_bias:
c = c + self.bias
elif inplace_add and test_for_pointwise_binary:
# When M is 1, dynamic shapes are enabled with torch.compile, has_bias is False,
# expand_a_scale is False and inplace_add is true,
# the output's outermost dim's stride can't be determined due to some Inductor bug.
c.add_(self.additive)
return c
mod = Mod(dtype, has_bias).eval()
a = torch.randint(q_min, q_max, [M, in_feature], dtype=torch.int8)
with torch.no_grad():
cmod = torch.compile(mod, dynamic=True)
cmod(a)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment