Created
March 5, 2016 16:36
-
-
Save h0tk3y/e2a820d61dd07d214fef to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| private const val SOURCE_ID = 0 | |
| private const val SINK_ID = 1 | |
| class IntEmd<TItem>(val distance: ((TItem, TItem) -> Int)? = null) | |
| : HistogramDistance<TItem, Int> { | |
| /** Adds two [TNumeric] items. */ | |
| infix fun Int.plus(other: Int) = this + other | |
| /** Subtracts two [TNumeric] items. */ | |
| infix fun Int.minus(other: Int) = this - other | |
| /** Multiplies two [TNumeric] items. */ | |
| infix fun Int.times(other: Int) = this * other | |
| protected inner class Edge(val cost: Int, | |
| val capacity: Int, | |
| var flow: Int = 0) | |
| private inner class FlowNetwork(val graph: Map<Int, Map<Int, Edge>>) { | |
| val maxNode = graph.keys.max()!! | |
| val infCost = graph.flatMap { it.value.values.map { it.cost } }.sum() * 2 | |
| private fun cheapestResidualPath(from: Int, to: Int): List<Int>? { | |
| //prev[k] = v such that edge (v -> k) is optimal | |
| val prev = HashMap<Int, Int>() | |
| //m[k] = cheapest path cost | |
| val m = HashMap<Int, Int>() | |
| m[from] = 0 | |
| for (i in 0..maxNode) { | |
| for ((nFrom, es) in graph) { | |
| for ((nTo, e) in es) { | |
| if (e.flow < e.capacity) { | |
| //the edge is in residual network | |
| val candidateCost = (m[nFrom] ?: infCost) + e.cost | |
| if (m[nTo] ?: infCost > candidateCost) { | |
| m[nTo] = candidateCost | |
| prev[nTo] = nFrom | |
| } | |
| } | |
| if (e.flow > 0) { | |
| //the reverse edge is in residual network | |
| val candidateCost = (m[nTo] ?: (infCost + e.cost)) - e.cost | |
| if (m[nFrom] ?: infCost > candidateCost) { | |
| m[nFrom] = candidateCost | |
| prev[nFrom] = nTo | |
| } | |
| } | |
| } | |
| } | |
| } | |
| if (m[to] == null) | |
| return null | |
| val result = generateSequence(to) { if (it == from) null else prev[it] }.take(maxNode).toList().reversed() | |
| if (result.first() != SOURCE_ID) | |
| return null | |
| return result | |
| } | |
| fun findMaxFlowMinCost() { | |
| while (true) { | |
| val path = cheapestResidualPath(SOURCE_ID, SINK_ID) | |
| ?: break | |
| val nodePairs = path.zip(path.drop(1)) | |
| val increase = nodePairs.map { | |
| graph[it.first]?.get(it.second)?.let { it.capacity minus it.flow } ?: | |
| graph[it.second]?.get(it.first)!!.flow | |
| }.min()!! | |
| for ((from, to) in nodePairs) { | |
| val direct = graph[from]?.get(to) | |
| if (direct != null) { | |
| direct.flow = direct.flow + increase | |
| } else { | |
| val reverse = graph[to]!![from]!! | |
| reverse.flow = reverse.flow - increase | |
| } | |
| } | |
| } | |
| } | |
| } | |
| override fun histogramDistance(h1: Map<TItem, Int>, | |
| h2: Map<TItem, Int>): Int { | |
| val distancesCache = HashMap<Pair<TItem, TItem>, Int>() | |
| fun distance(t1: TItem, t2: TItem) = | |
| distancesCache[Pair(t1, t2)] ?: | |
| distancesCache[Pair(t2, t1)] ?: | |
| (distance?.invoke(t1, t2) ?: | |
| @Suppress("UNCHECKED_CAST") (t1 as DistanceMeasurable<in TItem, Int>) distanceTo t2) | |
| .apply { | |
| distancesCache[Pair(t1, t2)] = this | |
| distancesCache[Pair(t2, t1)] = this | |
| } | |
| //First, we build a bipartite flow network, left part is items from h1, right -- from h2. | |
| var nextNodeId = SINK_ID + 1 | |
| val mapLeftIds = h1.keys.associateBy({ it }, { nextNodeId++ }) | |
| val mapRightIds = h2.keys.associateBy({ it }, { nextNodeId++ }) | |
| val graph = HashMap<Int, HashMap<Int, Edge>>() | |
| //add edges from source to the left part nodes, cost = 0, capacity = h1[t] | |
| val edgesSource = HashMap<Int, Edge>().apply { graph[SOURCE_ID] = this } | |
| for ((n, v) in mapLeftIds.map { it.value to h1[it.key]!! }) { | |
| edgesSource[n] = Edge(0, v) | |
| } | |
| //add edges from the right part nodes to sink, cost = 0, capacity = h2[t] | |
| for ((n, v) in mapRightIds.map { it.value to h2[it.key]!! }) { | |
| graph[n] = hashMapOf(SINK_ID to Edge(0, v)) | |
| } | |
| //add edges between the left and the right part nodes, cost = d(t1, t2), capacity such that all the flow | |
| //from the left part can go through any edge | |
| val capacity = h1.values.sum() * 2 | |
| for ((tl, hl) in h1) { | |
| val nodeLeft = mapLeftIds[tl]!! | |
| val edges = HashMap<Int, Edge>().apply { graph[nodeLeft] = this } | |
| for ((tr, hr) in h2) { | |
| val nodeRight = mapRightIds[tr]!! | |
| val d = distance(tl, tr) | |
| edges[nodeRight] = Edge(d, capacity) | |
| } | |
| } | |
| val network = FlowNetwork(graph) | |
| network.findMaxFlowMinCost() | |
| return graph.values.flatMap { it.values }.sumBy { it.cost * it.flow } | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment