Skip to content

Instantly share code, notes, and snippets.

@bkj
Last active December 19, 2017 23:15
Show Gist options
  • Save bkj/e559d6cb01e4515a5aa5cd690a3d65a8 to your computer and use it in GitHub Desktop.
Save bkj/e559d6cb01e4515a5aa5cd690a3d65a8 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
"""
auction-lap.py
From
https://dspace.mit.edu/bitstream/handle/1721.1/3265/P-2108-26912652.pdf;sequence=1
"""
from __future__ import print_function, division
import torch
import numpy as np
from time import time
from lap import lapjv # gatagat
from lapjv import lapjv as lapjv2 # src-d
def auction_lap(X, eps=None):
eps = 1 / X.shape[0] if eps is None else eps
# --
# Init
cost = torch.zeros((1, X.shape[1])).cuda()
curr_ass = torch.zeros(X.shape[0]).long().cuda() - 1
bids = torch.zeros(X.shape).cuda()
while (curr_ass == -1).any():
# --
# Bidding
unassigned = (curr_ass == -1).nonzero().squeeze()
value = X[unassigned] - cost
top_value, top_idx = value.topk(2, dim=1)
first_idx = top_idx[:,0]
first_value, second_value = top_value[:,0], top_value[:,1]
bid_increments = first_value - second_value + eps
bids_ = bids[unassigned]
bids_.zero_()
bids_.scatter_(
dim=1,
index=first_idx.contiguous().view(-1, 1),
src=bid_increments.view(-1, 1)
)
# --
# Assignment
have_bidder = (bids_ > 0).sum(dim=0).nonzero()
high_bids, high_bidders = bids_[:,have_bidder].max(dim=0)
high_bidders = unassigned[high_bidders.squeeze()]
cost[:,have_bidder] += high_bids
curr_ass[(curr_ass.view(-1, 1) == have_bidder.view(1, -1)).sum(dim=1)] = -1
curr_ass[high_bidders] = have_bidder.squeeze()
score = X_.gather(dim=1, index=curr_ass.view(-1, 1)).sum()
return score, curr_ass
# --
# Compare
np.random.seed(123)
N = 20000
X = np.random.choice(1000, (N, N))
X_ = torch.from_numpy(X).float().cuda()
# Run JV solver
t = time()
_, lap_ass, _ = lapjv(X.max() - X)
lap_score = X[(np.arange(X.shape[0]), lap_ass)].sum()
lap_time = time() - t
# Run other JV solver
t = time()
lap_ass2, _, _ = lapjv2(X.max() - X)
lap_score2 = X[(np.arange(X.shape[0]), lap_ass2)].sum()
lap_time2 = time() - t
# Run auction solver
t = time()
auction_score, auction_ass = auction_lap(X_, eps=10) # Score is accurate to within n * eps
auction_time = time() - t
auction_time
print((lap_score, lap_score2, auction_score))
print((lap_time, lap_time2, auction_time))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment