Skip to content

Instantly share code, notes, and snippets.

@h0tk3y
Created March 5, 2016 16:36
Show Gist options
  • Select an option

  • Save h0tk3y/e2a820d61dd07d214fef to your computer and use it in GitHub Desktop.

Select an option

Save h0tk3y/e2a820d61dd07d214fef to your computer and use it in GitHub Desktop.
private const val SOURCE_ID = 0
private const val SINK_ID = 1
class IntEmd<TItem>(val distance: ((TItem, TItem) -> Int)? = null)
: HistogramDistance<TItem, Int> {
/** Adds two [TNumeric] items. */
infix fun Int.plus(other: Int) = this + other
/** Subtracts two [TNumeric] items. */
infix fun Int.minus(other: Int) = this - other
/** Multiplies two [TNumeric] items. */
infix fun Int.times(other: Int) = this * other
protected inner class Edge(val cost: Int,
val capacity: Int,
var flow: Int = 0)
private inner class FlowNetwork(val graph: Map<Int, Map<Int, Edge>>) {
val maxNode = graph.keys.max()!!
val infCost = graph.flatMap { it.value.values.map { it.cost } }.sum() * 2
private fun cheapestResidualPath(from: Int, to: Int): List<Int>? {
//prev[k] = v such that edge (v -> k) is optimal
val prev = HashMap<Int, Int>()
//m[k] = cheapest path cost
val m = HashMap<Int, Int>()
m[from] = 0
for (i in 0..maxNode) {
for ((nFrom, es) in graph) {
for ((nTo, e) in es) {
if (e.flow < e.capacity) {
//the edge is in residual network
val candidateCost = (m[nFrom] ?: infCost) + e.cost
if (m[nTo] ?: infCost > candidateCost) {
m[nTo] = candidateCost
prev[nTo] = nFrom
}
}
if (e.flow > 0) {
//the reverse edge is in residual network
val candidateCost = (m[nTo] ?: (infCost + e.cost)) - e.cost
if (m[nFrom] ?: infCost > candidateCost) {
m[nFrom] = candidateCost
prev[nFrom] = nTo
}
}
}
}
}
if (m[to] == null)
return null
val result = generateSequence(to) { if (it == from) null else prev[it] }.take(maxNode).toList().reversed()
if (result.first() != SOURCE_ID)
return null
return result
}
fun findMaxFlowMinCost() {
while (true) {
val path = cheapestResidualPath(SOURCE_ID, SINK_ID)
?: break
val nodePairs = path.zip(path.drop(1))
val increase = nodePairs.map {
graph[it.first]?.get(it.second)?.let { it.capacity minus it.flow } ?:
graph[it.second]?.get(it.first)!!.flow
}.min()!!
for ((from, to) in nodePairs) {
val direct = graph[from]?.get(to)
if (direct != null) {
direct.flow = direct.flow + increase
} else {
val reverse = graph[to]!![from]!!
reverse.flow = reverse.flow - increase
}
}
}
}
}
override fun histogramDistance(h1: Map<TItem, Int>,
h2: Map<TItem, Int>): Int {
val distancesCache = HashMap<Pair<TItem, TItem>, Int>()
fun distance(t1: TItem, t2: TItem) =
distancesCache[Pair(t1, t2)] ?:
distancesCache[Pair(t2, t1)] ?:
(distance?.invoke(t1, t2) ?:
@Suppress("UNCHECKED_CAST") (t1 as DistanceMeasurable<in TItem, Int>) distanceTo t2)
.apply {
distancesCache[Pair(t1, t2)] = this
distancesCache[Pair(t2, t1)] = this
}
//First, we build a bipartite flow network, left part is items from h1, right -- from h2.
var nextNodeId = SINK_ID + 1
val mapLeftIds = h1.keys.associateBy({ it }, { nextNodeId++ })
val mapRightIds = h2.keys.associateBy({ it }, { nextNodeId++ })
val graph = HashMap<Int, HashMap<Int, Edge>>()
//add edges from source to the left part nodes, cost = 0, capacity = h1[t]
val edgesSource = HashMap<Int, Edge>().apply { graph[SOURCE_ID] = this }
for ((n, v) in mapLeftIds.map { it.value to h1[it.key]!! }) {
edgesSource[n] = Edge(0, v)
}
//add edges from the right part nodes to sink, cost = 0, capacity = h2[t]
for ((n, v) in mapRightIds.map { it.value to h2[it.key]!! }) {
graph[n] = hashMapOf(SINK_ID to Edge(0, v))
}
//add edges between the left and the right part nodes, cost = d(t1, t2), capacity such that all the flow
//from the left part can go through any edge
val capacity = h1.values.sum() * 2
for ((tl, hl) in h1) {
val nodeLeft = mapLeftIds[tl]!!
val edges = HashMap<Int, Edge>().apply { graph[nodeLeft] = this }
for ((tr, hr) in h2) {
val nodeRight = mapRightIds[tr]!!
val d = distance(tl, tr)
edges[nodeRight] = Edge(d, capacity)
}
}
val network = FlowNetwork(graph)
network.findMaxFlowMinCost()
return graph.values.flatMap { it.values }.sumBy { it.cost * it.flow }
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment