Skip to content

Instantly share code, notes, and snippets.

@Pangoraw
Last active February 20, 2023 10:47
Show Gist options
  • Save Pangoraw/e7945d737eb563dd43c17d77e355a27d to your computer and use it in GitHub Desktop.
Save Pangoraw/e7945d737eb563dd43c17d77e355a27d to your computer and use it in GitHub Desktop.
import torch
from torch import Tensor
def batch_sinkhorn(
a: Tensor,
b: Tensor,
C: Tensor,
reg: float,
max_iters: int = 10,
) -> Tensor:
"""
Solve a batch of Entropically regularized optimal transport problems
using the Sinkhorn-Knopp algorithm.
Parameters
==========
a: Tensor - size (b,n1)
b: Tensor - size (b,n2)
C: Tensor - size (b,n1,n2)
reg: float - entropic regularization (lambda)
max_iters: int - the number of iterations
Returns
=======
plans: Tensor - size (b,n1,n2) optimal transport plans
"""
K = (-C / reg).exp()
u = torch.ones_like(a)
for _ in range(max_iters):
v = b / torch.einsum("...ij,...i", K, u)
u = a / torch.einsum("...ij,...j", K, v)
return u.unsqueeze(-1) * K * v.unsqueeze(-2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment