Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save vanbasten23/ddfe570363785b1c8c54f49816ac3eb8 to your computer and use it in GitHub Desktop.
Save vanbasten23/ddfe570363785b1c8c54f49816ac3eb8 to your computer and use it in GitHub Desktop.
# Per-channel quant zero point
x = torch.randn(3, 6)
zero_point = torch.randn(8)
zp_out = torch.einsum("...c,z->...z", x, zero_point)
zp_out_ref = x.sum(dim=-1, keepdim=True) * zero_point
assert torch.allclose(zp_out, zp_out_ref)
# block-wise case
# w: [in_channel / block_size, block_size, out_channel]
# x: [*, in_channel // block_size, block_size]
w = torch.randn(3, 2, 8)
x = torch.randn(3, 3, 2)
out = torch.einsum('scn,...sc->...sn', w, x)
out_ref = (x.unsqueeze(-1) * w).sum(dim=-2)
assert torch.allclose(out, out_ref)
scaler = torch.randn(3, 8)
out = torch.randn(3, 3, 8)
res = torch.einsum('sn,...sn->...n', scaler, out)
res_ref = (scaler * out).sum(dim=-2)
assert torch.allclose(res, res_ref)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment