Skip to content

Instantly share code, notes, and snippets.

@thomasahle
Created August 5, 2022 19:16
Show Gist options
  • Save thomasahle/4c1e85e5842d01b007a8d10f5fed3a18 to your computer and use it in GitHub Desktop.
Save thomasahle/4c1e85e5842d01b007a8d10f5fed3a18 to your computer and use it in GitHub Desktop.
Simple Differentiable TopK for PyTorch
import torch
from functorch import vmap, grad
from torch.autograd import Function
sigmoid = torch.sigmoid
sigmoid_grad = vmap(vmap(grad(sigmoid)))
class TopK(Function):
@staticmethod
def forward(ctx, xs, k):
ts, ps = _find_ts(xs, k)
ctx.save_for_backward(xs, ts)
return ps
@staticmethod
def backward(ctx, grad_output):
# Compute vjp, that is grad_output.T @ J.
xs, ts = ctx.saved_tensors
# Let v = sigmoid'(x + t)
v = sigmoid_grad(xs + ts)
s = v.sum(dim=1, keepdims=True)
# Jacobian is -vv.T/s + diag(v)
uv = grad_output * v
t1 = - uv.sum(dim=1, keepdims=True) * v / s
return t1 + uv, None
@torch.no_grad()
def _find_ts(xs, k):
b, n = xs.shape
assert 0 < k < n
# Lo should be small enough that all sigmoids are in the 0 area.
# Similarly Hi is large enough that all are in their 1 area.
lo = -xs.max(dim=1, keepdims=True).values - 10
hi = -xs.min(dim=1, keepdims=True).values + 10
for _ in range(64):
mid = (hi + lo)/2
mask = sigmoid(xs + mid).sum(dim=1) < k
lo[mask] = mid[mask]
hi[~mask] = mid[~mask]
ts = (lo + hi)/2
return ts, sigmoid(xs + ts)
topk = TopK.apply
xs = torch.randn(2, 3)
ps = topk(xs, 2)
print(xs, ps, ps.sum(dim=1))
from torch.autograd import gradcheck
input = torch.randn(20, 10, dtype=torch.double, requires_grad=True)
for k in range(1, 10):
print(k, gradcheck(topk, (input, k), eps=1e-6, atol=1e-4))
@debraj135
Copy link

Is there any place I can read about the logic behind the implementation?

@qwe1256
Copy link

qwe1256 commented Oct 15, 2023

Is there any place I can read about the logic behind the implementation?

I believe this pose will be helpful. https://math.stackexchange.com/questions/3280757/differentiable-top-k-function

@thomasahle
Copy link
Author

If you are using this Soft TopK function, you may also want to combine it with BCE loss.
I have an updated gist here that does exactly that: https://gist.github.com/thomasahle/c72d11a5bd62f5f6187764f6a9bb4319

@Liu0329
Copy link

Liu0329 commented Apr 9, 2024

Can I use this for hard selection ?

@thomasahle
Copy link
Author

Tell me more?

@Liu0329
Copy link

Liu0329 commented Apr 11, 2024

Suppose there are n points (n, c), and a score function mapping them to scores (n, 1). Then select top k points based on scores.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment