Created
August 15, 2024 11:23
-
-
Save alexshtf/d9f14fd5af21fa34ba550f85f2bfc1ed to your computer and use it in GitHub Desktop.
pytorch ops for ranking
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
import copy | |
import math | |
def ipw_crossentropy(weights, scores, label): | |
propensities = torch.reciprocal(weights) | |
return ( | |
-label * (scores + torch.log(propensities)) | |
- (1 - label) * torch.logaddexp(scores + torch.log(1 - propensities)) | |
+ torch.logaddexp(scores) | |
) | |
def ipw_logsoftmax(qid, score, label, weights, return_se=False): | |
# group qids | |
qid_unique, qid_enumeration = qid.unique_consecutive(return_inverse=True) | |
zeros = torch.zeros_like(qid_unique, dtype=score.dtype) | |
# compute logsoftmax for each sample | |
qid_logsumexp = zeros.scatter_add(-1, qid_enumeration, score.exp() * weights).log() | |
# compute normalized score weighted sum | |
qid_score_sum = zeros.scatter_add(-1, qid_enumeration, label * weights * score) | |
qid_normalization = zeros.scatter_add(-1, qid_enumeration, label * weights) | |
qid_normalization = torch.where(qid_normalization > 0, qid_normalization, 1.) | |
qid_losses = qid_logsumexp - qid_score_sum / qid_normalization | |
if return_se: | |
return qid_losses.mean(), qid_losses.std() / math.sqrt(len(qid_losses)) | |
else: | |
return qid_losses.mean() | |
def listwide_loss(qid, score, label): | |
qid_unique, qid_enumeration, qid_counts = qid.unique_consecutive(return_inverse=True, return_counts=True) | |
# compute label for each query | |
zeros = torch.zeros_like(qid_unique, dtype=label.dtype) | |
qid_label_sums = zeros.scatter_add(-1, qid_enumeration, label) | |
qid_labels = torch.sign(torch.sign(qid_label_sums) - 0.5) # make labels in {-1, 1} | |
# compute ln(1 + mean(exp(score)) ** (-y)) for each query, and average the results. | |
zeros = torch.zeros_like(qid_unique, dtype=score.dtype) | |
logsumexp_mean = zeros.scatter_add(-1, qid_enumeration, torch.exp(score)) / qid_counts | |
query_losses = torch.log1p(torch.where( | |
qid_labels > 0, logsumexp_mean.reciprocal(), logsumexp_mean | |
)) | |
return query_losses.mean() | |
class IPWLogSoftMax: | |
def __call__(self, qid, score, label, weights): | |
return ipw_logsoftmax(qid, score, label, weights) | |
class IPWListWideLogSoftmax: | |
def __init__(self, listwide_weight=1.): | |
self.listwide_weight = listwide_weight | |
def __call__(self, qid, score, label, weights): | |
t = self.listwide_weight | |
return t * ipw_logsoftmax(qid, score, label, weights) + \ | |
(1 - t) * listwide_loss(qid, score, label) | |
def lexsort(keys, dim=-1): | |
if len(keys) == 0: | |
raise ValueError(f"Must have at least 1 key, but {len(keys)=}.") | |
idx = keys[0].argsort(dim=dim, stable=True) | |
for k in keys[1:]: | |
idx = idx.gather(dim, k.gather(dim, idx).argsort(dim=dim, stable=True)) | |
return idx | |
def split_int_to_bytes(input_tensor): | |
# Calculate the number of bytes required to represent each integer | |
num_bytes = input_tensor.dtype.itemsize | |
# Reshape the input tensor to a shape where each row represents a single integer | |
orig_shape = input_tensor.shape | |
input_tensor = input_tensor.view(-1, 1) | |
# Create masks to extract each byte | |
masks = torch.tensor([0xFF << (8 * i) for i in range(num_bytes)], | |
dtype=input_tensor.dtype, | |
device=input_tensor.device) | |
# Apply masks to extract bytes | |
byte_tensors = [] | |
for i in range(num_bytes): | |
byte = (input_tensor & masks[i]) >> (8 * i) | |
byte = byte.to(torch.uint8).reshape(orig_shape) | |
byte_tensors.append(byte) | |
# Return the tuple of byte tensors | |
return tuple(byte_tensors) | |
def fnv_hash(tensor): | |
""" | |
Computes the FNV hash for each component of a PyTorch tensor of integers. | |
Args: | |
tensor: A PyTorch tensor of type int32 or int16 | |
Returns: | |
A PyTorch tensor of the same size and dtype as the input tensor, containing the FNV hash for each element. | |
""" | |
# Define the FNV prime and offset basis | |
FNV_PRIME = torch.tensor(0x01000193, dtype=torch.int32) | |
FNV_OFFSET = torch.tensor(0x811c9dc5, dtype=torch.int32) | |
# Initialize the hash value with zeros (same size and dtype as tensor) | |
hash_value = torch.full_like(tensor, FNV_OFFSET) | |
for byte in split_int_to_bytes(tensor): | |
hash_value = torch.bitwise_xor(hash_value * FNV_PRIME, byte) | |
# No need to reshape, output already has the same size and dtype as input | |
return hash_value | |
def chunk_idx(x): | |
values, counts = x.unique_consecutive(return_counts=True) | |
idx = torch.cumsum(counts, dim=-1) | |
return F.pad(idx, (1, 0)) | |
class BatchIter: | |
""" | |
tensors: feature tensors (each with shape: num_instances x *) | |
""" | |
def __init__(self, *tensors, batch_size, shuffle=True): | |
self.tensors = tensors | |
device = tensors[0].device | |
n = tensors[0].size(0) | |
if shuffle: | |
idxs = torch.randperm(n, device=device) | |
else: | |
idxs = torch.arange(n, device=device) | |
self.idxs = idxs.split(batch_size) | |
def __len__(self): | |
return len(self.idxs) | |
def __iter__(self): | |
tensors = self.tensors | |
for batch_idxs in self.idxs: | |
yield tuple((x[batch_idxs, ...] for x in tensors)) | |
class QueryBatchIter: | |
""" | |
tensors: feature tensors (each with shape: num_instances x *) | |
""" | |
def __init__(self, qid, pos, *tensors, batch_size, shuffle=True, shuffle_seed=42): | |
self.qid = qid | |
self.pos = pos | |
self.tensors = tensors | |
device = qid.device | |
if shuffle: | |
self.idxs = lexsort([pos, fnv_hash(qid + shuffle_seed)]) | |
else: | |
self.idxs = torch.arange(len(qid), device=device) | |
chunk_endpoints = chunk_idx(qid[self.idxs]) | |
self.chunk_start, self.chunk_end = self.batches_with_overlap(chunk_endpoints, batch_size) | |
@staticmethod | |
def batches_with_overlap(idx, batch_size): | |
pad_size = len(idx) % batch_size | |
upper_bound = idx.max() + 1 | |
lower_bound = idx.min() - 1 | |
pad_for_start = torch.nn.functional.pad(idx, (0, pad_size), value=upper_bound) | |
pad_for_end = torch.nn.functional.pad(idx, (0, pad_size), value=lower_bound) | |
start = pad_for_start.unfold(0, 1 + batch_size, batch_size).min(dim=1) | |
stop = pad_for_end.unfold(0, 1 + batch_size, batch_size).max(dim=1) | |
return start.values.tolist(), stop.values.tolist() | |
def __len__(self): | |
return len(self.idxs) | |
def __iter__(self): | |
tensors = self.tensors | |
for start, end in zip(self.chunk_start, self.chunk_end): | |
batch_idxs = self.idxs[start:end] | |
if len(batch_idxs) > 0: | |
batch = (x[batch_idxs, ...] for x in (self.qid, self.pos) + tensors) | |
yield tuple(batch) | |
class CheckpointingEarlyStopper: | |
def __init__(self, patience=3, delta=0.): | |
self.patience = patience | |
self.delta = delta | |
self.counter = 0 | |
self.best_loss = None | |
self.best_model = None | |
def update(self, val_loss, model): | |
if self.best_loss is None: | |
self.best_loss = val_loss | |
self.best_model = copy.deepcopy(model.state_dict()) | |
if val_loss > self.best_loss + self.delta: | |
self.counter += 1 | |
if self.counter >= self.patience: | |
return True | |
else: | |
self.counter = 0 | |
self.best_loss = val_loss | |
self.best_model = copy.deepcopy(model.state_dict()) | |
return False | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment