Skip to content

Instantly share code, notes, and snippets.

@thomasahle
Created December 6, 2023 20:58
Show Gist options
  • Save thomasahle/c72d11a5bd62f5f6187764f6a9bb4319 to your computer and use it in GitHub Desktop.
Save thomasahle/c72d11a5bd62f5f6187764f6a9bb4319 to your computer and use it in GitHub Desktop.
Soft TopK with BCE loss
import torch
from torch.autograd import Function
import torch.nn.functional as F
@torch.no_grad()
def _find_ts(xs, ks, binary_iter=16, newton_iter=1):
n = xs.shape[-1]
assert torch.all((0 < ks) & (ks < n)), "We don't support k=0 or k=n"
# Lo should be small enough that all sigmoids are in the 0 area.
# Similarly Hi is large enough that all are in their 1 area.
lo = -xs.max(dim=-1, keepdims=True).values - 10
hi = -xs.min(dim=-1, keepdims=True).values + 10
assert torch.all(torch.sigmoid(xs + lo).sum(dim=-1) < 1)
assert torch.all(torch.sigmoid(xs + hi).sum(dim=-1) > n - 1)
# Batch binary search, solving sigmoid(xs + ts) = ks
for _ in range(binary_iter):
mid = (hi + lo) / 2
mask = torch.sigmoid(xs + mid).sum(dim=-1) < ks
lo[mask] = mid[mask]
hi[~mask] = mid[~mask]
ts = (lo + hi) / 2
# Fine-tune using some Newton iterations
for _ in range(newton_iter):
sig = torch.sigmoid(xs + ts)
den = sig.sum(dim=-1, keepdims=True) - ks[..., None]
num = (sig * (1 - sig)).sum(dim=-1, keepdims=True)
ts -= den / num
# Test for success
assert torch.allclose(torch.sigmoid(xs + ts).sum(dim=-1), ks.double())
return ts
class TopK(Function):
@staticmethod
def forward(ctx, xs, ks):
ts = _find_ts(xs, ks)
ps = torch.sigmoid(xs + ts)
ctx.save_for_backward(ps)
return ps
@staticmethod
def backward(ctx, grad_output):
# Compute vjp, that is grad_output.T @ J.
(ps,) = ctx.saved_tensors
# Let v = sigmoid'(x + t)
v = ps * (1 - ps) # sigmoid' = sigmoid * (1 - sigmoid)
s = v.sum(dim=-1, keepdims=True)
t_d = v / s
# Jacobian is -vv.T/s + diag(v)
uv = grad_output * v
t1 = uv.sum(dim=-1, keepdims=True) * t_d
return uv - t1, None
class TopK_BCE(Function):
@staticmethod
def forward(ctx, xs, ks, ys):
xs = xs + _find_ts(xs, ks)
ctx.save_for_backward(xs, ks, ys)
loss = (ys - 1) * xs + F.logsigmoid(xs)
return -loss
@staticmethod
def backward(ctx, grad_output):
xts, ks, ys = ctx.saved_tensors
# Compute d/dxi t = - sig'(x_i + t) / sum_j sig'(x_j + t)
sig = torch.sigmoid(xts)
sig_d = sig * (1 - sig) # sigmoid' = sigmoid * (1 - sigmoid)
num = sig_d.sum(dim=-1, keepdims=True)
t_d = -sig_d / num
# Jacobian is t'e^T - diag(e)
e = ys - sig
ev = e * grad_output
b = ev + t_d * ev.sum(dim=-1, keepdims=True)
return -b, None, xts
soft_topk = TopK.apply
bce_topk = TopK_BCE.apply
def main():
from torch.autograd import gradcheck
import tqdm
n1, n2, d = 20, 2, 10
xs = torch.randn(n1, n2, d, dtype=torch.double, requires_grad=True)
# Test TopK function
for _ in tqdm.tqdm(range(2)):
ks = torch.randint(1, d, size=(n1, n2))
assert gradcheck(soft_topk, (xs, ks), eps=1e-6, atol=1e-4)
for _ in tqdm.tqdm(range(10)):
ks = torch.randint(1, d, size=(n1, n2), dtype=torch.double)
ys = torch.randint(0, 2, size=(n1, n2, d), dtype=torch.double)
# Test forward method
torch.testing.assert_close(
F.binary_cross_entropy(soft_topk(xs, ks), ys, reduction="none"),
bce_topk(xs, ks, ys),
)
# Test backward method
assert gradcheck(bce_topk, (xs, ks, ys), eps=1e-6, atol=1e-4)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment