Skip to content

Instantly share code, notes, and snippets.

@noskill
Last active February 10, 2023 16:49
Show Gist options
  • Save noskill/c32f2e6148e14f2f9a2945b03bab42ac to your computer and use it in GitHub Desktop.
Save noskill/c32f2e6148e14f2f9a2945b03bab42ac to your computer and use it in GitHub Desktop.
sinkhorn algorithm with pytorch and numpy
import torch
import numpy
np = numpy
from geomloss import SamplesLoss # See also ImagesLoss, VolumesLoss
# preferences, need to be converted to costs
# row i = cost of moving each item from c to place i
# making cost non-negative will not changes solution matrix P
preference = numpy.asarray([[2, 2, 1 , 0 ,0],
[0,-2,-2,-2, 2],
[1, 2, 2, 2, -1],
[2, 1, 0, 1, -1],
[0.5, 2, 2, 1, 0],
[0, 1,1, 1, -1],
[-2, 2, 2, 1, 1],
[2, 1, 2, 1, -1]])
# how much do we have place awailable at place
r = (3,3,3,4,2,2,2,1)
r = torch.from_numpy(numpy.asarray(r)).float()
# how mach do we need to transfer from each place
c = (4,2,6,4,4)
c = torch.from_numpy(numpy.asarray(c)).float()
x = torch.from_numpy(preference).float()
# from here https://michielstock.github.io/OptimalTransport/
def compute_optimal_transport(M, r, c, lam, epsilon=1e-8):
"""
Computes the optimal transport matrix and Slinkhorn distance using the
Sinkhorn-Knopp algorithm
Inputs:
- M : cost matrix (n x m)
- r : vector of marginals (n, )
- c : vector of marginals (m, )
- lam : strength of the entropic regularization
- epsilon : convergence parameter
Outputs:
- P : optimal transport matrix (n x m)
- dist : Sinkhorn distance
"""
n, m = M.shape
P = torch.exp(- lam * M)
P = P / P.sum()
u = torch.ones(n)
# normalize this matrix
i = 0
while torch.max(torch.abs(u - P.sum(dim=1))) > epsilon:
u = P.sum(dim=1)
P *= (r / u).reshape((-1, 1))
P *= (c / P.sum(0)).reshape((1, -1))
i += 1
print(i)
return P, torch.sum(P * M)
def optimal_transport(M, r, c, lam, epsilon=1e-8):
n, m = M.shape
Kinit = torch.exp(- M * lam)
K = torch.diag(1./r).mm(Kinit)
# somehow faster
u = r
v = c
vprev = v * 2
i = 0
while(torch.abs(v - vprev).sum() > epsilon):
vprev = v
# changing order affects convergence a little bit
v = c / K.T.matmul(u)
u = r / K.matmul(v)
i += 1
print(i)
P = torch.diag(u) @ K @ torch.diag(v)
return P, torch.sum(P * M)
# see https://arxiv.org/pdf/1612.02273.pdf
# https://arxiv.org/pdf/1712.03082.pdf
# but instead i multiply by lam like in code above
def optimal_transport_np(M, r, c, lam, epsilon=1e-8):
n, m = M.shape
Kinit = np.exp(- M * lam)
K = np.diag(1./r).dot(Kinit)
u = r
v = c
vprev = v * 2
i = 0
while(np.abs(v - vprev).sum() > epsilon):
vprev = v
v = c / K.T.dot(u)
u = r / K.dot(v)
i += 1
print(i)
P = np.diag(u) @ K @ np.diag(v)
return P, np.sum(P * M)
P, cost = compute_optimal_transport(x * -1, r, c, 5)
print(P)
P, cost = optimal_transport(x * -1, r, c, 5)
print(P)
# shifting cost above zero will not change the solution P
x = x * -1
x = x - x.min()
P, cost = optimal_transport_np(x.numpy(), r.numpy(), c.numpy(), 5)
print(P)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment