-
-
Save wisnunugroho21/e020b5fca1a93d5441bc7b4319e191cf to your computer and use it in GitHub Desktop.
# Pytorch implementation of Hungarian Algorithm | |
# Inspired from here : https://python.plainenglish.io/hungarian-algorithm-introduction-python-implementation-93e7c0890e15 | |
# Despite my effort to parallelize the code, there is still some sequential workflows in this code | |
from typing import Tuple | |
import torch | |
from torch import Tensor | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
def min_zero_row(zero_mat: Tensor) -> Tuple[Tensor, Tensor]: | |
sum_zero_mat = zero_mat.sum(1) | |
sum_zero_mat[sum_zero_mat == 0] = 9999 | |
zero_row = sum_zero_mat.min(0)[1] | |
zero_column = zero_mat[zero_row].nonzero()[0] | |
zero_mat[zero_row, :] = False | |
zero_mat[:, zero_column] = False | |
mark_zero = torch.tensor([[zero_row, zero_column]], device = device) | |
return zero_mat, mark_zero | |
def mark_matrix(mat: Tensor) -> Tuple[Tensor, Tensor, Tensor]: | |
zero_bool_mat = (mat == 0) | |
zero_bool_mat_copy = zero_bool_mat.clone() | |
marked_zero = torch.tensor([], device = device) | |
while (True in zero_bool_mat_copy): | |
zero_bool_mat_copy, mark_zero = min_zero_row(zero_bool_mat_copy) | |
marked_zero = torch.concat([marked_zero, mark_zero], dim = 0) | |
marked_zero_row = marked_zero[:, 0] | |
marked_zero_col = marked_zero[:, 1] | |
arange_index_row = torch.arange(mat.shape[0], dtype=torch.float, device = device).unsqueeze(1) | |
repeated_marked_row = marked_zero_row.repeat(mat.shape[0], 1) | |
bool_non_marked_row = torch.all(arange_index_row != repeated_marked_row, dim = 1) | |
non_marked_row = arange_index_row[bool_non_marked_row].squeeze() | |
non_marked_mat = zero_bool_mat[non_marked_row.long(), :] | |
marked_cols = non_marked_mat.nonzero().unique() | |
is_need_add_row = True | |
while is_need_add_row: | |
repeated_non_marked_row = non_marked_row.repeat(marked_zero_row.shape[0], 1) | |
repeated_marked_cols = marked_cols.repeat(marked_zero_col.shape[0], 1) | |
first_bool = torch.all(marked_zero_row.unsqueeze(1) != repeated_non_marked_row, dim = 1) | |
second_bool = torch.any(marked_zero_col.unsqueeze(1) == repeated_marked_cols, dim = 1) | |
addit_non_marked_row = marked_zero_row[first_bool & second_bool] | |
if addit_non_marked_row.shape[0] > 0: | |
non_marked_row = torch.concat([non_marked_row.reshape(-1), addit_non_marked_row[0].reshape(-1)]) | |
else: | |
is_need_add_row = False | |
repeated_non_marked_row = non_marked_row.repeat(mat.shape[0], 1) | |
bool_marked_row = torch.all(arange_index_row != repeated_non_marked_row, dim = 1) | |
marked_rows = arange_index_row[bool_marked_row].squeeze(0) | |
return marked_zero, marked_rows, marked_cols | |
def adjust_matrix(mat: Tensor, cover_rows: Tensor, cover_cols: Tensor) -> Tensor: | |
bool_cover = torch.zeros_like(mat) | |
bool_cover[cover_rows.long()] = True | |
bool_cover[:, cover_cols.long()] = True | |
non_cover = mat[bool_cover != True] | |
min_non_cover = non_cover.min() | |
mat[bool_cover != True] = mat[bool_cover != True] - min_non_cover | |
double_bool_cover = torch.zeros_like(mat) | |
double_bool_cover[cover_rows.long(), cover_cols.long()] = True | |
mat[double_bool_cover == True] = mat[double_bool_cover == True] + min_non_cover | |
return mat | |
def hungarian_algorithm(mat: Tensor) -> Tensor: | |
dim = mat.shape[0] | |
cur_mat = mat | |
cur_mat = cur_mat - cur_mat.min(1, keepdim = True)[0] | |
cur_mat = cur_mat - cur_mat.min(0, keepdim = True)[0] | |
zero_count = 0 | |
while zero_count < dim: | |
ans_pos, marked_rows, marked_cols = mark_matrix(cur_mat) | |
zero_count = len(marked_rows) + len(marked_cols) | |
if zero_count < dim: | |
cur_mat = adjust_matrix(cur_mat, marked_rows, marked_cols) | |
return ans_pos | |
# Example 1 | |
mat = torch.tensor( | |
[[7, 6, 2, 9, 2], | |
[6, 2, 1, 3, 9], | |
[5, 6, 8, 9, 5], | |
[6, 8, 5, 8, 6], | |
[9, 5, 6, 4, 7]], device = device) | |
ans_pos = hungarian_algorithm(mat) | |
print(ans_pos) | |
res = mat[ans_pos[:, 0].long(), ans_pos[:, 1].long()] | |
print(res) | |
print(res.sum()) | |
print('==============') | |
# Example 2 | |
mat = torch.tensor( | |
[[108, 125, 150], | |
[150, 135, 175], | |
[122, 148, 250]], device = device) | |
ans_pos = hungarian_algorithm(mat) | |
print(ans_pos) | |
res = mat[ans_pos[:, 0].long(), ans_pos[:, 1].long()] | |
print(res) | |
print(res.sum()) | |
print('==============') | |
# Example 3 | |
mat = torch.tensor( | |
[[1500, 4000, 4500], | |
[2000, 6000, 3500], | |
[2000, 4000, 2500]], device = device) | |
ans_pos = hungarian_algorithm(mat) | |
print(ans_pos) | |
res = mat[ans_pos[:, 0].long(), ans_pos[:, 1].long()] | |
print(res) | |
print(res.sum()) | |
print('==============') | |
# Example 4 | |
mat = torch.tensor( | |
[[5, 9, 3, 6], | |
[8, 7, 8, 2], | |
[6, 10, 12, 7], | |
[3, 10, 8, 6]], device = device) | |
ans_pos = hungarian_algorithm(mat) | |
print(ans_pos) | |
res = mat[ans_pos[:, 0].long(), ans_pos[:, 1].long()] | |
print(res) | |
print(res.sum()) |
I have extracted ultralytics and DETR code, passed 1000 lines to GPT4 to refactor it. So, here is my PyTorch implementation of Hungarian loss function with SciPy assignment problem solver in 12 lines of code.
@ivanstepanovftw Scipy.optimize's implementation of linear sum assignment requires a numpy array. Moving the iou tensor to cpu to get it as a numpy array will detach it from the computation graph, meaning you can't calculate grads on it anymore.
@LivesayMe, while this is true, linear sum assignment only returns indices that are required for Hungarian loss to be used as a mask to permute indices in predicted and target tensors. Take a look at minimal implementation of my Hungarian loss implementation on Github Gist:
def hungarian_loss(outputs, targets):
cost_matrix = torch.cdist(outputs, targets, p=1)
row_ind, col_ind = linear_sum_assignment(cost_matrix.cpu().detach().numpy())
matched_outputs = outputs[row_ind]
matched_targets = targets[col_ind]
loss = F.l1_loss(matched_outputs, matched_targets)
return loss
Here you see that matched_outputs
located at the same device as outputs
, and matched_targets
are located at the same device as targets
, while they are being permuted by row_ind
and col_ind
of type numpy.ndarray
.
Note that in this hungarian_loss
implementation requires both outputs
and targets
to not be batched, i.e. tensor dims are (set_length, set_features). For batched implementation you need to adapt the code for batching, i.e.:
def criterion(x, y, x_lengths, y_lengths):
hungarian = torch.tensor(0.0)
for i in range(x.shape[0]):
hungarian += hungarian_loss(x[i, :x_lengths[i]], y[i, :y_lengths[i]])
hungarian /= x.shape[0] # batchmean
return hungarian
If you still not sure if it works in product, learn how Ultralytics YOLO code works, and even uses alternative solvers, such as lap.lapjv from https://github.com/gatagat/lap. Search for repo:ultralytics/ultralytics linear_sum_assignment.
In case if you want alternatives to Hungarian loss, take a look into Chamfer distance written in both CPU and CUDA code.
from pytorch3d.loss import chamfer_distance
def criterion(x, y, x_lengths, y_lengths):
cham, cham_norm = chamfer_distance(x, y, x_lengths, y_lengths, point_reduction='mean', single_directional=False, abs_cosine=True)
return cham
If you have any questions feel free to ask, I will happy to answer it.
@LivesayMe, while this is true, linear sum assignment only returns indices that are required for Hungarian loss to be used as a mask to permute indices in predicted and target tensors. Take a look at minimal implementation of my Hungarian loss implementation on Github Gist:
def hungarian_loss(outputs, targets): cost_matrix = torch.cdist(outputs, targets, p=1) row_ind, col_ind = linear_sum_assignment(cost_matrix.cpu().detach().numpy()) matched_outputs = outputs[row_ind] matched_targets = targets[col_ind] loss = F.l1_loss(matched_outputs, matched_targets) return lossHere you see that
matched_outputs
located at the same device asoutputs
, andmatched_targets
are located at the same device astargets
, while they are being permuted byrow_ind
andcol_ind
of typenumpy.ndarray
.Note that in this
hungarian_loss
implementation requires bothoutputs
andtargets
to not be batched, i.e. tensor dims are (set_length, set_features). For batched implementation you need to adapt the code for batching, i.e.:def criterion(x, y, x_lengths, y_lengths): hungarian = torch.tensor(0.0) for i in range(x.shape[0]): hungarian += hungarian_loss(x[i, :x_lengths[i]], y[i, :y_lengths[i]]) hungarian /= x.shape[0] # batchmean return hungarianIf you still not sure if it works in product, learn how Ultralytics YOLO code works, and even uses alternative solvers, such as lap.lapjv from https://github.com/gatagat/lap. Search for repo:ultralytics/ultralytics linear_sum_assignment.
In case if you want alternatives to Hungarian loss, take a look into Chamfer distance written in both CPU and CUDA code.
from pytorch3d.loss import chamfer_distance def criterion(x, y, x_lengths, y_lengths): cham, cham_norm = chamfer_distance(x, y, x_lengths, y_lengths, point_reduction='mean', single_directional=False, abs_cosine=True) return chamIf you have any questions feel free to ask, I will happy to answer it.
Scipy's routine does not run on GPU, if you're using it for detection, it's ok. But if you are dealing with other tasks e.g. point cloud matching where the distance matrix is like 100_000x100_000, it is just too slow.
dim
just tested the code on my distance matrix, it ran into a dead loop
Hungarian loss is
Dear authror:
Thanks for your excellent programming, which has been very helpful to me. If it is helpful to you, I'd like to offer a suggestion.
I'm afraid that the algorithm you adapted from "https://python.plainenglish.io/hungarian-algorithm-introduction-python-implementation-93e7c0890e15" was incorrect. Specifically, the function "adjust_matrix()" is expected to modify the matrix by adding and subtracting "min_non_cover", so that "zero_count" could reach "dim". However, the function "mark_matrix()" does not make sure no zero is uncovered, since "non_marked_row" may contain zeros that does not lie in "marked_cols". In this case, adjust_matrix() is doing nothing so that the loop in "hungarian algorithm" is dead.
The algorithm in https://brc2.com/the-algorithm-workshop/ should be correct since I have not found any contradictory in it. I'd be very happy to see you update your code.
If you have any ideas, my email is [email protected]