Skip to content

Instantly share code, notes, and snippets.

@spdskatr
Created October 10, 2020 11:08
Show Gist options
  • Save spdskatr/eeb6e7cb4a85b23f963566187df930fd to your computer and use it in GitHub Desktop.
Save spdskatr/eeb6e7cb4a85b23f963566187df930fd to your computer and use it in GitHub Desktop.
Hungarian Algorithm O(N^4) implementation in C++
#include <cstdio>
#include <cstdlib>
#include <cassert>
#include <algorithm>
#include <vector>
#include <utility>
#include <set>
using namespace std;
// Prepared for inputs up to N = 100 :)
int N, a[105][105], in[105][105], potx[105], poty[105];
int match[105], matchs[105], matchings;
int inS[105], inT[105], from[105], Tsize, NSsize;
int unmatched;
void recalc() {
for (int i = 1; i <= N; i++) for (int j = 1; j <= N; j++) {
in[i][j] = a[i][j] == potx[i] + poty[j];
}
}
void calcNS() {
for (int u = 1; u <= N; u++) if (inS[u]) {
for (int v = 1; v <= N; v++) if (in[u][v] && !from[v]) {
from[v] = u;
NSsize++;
}
}
}
int main() {
// Read in input
scanf("%d", &N);
for (int i = 1; i <= N; i++) for (int j = 1; j <= N; j++) {
scanf("%d", &a[i][j]);
}
// Initialise potentials
for (int i = 1; i <= N; i++) {
poty[i] = 2069696969;
for (int j = 1; j <= N; j++) poty[i] = min(poty[i], a[j][i]);
}
recalc();
unmatched = 1;
inS[unmatched] = 1;
calcNS();
while (matchings < N) {
int success = 0;
while (NSsize > Tsize) {
int notInT = 0;
for (int x = 1; x <= N; x++) if (from[x] != 0) {
if (!inT[x]) {
notInT = x;
break;
}
}
assert(notInT != 0);
inT[notInT] = 1;
Tsize++;
if (match[notInT]) {
inS[match[notInT]] = 1;
for (int v = 1; v <= N; v++) if (in[match[notInT]][v] && from[v] == 0) {
from[v] = match[notInT];
NSsize++;
}
} else {
int cur = notInT;
while (cur != 0) {
match[cur] = from[cur];
int temp = matchs[from[cur]];
matchs[from[cur]] = cur;
cur = temp;
}
matchings++;
success = 1;
break;
}
}
if (success) {
if (matchings == N) break;
// Reset everything
Tsize = 0, NSsize = 0;
fill(inS+1, inS+N+1, 0);
fill(inT+1, inT+N+1, 0);
fill(from+1, from+N+1, 0);
unmatched = 0;
for (int u = 1; u <= N; u++) if (!matchs[u]) {
unmatched = u;
}
assert(unmatched != 0);
inS[unmatched] = 1;
calcNS();
} else {
int diff = 2069696969;
for (int u = 1; u <= N; u++)
for (int v = 1; v <= N; v++)
if (inS[u] && !inT[v])
diff = min(diff, a[u][v] - potx[u] - poty[v]);
assert(diff != 2069696969 && diff > 0);
for (int u = 1; u <= N; u++) {
if (inS[u]) potx[u] += diff;
if (inT[u]) poty[u] -= diff;
}
recalc();
calcNS();
}
}
int ans = 0;
for (int u = 1; u <= N; u++) ans += (potx[u] + poty[u]);
printf("%d\n", ans);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment