Last active
July 20, 2020 09:45
-
-
Save avinashselvam/c9704c36301d257e7b835ac7b29794cd to your computer and use it in GitHub Desktop.
Hungarian Maximum Matching Algorithm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Node: | |
""" | |
Vertex / Node in a bipartite graph | |
Attributes | |
---------- | |
id : int | |
Unique identifier within a set | |
set: int | |
0 means belongs to left set of bipartite graph 1 means right | |
label : int | |
A real value assigned under some condition | |
match : Node | |
Node in the other set to which this node is matched to | |
visitied : bool | |
To keep track of depth first search | |
Methods | |
------- | |
is_left() | |
Check if the node belongs to the left set | |
is_right() | |
Check if the node belongs to the right set | |
is_matched() | |
Check if the node has been already matched | |
set_label() | |
Convenience method to write to label | |
visit() | |
Mark node as visited during graph traversal | |
unvisit() | |
Mark node as unvisited during graph traversal | |
""" | |
def __init__(self, uid, left_or_right): | |
self.id = uid | |
self.label = 0 | |
self.set = left_or_right | |
self.match = None | |
self.visited = False | |
def __hash__(self): | |
return hash((self.set, self.id)) | |
def __repr__(self): | |
return "<l/r:{}, id:{}, label:{}>".format(self.set, self.id, self.label) | |
def is_left(self): | |
return self.set == 0 | |
def is_right(self): | |
return self.set == 1 | |
def is_matched(self): | |
return self.match is not None | |
def set_label(self, label): | |
self.label = label | |
def visit(self): | |
self.visited = True | |
def unvisit(self): | |
self.visited = False | |
class Edge: | |
""" | |
Edge connecting two nodes in a bipartite graph | |
Attributes | |
---------- | |
left : Node | |
Node in the left set | |
right : Node | |
Node in the right set | |
weight: int | |
Weight of the edge in graph theory terms | |
Methods | |
------- | |
is_tight() | |
Check if the edge belongs to equality graph | |
i.e label of left node + label of right node = weight | |
""" | |
def __init__(self, left_node, right_node, weight): | |
self.left = left_node | |
self.right = right_node | |
self.weight = weight | |
def __hash__(self): | |
return (self.left.id, self.right.id) | |
def __repr__(self): | |
return "<l_id:{}, r_id:{}, w:{}>".format(self.left.id, self.right.id, self.weight) | |
def is_tight(self): | |
return (self.left.label + self.right.label == self.weight) | |
class Hungarian: | |
""" | |
Implements Hungarian Maximum Matching Algorithm in a bipartite graph | |
Attributes | |
---------- | |
cost : [[int]] | |
Cost matrix that specifies the weights of all edges in the bipartite graph | |
N : int | |
Number of nodes in either of the sets of the bipartite graph | |
X : [Node] | |
Left set of the bipartite graph | |
Y : [Node] | |
Right set of the bipartite graph | |
E : [[Edge]] | |
Adjacency matrix of the bipartite graph | |
Note : method names starting with _ should only be called on self | |
Methods | |
------- | |
_add_edges(N, X, Y, cost) | |
Constructs E from N, X, Y | |
_init_labels() | |
Assigns labels to nodes based on the equality graph condition | |
_reset_visit_status() | |
Set all nodes' visited as False to prepare for next DFS traversal | |
_alternating_dfs(root, path, augmenting_path, candidate_path) | |
DFS traversal of the graph to find augmenting path if not candidate path | |
_augment(path) | |
Augments the existing matching with newly found augmenting path | |
_find_augmenting_path() | |
Finds free node and begins alternating DFS from there | |
_perfect_match_not_found() | |
Checks if the current match is perfect or not | |
_update_node_labels() | |
Givent the candidate path it updates the node labels so we can find an augmented path | |
match() | |
main function that runs the algorithm | |
""" | |
def __init__(self, cost): | |
assert len(cost) == len(cost[0]), "Only square cost matrix is supported" | |
self.cost = cost | |
self.N = len(cost) | |
self.X = [Node(i, 0) for i in range(self.N)] | |
self.Y = [Node(i, 1) for i in range(self.N)] | |
self._add_edges(self.N, self.X, self.Y, self.cost) | |
self._init_labels() | |
def _add_edges(self, N, X, Y, cost): | |
self.E = [[Edge(X[i], Y[j], cost[i][j]) for j in range(N)]for i in range(N)] | |
def _init_labels(self): | |
for i in range(self.N): | |
self.X[i].set_label(max([edge.weight for edge in self.E[i]])) | |
self.Y[i].set_label(0) | |
def _reset_visit_status(self): | |
for node in self.X: node.unvisit() | |
for node in self.Y: node.unvisit() | |
def _alternating_dfs(self, root, path, augmenting_path, candidate_path): | |
if root.visited: return | |
root.visit() | |
uid = root.id | |
if root.is_left(): | |
for edge in self.E[uid]: | |
if edge.is_tight(): | |
if edge.right.is_matched(): self._alternating_dfs(edge.right, path+[edge.right], augmenting_path, candidate_path) | |
else: augmenting_path[0] = path + [edge.right] | |
elif root.is_right(): | |
candidate_path[0] = path+[root.match] | |
self._alternating_dfs(root.match, path+[root.match], augmenting_path, candidate_path) | |
def _augment(self, path): | |
print("augmenting with: ", path) | |
i = 0 | |
n = len(path) | |
while i < n: | |
node1, node2 = path[i], path[i+1] | |
node1.match = node2 | |
node2.match = node1 | |
i += 2 | |
def _find_augmenting_path(self): | |
root = next(node for node in self.X if not node.is_matched()) | |
augmenting_path = [None] # pass by reference array trick | |
candidate_path = [None] # pass by reference array trick | |
self._alternating_dfs(root, [root], augmenting_path, candidate_path) | |
self._reset_visit_status() | |
return (augmenting_path[0], candidate_path[0]) | |
def _perfect_match_not_found(self): | |
return False in [node.is_matched() for node in self.X] | |
def _update_node_labels(self, S, T): | |
delta = 10000000 | |
notT = set(self.Y) - T | |
for left_node in S: | |
left_label = left_node.label | |
for right_node in notT: | |
right_label = right_node.label | |
weight = self.E[left_node.id][right_node.id].weight | |
delta = min(delta, left_label + right_label - weight) | |
print("updating labels of: ", S, T, "with: ", delta) | |
for node in S: node.set_label(node.label-delta) | |
for node in T: node.set_label(node.label+delta) | |
def match(self): | |
while self._perfect_match_not_found(): | |
augmenting_path, candidate_path = self._find_augmenting_path() | |
if augmenting_path: self._augment(augmenting_path) | |
else: | |
S = set(candidate_path[0::2]) | |
T = set(candidate_path[1::2]) | |
self._update_node_labels(S, T) | |
return self.X | |
# TESTING | |
cost = [ | |
[2, 3, 4, 5], | |
[6, 5, 4, 8], | |
[5, 9, 2, 8], | |
[4, 6, 3, 1] | |
] | |
h = Hungarian(cost) | |
X = h.match() | |
print([(node.id, node.match.id) for node in X]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment