Created
December 6, 2023 20:58
-
-
Save thomasahle/c72d11a5bd62f5f6187764f6a9bb4319 to your computer and use it in GitHub Desktop.
Soft TopK with BCE loss
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.autograd import Function | |
import torch.nn.functional as F | |
@torch.no_grad() | |
def _find_ts(xs, ks, binary_iter=16, newton_iter=1): | |
n = xs.shape[-1] | |
assert torch.all((0 < ks) & (ks < n)), "We don't support k=0 or k=n" | |
# Lo should be small enough that all sigmoids are in the 0 area. | |
# Similarly Hi is large enough that all are in their 1 area. | |
lo = -xs.max(dim=-1, keepdims=True).values - 10 | |
hi = -xs.min(dim=-1, keepdims=True).values + 10 | |
assert torch.all(torch.sigmoid(xs + lo).sum(dim=-1) < 1) | |
assert torch.all(torch.sigmoid(xs + hi).sum(dim=-1) > n - 1) | |
# Batch binary search, solving sigmoid(xs + ts) = ks | |
for _ in range(binary_iter): | |
mid = (hi + lo) / 2 | |
mask = torch.sigmoid(xs + mid).sum(dim=-1) < ks | |
lo[mask] = mid[mask] | |
hi[~mask] = mid[~mask] | |
ts = (lo + hi) / 2 | |
# Fine-tune using some Newton iterations | |
for _ in range(newton_iter): | |
sig = torch.sigmoid(xs + ts) | |
den = sig.sum(dim=-1, keepdims=True) - ks[..., None] | |
num = (sig * (1 - sig)).sum(dim=-1, keepdims=True) | |
ts -= den / num | |
# Test for success | |
assert torch.allclose(torch.sigmoid(xs + ts).sum(dim=-1), ks.double()) | |
return ts | |
class TopK(Function): | |
@staticmethod | |
def forward(ctx, xs, ks): | |
ts = _find_ts(xs, ks) | |
ps = torch.sigmoid(xs + ts) | |
ctx.save_for_backward(ps) | |
return ps | |
@staticmethod | |
def backward(ctx, grad_output): | |
# Compute vjp, that is grad_output.T @ J. | |
(ps,) = ctx.saved_tensors | |
# Let v = sigmoid'(x + t) | |
v = ps * (1 - ps) # sigmoid' = sigmoid * (1 - sigmoid) | |
s = v.sum(dim=-1, keepdims=True) | |
t_d = v / s | |
# Jacobian is -vv.T/s + diag(v) | |
uv = grad_output * v | |
t1 = uv.sum(dim=-1, keepdims=True) * t_d | |
return uv - t1, None | |
class TopK_BCE(Function): | |
@staticmethod | |
def forward(ctx, xs, ks, ys): | |
xs = xs + _find_ts(xs, ks) | |
ctx.save_for_backward(xs, ks, ys) | |
loss = (ys - 1) * xs + F.logsigmoid(xs) | |
return -loss | |
@staticmethod | |
def backward(ctx, grad_output): | |
xts, ks, ys = ctx.saved_tensors | |
# Compute d/dxi t = - sig'(x_i + t) / sum_j sig'(x_j + t) | |
sig = torch.sigmoid(xts) | |
sig_d = sig * (1 - sig) # sigmoid' = sigmoid * (1 - sigmoid) | |
num = sig_d.sum(dim=-1, keepdims=True) | |
t_d = -sig_d / num | |
# Jacobian is t'e^T - diag(e) | |
e = ys - sig | |
ev = e * grad_output | |
b = ev + t_d * ev.sum(dim=-1, keepdims=True) | |
return -b, None, xts | |
soft_topk = TopK.apply | |
bce_topk = TopK_BCE.apply | |
def main(): | |
from torch.autograd import gradcheck | |
import tqdm | |
n1, n2, d = 20, 2, 10 | |
xs = torch.randn(n1, n2, d, dtype=torch.double, requires_grad=True) | |
# Test TopK function | |
for _ in tqdm.tqdm(range(2)): | |
ks = torch.randint(1, d, size=(n1, n2)) | |
assert gradcheck(soft_topk, (xs, ks), eps=1e-6, atol=1e-4) | |
for _ in tqdm.tqdm(range(10)): | |
ks = torch.randint(1, d, size=(n1, n2), dtype=torch.double) | |
ys = torch.randint(0, 2, size=(n1, n2, d), dtype=torch.double) | |
# Test forward method | |
torch.testing.assert_close( | |
F.binary_cross_entropy(soft_topk(xs, ks), ys, reduction="none"), | |
bce_topk(xs, ks, ys), | |
) | |
# Test backward method | |
assert gradcheck(bce_topk, (xs, ks, ys), eps=1e-6, atol=1e-4) | |
if __name__ == '__main__': | |
main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment