-
-
Save thomasahle/48e9b3f17ead6c3ef11325f25de3655e to your computer and use it in GitHub Desktop.
def sinkhorn_forward(C, mu, nu, epsilon, max_iter): | |
bs, n, k_ = C.size() | |
v = torch.ones([bs, 1, k_])/(k_) | |
G = torch.exp(-C/epsilon) | |
if torch.cuda.is_available(): | |
v = v.cuda() | |
for i in range(max_iter): | |
u = mu/(G*v).sum(-1, keepdim=True) | |
v = nu/(G*u).sum(-2, keepdim=True) | |
Gamma = u*G*v | |
return Gamma | |
def sinkhorn_forward_stablized(C, mu, nu, epsilon, max_iter): | |
bs, n, k_ = C.size() | |
k = k_-1 | |
f = torch.zeros([bs, n, 1]) | |
g = torch.zeros([bs, 1, k+1]) | |
if torch.cuda.is_available(): | |
f = f.cuda() | |
g = g.cuda() | |
epsilon_log_mu = epsilon*torch.log(mu) | |
epsilon_log_nu = epsilon*torch.log(nu) | |
def min_epsilon_row(Z, epsilon): | |
return -epsilon*torch.logsumexp((-Z)/epsilon, -1, keepdim=True) | |
def min_epsilon_col(Z, epsilon): | |
return -epsilon*torch.logsumexp((-Z)/epsilon, -2, keepdim=True) | |
for i in range(max_iter): | |
f = min_epsilon_row(C-g, epsilon)+epsilon_log_mu | |
g = min_epsilon_col(C-f, epsilon)+epsilon_log_nu | |
Gamma = torch.exp((-C+f+g)/epsilon) | |
return Gamma | |
def sinkhorn_backward(grad_output_Gamma, Gamma, mu, nu, epsilon): | |
nu_ = nu[:,:,:-1] | |
Gamma_ = Gamma[:,:,:-1] | |
bs, n, k_ = Gamma.size() | |
inv_mu = 1./(mu.view([1,-1])) #[1, n] | |
Kappa = torch.diag_embed(nu_.squeeze(-2)) \ | |
-torch.matmul(Gamma_.transpose(-1, -2) * inv_mu.unsqueeze(-2), Gamma_) #[bs, k, k] | |
inv_Kappa = torch.inverse(Kappa) #[bs, k, k] | |
Gamma_mu = inv_mu.unsqueeze(-1)*Gamma_ | |
L = Gamma_mu.matmul(inv_Kappa) #[bs, n, k] | |
G1 = grad_output_Gamma * Gamma #[bs, n, k+1] | |
g1 = G1.sum(-1) | |
G21 = (g1*inv_mu).unsqueeze(-1)*Gamma #[bs, n, k+1] | |
g1_L = g1.unsqueeze(-2).matmul(L) #[bs, 1, k] | |
G22 = g1_L.matmul(Gamma_mu.transpose(-1,-2)).transpose(-1,-2)*Gamma #[bs, n, k+1] | |
G23 = - F.pad(g1_L, pad=(0, 1), mode='constant', value=0)*Gamma #[bs, n, k+1] | |
G2 = G21 + G22 + G23 #[bs, n, k+1] | |
del g1, G21, G22, G23, Gamma_mu | |
g2 = G1.sum(-2).unsqueeze(-1) #[bs, k+1, 1] | |
g2 = g2[:,:-1,:] #[bs, k, 1] | |
G31 = - L.matmul(g2)*Gamma #[bs, n, k+1] | |
G32 = F.pad(inv_Kappa.matmul(g2).transpose(-1,-2), pad=(0, 1), mode='constant', value=0)*Gamma #[bs, n, k+1] | |
G3 = G31 + G32 #[bs, n, k+1] | |
grad_C = (-G1+G2+G3)/epsilon #[bs, n, k+1] | |
return grad_C | |
class TopKFunc(Function): | |
@staticmethod | |
def forward(ctx, C, mu, nu, epsilon, max_iter): | |
with torch.no_grad(): | |
if epsilon>1e-2: | |
Gamma = sinkhorn_forward(C, mu, nu, epsilon, max_iter) | |
if bool(torch.any(Gamma!=Gamma)): | |
print('Nan appeared in Gamma, re-computing...') | |
Gamma = sinkhorn_forward_stablized(C, mu, nu, epsilon, max_iter) | |
else: | |
Gamma = sinkhorn_forward_stablized(C, mu, nu, epsilon, max_iter) | |
ctx.save_for_backward(mu, nu, Gamma) | |
ctx.epsilon = epsilon | |
return Gamma | |
@staticmethod | |
def backward(ctx, grad_output_Gamma): | |
epsilon = ctx.epsilon | |
mu, nu, Gamma = ctx.saved_tensors | |
# mu [1, n, 1] | |
# nu [1, 1, k+1] | |
#Gamma [bs, n, k+1] | |
with torch.no_grad(): | |
grad_C = sinkhorn_backward(grad_output_Gamma, Gamma, mu, nu, epsilon) | |
return grad_C, None, None, None, None | |
class TopK_custom(torch.nn.Module): | |
def __init__(self, k, epsilon=0.1, max_iter = 200): | |
super(TopK_custom1, self).__init__() | |
self.k = k | |
self.epsilon = epsilon | |
self.anchors = torch.FloatTensor([k-i for i in range(k+1)]).view([1,1, k+1]) | |
self.max_iter = max_iter | |
if torch.cuda.is_available(): | |
self.anchors = self.anchors.cuda() | |
def forward(self, scores): | |
bs, n = scores.size() | |
scores = scores.view([bs, n, 1]) | |
#find the -inf value and replace it with the minimum value except -inf | |
scores_ = scores.clone().detach() | |
max_scores = torch.max(scores_).detach() | |
scores_[scores_==float('-inf')] = float('inf') | |
min_scores = torch.min(scores_).detach() | |
filled_value = min_scores - (max_scores-min_scores) | |
mask = scores==float('-inf') | |
scores = scores.masked_fill(mask, filled_value) | |
C = (scores-self.anchors)**2 | |
C = C / (C.max().detach()) | |
mu = torch.ones([1, n, 1], requires_grad=False)/n | |
nu = [1./n for _ in range(self.k)] | |
nu.append((n-self.k)/n) | |
nu = torch.FloatTensor(nu).view([1, 1, self.k+1]) | |
if torch.cuda.is_available(): | |
mu = mu.cuda() | |
nu = nu.cuda() | |
Gamma = TopKFunc.apply(C, mu, nu, self.epsilon, self.max_iter) | |
A = Gamma[:,:,:self.k]*n | |
return A, None |
Maybe torch.autograd.Function
?
Yes, that seemed to work. Thank you very much.
I am actually interested in the topk indices and not the topk values. However, this function returns None for the indices on line 146. Do you know by any chance if it is possible to get the indices in a differentiable manner as well?
The way I understand it, the method returns a matrix of (item, position) probabilities.
For example, if I take the probabilities p = [0.01, 0.1, 0.04, 0.5, 0.24]
and compute the top-3:
ps2, _ = TopK_custom(3)(torch.log(p))
print(f'{ps2=}')
print(ps2.sum(dim=2))
I get
ps2=tensor([[[0.0307, 0.0824, 0.1877],
[0.1803, 0.2183, 0.2244],
[0.0960, 0.1596, 0.2252],
[0.4005, 0.2780, 0.1638],
[0.2926, 0.2618, 0.1988]]])
tensor([[0.3008, 0.6230, 0.4807, 0.8423, 0.7532]])
The matrix says that the first item (with initial probability 0.01
) has 3% chance being first (index 1), 8% chance being at index 2 and so on.
If you sum along the second axis you get the inclusion probabilities.
If you are just interested in differentiable inclusion probabilities, you can use my simple differentiable top-k code here: https://gist.github.com/thomasahle/4c1e85e5842d01b007a8d10f5fed3a18
Great, thank you!
Is this function invertible?
Hi,
Thank you very much for providing this code.
Unfortunately, I don't see any imports, like in the paper (https://arxiv.org/pdf/2002.06504.pdf). Do you know how to import "Function" on line 77?