Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active November 5, 2024 21:33
Show Gist options
  • Save crowsonkb/7df88ec63ea19ac335aa8b6c8f530769 to your computer and use it in GitHub Desktop.
Save crowsonkb/7df88ec63ea19ac335aa8b6c8f530769 to your computer and use it in GitHub Desktop.
Plackett-Luce loss function for ranking tasks. The rankings may be partial and include ties.
"""Plackett-Luce loss function for ranking tasks. The rankings may be partial and include ties."""
from itertools import chain
from typing import List, Optional, Tuple
import torch
def plackett_luce_loss(
scores: torch.Tensor,
numerator_mask: torch.Tensor,
denominator_mask: torch.Tensor,
weights: torch.Tensor,
*,
eps: float = 0.0,
) -> torch.Tensor:
"""Plackett-Luce loss function for ranking tasks.
If `beta * (logp - logp_ref)` is input as `scores`, where `logp` is the log probability
of a completion given its prompt, and `logp_ref` is the log probability of the completion
given its prompt under a reference model, then this function computes the Plackett-Luce DPO
loss (Appendix A.3 of https://arxiv.org/abs/2305.18290).
Args:
scores (torch.Tensor): Model output. Shape (n_scores).
numerator_mask (torch.Tensor): From `make_inputs_for_plackett_luce_loss()`.
Shape (n_groups, n_scores).
denominator_mask (torch.Tensor): From `make_inputs_for_plackett_luce_loss()`.
Shape (n_groups, n_scores).
weights: From `make_inputs_for_plackett_luce_loss()`. Shape (n_groups).
eps (float, optional): Epsilon for conservative Plackett-Luce. Defaults to `0.0`.
Returns:
torch.Tensor: Scalar loss value.
"""
n1 = torch.logsumexp(torch.where(numerator_mask, scores, float("-inf")), dim=-1)
d1 = torch.logsumexp(torch.where(denominator_mask, scores, float("-inf")), dim=-1)
n2 = torch.logsumexp(torch.where(numerator_mask, -scores, float("-inf")), dim=-1)
d2 = torch.logsumexp(torch.where(denominator_mask, -scores, float("-inf")), dim=-1)
log_likelihood_parts = torch.where(n1 == float("-inf"), 0.0, torch.lerp(n1 - d1, n2 - d2, eps))
return -torch.sum(weights * log_likelihood_parts)
def make_inputs_for_plackett_luce_loss(
rankings: List[List[List[int]]],
n_scores: int,
weights: Optional[List[float]] = None,
device: Optional[torch.device] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Make inputs for the Plackett-Luce loss function.
This function accepts a list of rankings in the following format:
```
rankings = [
[[1], [2], [0]],
[[2], [0]],
[[0, 2], [1]],
]
```
The above example has 3 rankings. The first ranking prefers 1 to 2 to 0. The second ranking
prefers 2 to 0 and is indifferent to 1. The third ranking prefers either 0 or 2 to 1.
Args:
rankings (List[List[List[int]]]): List of rankings.
n_scores (int): Number of scores.
weights (List[float], optional): Weights for the rankings. Defaults to `1 / len(rankings)`
for all rankings.
device (torch.device, optional): Device for the inputs.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of inputs.
"""
n_groups = sum(len(ranking) - 1 for ranking in rankings)
weights = [1 / len(rankings)] * len(rankings) if weights is None else weights
numerator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool)
denominator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool)
weights_out = torch.empty(n_groups, device="cpu")
i = 0
for ranking, weight in zip(rankings, weights):
remaining_items = set(chain.from_iterable(ranking))
for group in ranking[:-1]:
items = set(group)
for item in items:
numerator_mask[i, item] = True
for item in remaining_items:
denominator_mask[i, item] = True
weights_out[i] = weight
remaining_items -= items
i += 1
return numerator_mask.to(device), denominator_mask.to(device), weights_out.to(device)
def simple_plackett_luce_loss(
scores: torch.Tensor,
rankings: torch.Tensor,
weights: Optional[torch.Tensor] = None,
eps: float = 0.0,
) -> torch.Tensor:
"""Simple Plackett-Luce loss function for ranking tasks.
This function is an easier to use version of `plackett_luce_loss()`. It accepts a batched tensor
of scores, a batched tensor of rankings, and an optional tensor of weights. Ties are not
supported, all rankings must rank the same number of items, and all rankings are indifferent to
items not in the ranking. For example, if the rankings for `plackett_luce_loss()` are:
```
rankings = [
[[0], [1], [2]],
[[4], [3], [5]],
[[8], [6], [7]],
]
```
with a shape (9) tensor of scores, then the rankings for `simple_plackett_luce_loss()` are:
```
rankings = torch.tensor([[0, 1, 2], [1, 0, 2], [2, 0, 1]])
```
with a shape (3, 3) tensor of scores.
Args:
scores (torch.Tensor): Model output. Shape (n_rankings, n_scores_per_ranking).
rankings (torch.Tensor): Rankings. Shape (n_rankings, n_scores_per_ranking).
weights (torch.Tensor, optional): Weights for the rankings. Defaults to `1 / n_rankings` for
all rankings.
eps (float, optional): Epsilon for conservative Plackett-Luce. Defaults to `0.0`.
Returns:
torch.Tensor: Scalar loss value.
"""
if weights is None:
weights = scores.new_full(rankings.shape[:-1], 1 / rankings.shape[-2])
n1 = torch.gather(scores, -1, rankings)
d1 = n1.flip(-1).logcumsumexp(dim=-1).flip(-1)
n2 = -n1
d2 = n2.flip(-1).logcumsumexp(dim=-1).flip(-1)
log_likelihoods = torch.lerp(torch.sum(n1 - d1, dim=-1), torch.sum(n2 - d2, dim=-1), eps)
return -torch.sum(weights * log_likelihoods, dim=-1)
def sample_plackett_luce(scores: torch.Tensor, shape: Tuple[int] = ()) -> torch.Tensor:
"""Sample from a Plackett-Luce model.
Args:
scores (torch.Tensor): Model output. Shape (n_scores).
shape (Tuple[int], optional): Shape of the samples.
Returns:
torch.Tensor: Samples. Shape (*shape, n_scores).
"""
gumbel = scores.new_empty(*shape, *scores.shape).exponential_().log_().neg_()
_, indices = torch.sort(scores + gumbel, dim=-1, descending=True)
return indices
def is_before(permutation: torch.Tensor, a: int, b: int) -> torch.Tensor:
"""Check if a is before b in a permutation.
Args:
permutation (torch.Tensor): Permutation. Shape (*, n_scores).
a (int): The first element.
b (int): The second element.
Returns:
torch.Tensor: Boolean tensor. Shape (*).
"""
a_mask = permutation == a
b_mask = permutation == b
a_present = torch.any(a_mask, dim=-1)
b_present = torch.any(b_mask, dim=-1)
pos_a = torch.argmax(a_mask.byte(), dim=-1)
pos_b = torch.argmax(b_mask.byte(), dim=-1)
return a_present & b_present & (pos_a < pos_b)
def main():
"""Run tests."""
# Fit a model.
rankings = [
[[1], [2], [0]],
[[2], [0, 1]],
[[0], [1]],
]
weights = [0.5, 0.2, 0.3]
inputs = make_inputs_for_plackett_luce_loss(rankings, 3, weights)
scores = torch.zeros(3, requires_grad=True)
opt = torch.optim.SGD([scores], lr=1)
for i in range(20):
loss = plackett_luce_loss(scores, *inputs)
print(f"step: {i}, loss: {loss:.6f}")
loss.backward()
opt.step()
opt.zero_grad()
scores = scores.detach()
# Sample from the model.
samples = sample_plackett_luce(scores, (1_000_000,))
# Check the probability of a specific ranking.
rankings = [[[2], [0], [1]]]
inputs = make_inputs_for_plackett_luce_loss(rankings, 3)
expected = torch.exp(-plackett_luce_loss(scores, *inputs))
cond = torch.all(samples == torch.tensor([2, 0, 1]), dim=-1)
sampled = torch.mean(cond.float())
print(f"expected: {expected:.4f}")
print(f" sampled: {sampled:.4f}")
# Check the probability of a ranking where we are indifferent to one value.
rankings = [[[2], [0]]]
inputs = make_inputs_for_plackett_luce_loss(rankings, 3)
expected = torch.exp(-plackett_luce_loss(scores, *inputs))
cond = is_before(samples, 2, 0)
sampled = torch.mean(cond.float())
print(f"expected: {expected:.4f}")
print(f" sampled: {sampled:.4f}")
# Check the probability of a ranking where we are indifferent between two values.
rankings = [[[1], [0, 2]]]
inputs = make_inputs_for_plackett_luce_loss(rankings, 3)
expected = torch.exp(-plackett_luce_loss(scores, *inputs))
cond = is_before(samples, 1, 0) & is_before(samples, 1, 2)
sampled = torch.mean(cond.float())
print(f"expected: {expected:.4f}")
print(f" sampled: {sampled:.4f}")
# Check a more complicated example. Make up a Plackett-Luce model for this one.
scores = torch.linspace(0, 2, 5)
rankings = [[[0, 1], [2], [3]]]
inputs = make_inputs_for_plackett_luce_loss(rankings, 5)
expected = torch.exp(-plackett_luce_loss(scores, *inputs))
samples = sample_plackett_luce(scores, (1_000_000,))
cond1 = is_before(samples, 0, 2) & is_before(samples, 0, 3)
cond2 = is_before(samples, 1, 2) & is_before(samples, 1, 3)
cond3 = is_before(samples, 2, 3)
cond = (cond1 | cond2) & cond3
sampled = torch.mean(cond.float())
print(f"expected: {expected:.4f}")
print(f" sampled: {sampled:.4f}")
# Check that conservative Plackett-Luce does the right thing.
scores = torch.randn(3)
rankings = [[[0], [1], [2]]]
inputs = make_inputs_for_plackett_luce_loss(rankings, 3)
loss = plackett_luce_loss(scores, *inputs, eps=0.1)
expected_1 = plackett_luce_loss(scores, *inputs)
expected_2 = plackett_luce_loss(-scores, *inputs)
expected = 0.9 * expected_1 + 0.1 * expected_2
if torch.allclose(loss, expected):
print("Conservative Plackett-Luce test passed.")
else:
print("Conservative Plackett-Luce test failed.")
# Check the simple Plackett-Luce loss function.
scores = torch.randn(3, 3)
rankings_simple = torch.tensor([[0, 1, 2], [1, 0, 2], [2, 0, 1]])
loss = simple_plackett_luce_loss(scores, rankings_simple, eps=0.1)
rankings = [
[[0], [1], [2]],
[[4], [3], [5]],
[[8], [6], [7]],
]
inputs = make_inputs_for_plackett_luce_loss(rankings, 9)
expected = plackett_luce_loss(scores.flatten(), *inputs, eps=0.1)
if torch.allclose(loss, expected):
print("Simple Plackett-Luce loss function test passed.")
else:
print("Simple Plackett-Luce loss function test failed.")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment