Created
February 2, 2024 15:13
-
-
Save tiandiao123/71acbcaee7968a630dd7e175a3987071 to your computer and use it in GitHub Desktop.
test_sample_reweights.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
def sample_reweight(loss_curve, loss_values, k_th, alpha1=1.0, alpha2=1.0, bins_sr=10, decay=0.9): | |
""" | |
The SR module of Double Ensemble using PyTorch. | |
Args: | |
- loss_curve: Tensor, shape (N, T), the loss curve for each sample over training iterations. | |
- loss_values: Tensor, shape (N,), the loss of the current ensemble on each sample. | |
- k_th: int, the index of the current sub-model, starting from 1. | |
- alpha1: float, weight for h1 calculation. | |
- alpha2: float, weight for h2 calculation. | |
- bins_sr: int, number of bins for discretizing h-values. | |
- decay: float, decay rate for adjusting weights. | |
Returns: | |
- weights: Tensor, shape (N,), new weights for each sample. | |
""" | |
N, T = loss_curve.shape | |
# Normalize loss_curve and loss_values with ranking | |
loss_curve_rank = loss_curve.argsort(dim = 0).argsort(dim = 0).float() / (T - 1) | |
loss_values_rank = (-loss_values).argsort().argsort().float() / (N - 1) | |
# Calculate l_start and l_end | |
part = max(int(T * 0.1), 1) | |
l_start = loss_curve_rank[:, :part].mean(dim=1) | |
l_end = loss_curve_rank[:, -part:].mean(dim=1) | |
# Calculate h-value for each sample | |
h1 = loss_values_rank | |
h2 = (l_end / l_start).argsort().argsort().float() / (N - 1) | |
h_value = alpha1 * h1 + alpha2 * h2 | |
# Discretize h-value into bins and calculate weights | |
_, bins = torch.histogram(h_value, bins=bins_sr) | |
h_bins = torch.bucketize(h_value, bins, right=True) | |
weights = torch.zeros(N, dtype=torch.float) | |
for i in range(1, bins_sr + 1): | |
bin_mask = h_bins == i | |
if bin_mask.any(): | |
bin_mean_h = h_value[bin_mask].mean() | |
weights[bin_mask] = 1.0 / (decay ** k_th * bin_mean_h + 0.1) | |
return weights | |
# Example usage | |
N, T = 100, 20 | |
loss_curve = torch.randn(N, T) | |
loss_values = torch.randn(N) | |
k_th = 1 | |
weights = sample_reweight(loss_curve, loss_values, k_th) | |
print(weights) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment