Created
August 5, 2022 19:16
-
-
Save thomasahle/4c1e85e5842d01b007a8d10f5fed3a18 to your computer and use it in GitHub Desktop.
Simple Differentiable TopK for PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from functorch import vmap, grad | |
from torch.autograd import Function | |
sigmoid = torch.sigmoid | |
sigmoid_grad = vmap(vmap(grad(sigmoid))) | |
class TopK(Function): | |
@staticmethod | |
def forward(ctx, xs, k): | |
ts, ps = _find_ts(xs, k) | |
ctx.save_for_backward(xs, ts) | |
return ps | |
@staticmethod | |
def backward(ctx, grad_output): | |
# Compute vjp, that is grad_output.T @ J. | |
xs, ts = ctx.saved_tensors | |
# Let v = sigmoid'(x + t) | |
v = sigmoid_grad(xs + ts) | |
s = v.sum(dim=1, keepdims=True) | |
# Jacobian is -vv.T/s + diag(v) | |
uv = grad_output * v | |
t1 = - uv.sum(dim=1, keepdims=True) * v / s | |
return t1 + uv, None | |
@torch.no_grad() | |
def _find_ts(xs, k): | |
b, n = xs.shape | |
assert 0 < k < n | |
# Lo should be small enough that all sigmoids are in the 0 area. | |
# Similarly Hi is large enough that all are in their 1 area. | |
lo = -xs.max(dim=1, keepdims=True).values - 10 | |
hi = -xs.min(dim=1, keepdims=True).values + 10 | |
for _ in range(64): | |
mid = (hi + lo)/2 | |
mask = sigmoid(xs + mid).sum(dim=1) < k | |
lo[mask] = mid[mask] | |
hi[~mask] = mid[~mask] | |
ts = (lo + hi)/2 | |
return ts, sigmoid(xs + ts) | |
topk = TopK.apply | |
xs = torch.randn(2, 3) | |
ps = topk(xs, 2) | |
print(xs, ps, ps.sum(dim=1)) | |
from torch.autograd import gradcheck | |
input = torch.randn(20, 10, dtype=torch.double, requires_grad=True) | |
for k in range(1, 10): | |
print(k, gradcheck(topk, (input, k), eps=1e-6, atol=1e-4)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Suppose there are n points (n, c), and a score function mapping them to scores (n, 1). Then select top k points based on scores.