Created
June 21, 2025 13:47
-
-
Save ruccho/55601a91e900b1a749f2f6c250e707a1 to your computer and use it in GitHub Desktop.
Allocation-less hungarian algorithm implementation written in C#
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Span<float> cost = stackalloc float[16]; | |
Span<float> tempCost = stackalloc float[16]; | |
Span<int> result = stackalloc int[4]; | |
var random = new Random(); | |
for (var i = 0; i < 100; i++) | |
{ | |
for (var j = 0; j < cost.Length; j++) cost[j] = random.NextSingle(); | |
cost.CopyTo(tempCost); | |
HungarianMatcher.Solve(cost, 4, result); | |
PrintMatrix(tempCost, 4, result); | |
Console.WriteLine(); | |
} | |
void PrintMatrix(Span<float> matrix, int size, Span<int> result) | |
{ | |
Console.WriteLine("Matrix:"); | |
for (var y = 0; y < size; y++) | |
{ | |
var r = result[y]; | |
for (var x = 0; x < size; x++) | |
{ | |
Console.Write($"{matrix[y * size + x]:F2}"); | |
if (r == x) Console.Write("* "); | |
else Console.Write(" "); | |
} | |
Console.WriteLine(); | |
} | |
} | |
public ref struct HungarianMatcher | |
{ | |
private Span<float> _data; | |
private int _size; | |
public static void Solve(Span<float> costMatrix, int width, Span<int> result) | |
{ | |
if (costMatrix.Length != width * width) throw new ArgumentException("Matrix size mismatch"); | |
if (result.Length != width) throw new ArgumentException("Result size mismatch"); | |
var matcher = new HungarianMatcher | |
{ | |
_data = costMatrix, | |
_size = width | |
}; | |
matcher.Solve(result); | |
} | |
private void Solve(Span<int> result) | |
{ | |
Reduce(); | |
if (TryMatch(result)) return; | |
MarkAndSolve(result); | |
} | |
private void Reduce() | |
{ | |
for (var y = 0; y < _size; y++) | |
{ | |
var min = float.MaxValue; | |
for (var x = 0; x < _size; x++) | |
{ | |
var value = _data[y * _size + x]; | |
if (value < min) | |
min = value; | |
} | |
for (var x = 0; x < _size; x++) _data[y * _size + x] -= min; | |
} | |
for (var x = 0; x < _size; x++) | |
{ | |
var min = float.MaxValue; | |
for (var y = 0; y < _size; y++) | |
{ | |
var value = _data[y * _size + x]; | |
if (value < min) | |
min = value; | |
} | |
for (var y = 0; y < _size; y++) _data[y * _size + x] -= min; | |
} | |
} | |
private bool TryMatch(Span<int> result) | |
{ | |
result.Clear(); | |
var row = result; | |
Span<int> column = stackalloc int[_size]; | |
do | |
{ | |
var any = false; | |
for (var y = 0; y < _size; y++) | |
{ | |
if (row[y] != 0) continue; | |
var singleZero = -1; | |
for (var x = 0; x < _size; x++) | |
{ | |
if (column[x] != 0) continue; | |
if (_data[y * _size + x] == 0) | |
{ | |
if (singleZero == -1) | |
{ | |
singleZero = x; | |
} | |
else | |
{ | |
singleZero = -1; | |
break; | |
} | |
} | |
} | |
if (singleZero >= 0) | |
{ | |
row[y] = singleZero + 1; | |
column[singleZero] = y + 1; | |
any = true; | |
} | |
} | |
for (var x = 0; x < _size; x++) | |
{ | |
if (column[x] != 0) continue; | |
var singleZero = -1; | |
for (var y = 0; y < _size; y++) | |
{ | |
if (row[y] != 0) continue; | |
if (_data[y * _size + x] == 0) | |
{ | |
if (singleZero == -1) | |
{ | |
singleZero = y; | |
} | |
else | |
{ | |
singleZero = -1; | |
break; | |
} | |
} | |
} | |
if (singleZero >= 0) | |
{ | |
column[x] = singleZero + 1; | |
row[singleZero] = x; | |
any = true; | |
} | |
} | |
if (!any) | |
{ | |
// check if completed | |
for (var i = 0; i < _size; i++) | |
if (row[i] == 0 || column[i] == 0) | |
{ | |
// select first zero | |
for (var y = 0; y < _size; y++) | |
for (var x = 0; x < _size; x++) | |
if (_data[y * _size + x] == 0 && row[y] == 0 && column[x] == 0) | |
{ | |
row[y] = x + 1; | |
column[x] = y + 1; | |
goto Loop; | |
} | |
return false; | |
} | |
foreach (ref var i in result) i -= 1; | |
// completed | |
return true; | |
} | |
Loop: ; | |
} while (true); | |
} | |
private void MarkAndSolve(Span<int> result) | |
{ | |
Span<bool> columnCovered = stackalloc bool[_size]; | |
Span<bool> rowCovered = stackalloc bool[_size]; | |
Span<byte> mark = stackalloc byte[_size * _size]; | |
for (var y = 0; y < _size; y++) | |
for (var x = 0; x < _size; x++) | |
if (!columnCovered[x] && _data[y * _size + x] == 0) | |
{ | |
columnCovered[x] = true; | |
mark[y * _size + x] = 1; // mark as starred | |
break; | |
} | |
Refind: | |
for (var y = 0; y < _size; y++) | |
{ | |
if (rowCovered[y]) continue; | |
for (var x = 0; x < _size; x++) | |
{ | |
if (columnCovered[x]) continue; | |
if (_data[y * _size + x] == 0) | |
{ | |
mark[y * _size + x] = 2; // mark as primed | |
for (var x1 = 0; x1 < _size; x1++) | |
if (mark[y * _size + x1] == 1) | |
{ | |
columnCovered[x1] = false; | |
rowCovered[y] = true; | |
goto Refind; | |
} | |
// star not found on this row | |
// find a star on this column | |
Loop: | |
mark[y * _size + x] = 1; // star | |
for (var y1 = 0; y1 < _size; y1++) | |
{ | |
if (y1 == y) continue; | |
if (mark[y1 * _size + x] == 1) | |
{ | |
mark[y1 * _size + x] = 0; // unstar | |
// find prime | |
for (var x1 = 0; x1 < _size; x1++) | |
if (mark[y1 * _size + x1] == 2) | |
{ | |
x = x1; | |
y = y1; | |
goto Loop; | |
} | |
throw new InvalidOperationException("Something went wrong"); | |
} | |
} | |
columnCovered.Clear(); | |
rowCovered.Clear(); | |
for (var i = 0; i < mark.Length; i++) | |
{ | |
if (mark[i] == 2) mark[i] = 0; // remove prime | |
if (mark[i] == 1) columnCovered[i % _size] = true; | |
} | |
goto Refind; | |
} | |
} | |
} | |
// count star | |
var count = 0; | |
for (var i = 0; i < mark.Length; i++) | |
if (mark[i] == 1) | |
count++; | |
if (count == _size) | |
{ | |
// end | |
for (var i = 0; i < mark.Length; i++) | |
if (mark[i] == 1) | |
result[i / _size] = i % _size; | |
return; | |
} | |
// min | |
var minUncovered = float.MaxValue; | |
for (var y = 0; y < _size; y++) | |
{ | |
if (rowCovered[y]) continue; | |
for (var x = 0; x < _size; x++) | |
{ | |
if (columnCovered[x]) continue; | |
var value = _data[y * _size + x]; | |
if (value < minUncovered) minUncovered = value; | |
} | |
} | |
// ReSharper disable once CompareOfFloatsByEqualityOperator | |
if (minUncovered == float.MaxValue) throw new InvalidOperationException("No minimum uncovered value found"); | |
for (var y = 0; y < _size; y++) | |
{ | |
var add = rowCovered[y] ? 0 : -minUncovered; | |
for (var x = 0; x < _size; x++) _data[y * _size + x] += add + (columnCovered[x] ? minUncovered : 0); | |
} | |
goto Refind; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment