Skip to content

Instantly share code, notes, and snippets.

@jerryzh168
Created September 9, 2024 18:38
Show Gist options
  • Save jerryzh168/c2d4ce9c95d25b037a4c636a05f84fb7 to your computer and use it in GitHub Desktop.
Save jerryzh168/c2d4ce9c95d25b037a4c636a05f84fb7 to your computer and use it in GitHub Desktop.
+ @common_utils.parametrize("device", COMMON_DEVICES)
+ @common_utils.parametrize("dtype", COMMON_DTYPES)
+ def test_linear_compile(self, device, dtype):
+ hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
+ lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
+
+ hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype)
+ hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor)
+ l = torch.nn.Linear(128, 4, bias=False, device=device, dtype=dtype)
+ l.weight = torch.nn.Parameter(lp_tensor)
+ lp_res = torch.compile(l)(hp_act_tensor)
+ self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment