Skip to content

Instantly share code, notes, and snippets.

@alexshtf
Created August 15, 2024 11:23
Show Gist options
  • Save alexshtf/d9f14fd5af21fa34ba550f85f2bfc1ed to your computer and use it in GitHub Desktop.
Save alexshtf/d9f14fd5af21fa34ba550f85f2bfc1ed to your computer and use it in GitHub Desktop.
pytorch ops for ranking
import torch
import torch.nn.functional as F
import copy
import math
def ipw_crossentropy(weights, scores, label):
propensities = torch.reciprocal(weights)
return (
-label * (scores + torch.log(propensities))
- (1 - label) * torch.logaddexp(scores + torch.log(1 - propensities))
+ torch.logaddexp(scores)
)
def ipw_logsoftmax(qid, score, label, weights, return_se=False):
# group qids
qid_unique, qid_enumeration = qid.unique_consecutive(return_inverse=True)
zeros = torch.zeros_like(qid_unique, dtype=score.dtype)
# compute logsoftmax for each sample
qid_logsumexp = zeros.scatter_add(-1, qid_enumeration, score.exp() * weights).log()
# compute normalized score weighted sum
qid_score_sum = zeros.scatter_add(-1, qid_enumeration, label * weights * score)
qid_normalization = zeros.scatter_add(-1, qid_enumeration, label * weights)
qid_normalization = torch.where(qid_normalization > 0, qid_normalization, 1.)
qid_losses = qid_logsumexp - qid_score_sum / qid_normalization
if return_se:
return qid_losses.mean(), qid_losses.std() / math.sqrt(len(qid_losses))
else:
return qid_losses.mean()
def listwide_loss(qid, score, label):
qid_unique, qid_enumeration, qid_counts = qid.unique_consecutive(return_inverse=True, return_counts=True)
# compute label for each query
zeros = torch.zeros_like(qid_unique, dtype=label.dtype)
qid_label_sums = zeros.scatter_add(-1, qid_enumeration, label)
qid_labels = torch.sign(torch.sign(qid_label_sums) - 0.5) # make labels in {-1, 1}
# compute ln(1 + mean(exp(score)) ** (-y)) for each query, and average the results.
zeros = torch.zeros_like(qid_unique, dtype=score.dtype)
logsumexp_mean = zeros.scatter_add(-1, qid_enumeration, torch.exp(score)) / qid_counts
query_losses = torch.log1p(torch.where(
qid_labels > 0, logsumexp_mean.reciprocal(), logsumexp_mean
))
return query_losses.mean()
class IPWLogSoftMax:
def __call__(self, qid, score, label, weights):
return ipw_logsoftmax(qid, score, label, weights)
class IPWListWideLogSoftmax:
def __init__(self, listwide_weight=1.):
self.listwide_weight = listwide_weight
def __call__(self, qid, score, label, weights):
t = self.listwide_weight
return t * ipw_logsoftmax(qid, score, label, weights) + \
(1 - t) * listwide_loss(qid, score, label)
def lexsort(keys, dim=-1):
if len(keys) == 0:
raise ValueError(f"Must have at least 1 key, but {len(keys)=}.")
idx = keys[0].argsort(dim=dim, stable=True)
for k in keys[1:]:
idx = idx.gather(dim, k.gather(dim, idx).argsort(dim=dim, stable=True))
return idx
def split_int_to_bytes(input_tensor):
# Calculate the number of bytes required to represent each integer
num_bytes = input_tensor.dtype.itemsize
# Reshape the input tensor to a shape where each row represents a single integer
orig_shape = input_tensor.shape
input_tensor = input_tensor.view(-1, 1)
# Create masks to extract each byte
masks = torch.tensor([0xFF << (8 * i) for i in range(num_bytes)],
dtype=input_tensor.dtype,
device=input_tensor.device)
# Apply masks to extract bytes
byte_tensors = []
for i in range(num_bytes):
byte = (input_tensor & masks[i]) >> (8 * i)
byte = byte.to(torch.uint8).reshape(orig_shape)
byte_tensors.append(byte)
# Return the tuple of byte tensors
return tuple(byte_tensors)
def fnv_hash(tensor):
"""
Computes the FNV hash for each component of a PyTorch tensor of integers.
Args:
tensor: A PyTorch tensor of type int32 or int16
Returns:
A PyTorch tensor of the same size and dtype as the input tensor, containing the FNV hash for each element.
"""
# Define the FNV prime and offset basis
FNV_PRIME = torch.tensor(0x01000193, dtype=torch.int32)
FNV_OFFSET = torch.tensor(0x811c9dc5, dtype=torch.int32)
# Initialize the hash value with zeros (same size and dtype as tensor)
hash_value = torch.full_like(tensor, FNV_OFFSET)
for byte in split_int_to_bytes(tensor):
hash_value = torch.bitwise_xor(hash_value * FNV_PRIME, byte)
# No need to reshape, output already has the same size and dtype as input
return hash_value
def chunk_idx(x):
values, counts = x.unique_consecutive(return_counts=True)
idx = torch.cumsum(counts, dim=-1)
return F.pad(idx, (1, 0))
class BatchIter:
"""
tensors: feature tensors (each with shape: num_instances x *)
"""
def __init__(self, *tensors, batch_size, shuffle=True):
self.tensors = tensors
device = tensors[0].device
n = tensors[0].size(0)
if shuffle:
idxs = torch.randperm(n, device=device)
else:
idxs = torch.arange(n, device=device)
self.idxs = idxs.split(batch_size)
def __len__(self):
return len(self.idxs)
def __iter__(self):
tensors = self.tensors
for batch_idxs in self.idxs:
yield tuple((x[batch_idxs, ...] for x in tensors))
class QueryBatchIter:
"""
tensors: feature tensors (each with shape: num_instances x *)
"""
def __init__(self, qid, pos, *tensors, batch_size, shuffle=True, shuffle_seed=42):
self.qid = qid
self.pos = pos
self.tensors = tensors
device = qid.device
if shuffle:
self.idxs = lexsort([pos, fnv_hash(qid + shuffle_seed)])
else:
self.idxs = torch.arange(len(qid), device=device)
chunk_endpoints = chunk_idx(qid[self.idxs])
self.chunk_start, self.chunk_end = self.batches_with_overlap(chunk_endpoints, batch_size)
@staticmethod
def batches_with_overlap(idx, batch_size):
pad_size = len(idx) % batch_size
upper_bound = idx.max() + 1
lower_bound = idx.min() - 1
pad_for_start = torch.nn.functional.pad(idx, (0, pad_size), value=upper_bound)
pad_for_end = torch.nn.functional.pad(idx, (0, pad_size), value=lower_bound)
start = pad_for_start.unfold(0, 1 + batch_size, batch_size).min(dim=1)
stop = pad_for_end.unfold(0, 1 + batch_size, batch_size).max(dim=1)
return start.values.tolist(), stop.values.tolist()
def __len__(self):
return len(self.idxs)
def __iter__(self):
tensors = self.tensors
for start, end in zip(self.chunk_start, self.chunk_end):
batch_idxs = self.idxs[start:end]
if len(batch_idxs) > 0:
batch = (x[batch_idxs, ...] for x in (self.qid, self.pos) + tensors)
yield tuple(batch)
class CheckpointingEarlyStopper:
def __init__(self, patience=3, delta=0.):
self.patience = patience
self.delta = delta
self.counter = 0
self.best_loss = None
self.best_model = None
def update(self, val_loss, model):
if self.best_loss is None:
self.best_loss = val_loss
self.best_model = copy.deepcopy(model.state_dict())
if val_loss > self.best_loss + self.delta:
self.counter += 1
if self.counter >= self.patience:
return True
else:
self.counter = 0
self.best_loss = val_loss
self.best_model = copy.deepcopy(model.state_dict())
return False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment