Last active
February 28, 2022 18:20
-
-
Save mberr/7f08a37a56addb083258adfbca12b837 to your computer and use it in GitHub Desktop.
Several Similarity Matrix Normalization Methods written in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Several similarity matrix normalization methods.""" | |
import torch | |
def csls( | |
sim: torch.FloatTensor, | |
k: Optional[int] = 1, | |
) -> torch.FloatTensor: | |
""" | |
Apply CSLS normalization to a similarity matrix. | |
.. math:: | |
csls[i, j] = 2*sim[i, j] - avg(top_k(sim[i, :])) - avg(top_k(sim[:, j])) | |
:param sim: shape: (d1, ..., dk) | |
Similarity matrix. | |
:param k: | |
The number of top-k elements to use for correction. | |
:return: | |
The normalized similarity matrix. | |
""" | |
if k is None: | |
return sim | |
# Empty similarity matrix | |
if sim.numel() < 1: | |
return sim | |
old_sim = sim | |
# compensate for subtraction | |
sim = sim.ndimension() * sim | |
# Subtract average over top-k similarities for each mode of the tensors. | |
for dim, size in enumerate(sim.size()): | |
sim = sim - old_sim.topk(k=min(k, size), dim=dim, largest=True, sorted=False).values.mean(dim=dim, keepdim=True) | |
return sim | |
def sinkhorn_knopp( | |
similarities: torch.FloatTensor, | |
eps: float = 1.0e-04, | |
max_iter: int = 1000, | |
) -> torch.FloatTensor: | |
""" | |
Normalize similarities to be double stochastic using the Sinkhorn-Knopp algorithm. | |
:param similarities: shape: (n, n) | |
The similarities. | |
:param eps: | |
A tolerance for convergence check. | |
:param max_iter: | |
A maximum number of iterations. | |
:return: | |
The normalized similarities (in log space!). | |
.. seealso :: | |
http://www.cerfacs.fr/algor/reports/2006/TR_PA_06_42.pdf | |
https://github.com/HeddaCohenIndelman/Learning-Gumbel-Sinkhorn-Permutations-w-Pytorch/blob/8fbc8cf4b97f5bafd18776b5497e3f724d60cc0a/my_sinkhorn_ops.py#L36 | |
""" | |
# input verification | |
n = similarities.shape[0] | |
if similarities.ndimension() != 2 or similarities.shape[1] != n: | |
raise ValueError(f'similarities have to be a square matrix, but have shape: {similarities.shape}') | |
# fix-point iteration | |
for _ in range(max_iter): | |
old_similarities = similarities | |
# update | |
similarities = similarities - similarities.logsumexp(dim=-1, keepdim=True) | |
similarities = similarities - similarities.logsumexp(dim=-2, keepdim=True) | |
# convergence check | |
if (old_similarities - similarities).norm() < eps: | |
break | |
return similarities | |
def bidirectional_alignment( | |
similarities: torch.FloatTensor, | |
normalize: bool = False, | |
) -> torch.FloatTensor: | |
""" | |
Compute bi-directional alignment scores. | |
.. note :: | |
This operation is non-differentiable. | |
.. seealso :: | |
https://www.aclweb.org/anthology/D19-1075.pdf | |
:param similarities: shape: (n, m) | |
The similarity scores. | |
:param normalize: | |
Use the normalized rank instead of the rank; also use mean instead of sum over both directions. Guarantees | |
that the output has value range [0, 1]. | |
:return: shape: (n, m) | |
The new similarity scores. | |
""" | |
left_to_right, right_to_left = [similarities.argsort(dim=dim).float() for dim in (0, 1)] | |
if normalize: | |
left_to_right = 0.5 * left_to_right / similarities.shape[0] | |
right_to_left = 0.5 * right_to_left / similarities.shape[1] | |
return left_to_right + right_to_left |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment