Last active
November 21, 2023 10:31
-
-
Save rwightman/fff86a015efddcba8b3c8008167ea705 to your computer and use it in GitHub Desktop.
Hacky PyTorch Batch-Hard Triplet Loss and PK samplers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from collections import OrderedDict | |
import math | |
def pdist(v): | |
dist = torch.norm(v[:, None] - v, dim=2, p=2) | |
return dist | |
class TripletLoss(nn.Module): | |
def __init__(self, margin=1.0, sample=True): | |
super(TripletLoss, self).__init__() | |
self.margin = margin | |
self.sample = sample | |
def forward(self, inputs, targets): | |
n = inputs.size(0) | |
# pairwise distances | |
dist = pdist(inputs) | |
# find the hardest positive and negative | |
mask_pos = targets.expand(n, n).eq(targets.expand(n, n).t()) | |
mask_neg = ~mask_pos | |
mask_pos[torch.eye(n).byte().cuda()] = 0 | |
if self.sample: | |
# weighted sample pos and negative to avoid outliers causing collapse | |
posw = (dist + 1e-12) * mask_pos.float() | |
posi = torch.multinomial(posw, 1) | |
dist_p = dist.gather(0, posi.view(1, -1)) | |
# There is likely a much better way of sampling negatives in proportion their difficulty, based on distance | |
# this was a quick hack that ended up working better for some datasets than hard negative | |
negw = (1 / (dist + 1e-12)) * mask_neg.float() | |
negi = torch.multinomial(negw, 1) | |
dist_n = dist.gather(0, negi.view(1, -1)) | |
else: | |
# hard negative | |
ninf = torch.ones_like(dist) * float('-inf') | |
dist_p = torch.max(dist * mask_pos.float(), dim=1)[0] | |
nindex = torch.max(torch.where(mask_neg, -dist, ninf), dim=1)[1] | |
dist_n = dist.gather(0, nindex.unsqueeze(0)) | |
# calc loss | |
diff = dist_p - dist_n | |
if isinstance(self.margin, str) and self.margin == 'soft': | |
diff = F.softplus(diff) | |
else: | |
diff = torch.clamp(diff + self.margin, min=0.) | |
loss = diff.mean() | |
# calculate metrics, no impact on loss | |
metrics = OrderedDict() | |
with torch.no_grad(): | |
_, top_idx = torch.topk(dist, k=2, largest=False) | |
top_idx = top_idx[:, 1:] | |
flat_idx = top_idx.squeeze() + n * torch.arange(n, out=torch.LongTensor()).cuda() | |
top1_is_same = torch.take(mask_pos, flat_idx) | |
metrics['prec'] = top1_is_same.float().mean().item() | |
metrics['dist_acc'] = (dist_n > dist_p).float().mean().item() | |
if not isinstance(self.margin, str): | |
metrics['dist_sm'] = (dist_n > dist_p + self.margin).float().mean().item() | |
metrics['nonzero_count'] = torch.nonzero(diff).size(0) | |
metrics['dist_p'] = dist_p.mean().item() | |
metrics['dist_n'] = dist_n.mean().item() | |
metrics['rel_dist'] = ((dist_n - dist_p) / torch.max(dist_p, dist_n)).mean().item() | |
return loss, metrics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.data.sampler import Sampler | |
import numpy as np | |
# Both samplers are passed a data_source (likely your dataset) that has following members: | |
# * label_to_samples - mapping of label ids (zero based integer) to samples for that label | |
class PKSampler(Sampler): | |
def __init__(self, data_source, p=64, k=16): | |
super().__init__(data_source) | |
self.p = p | |
self.k = k | |
self.data_source = data_source | |
def __iter__(self): | |
pk_count = len(self) // (self.p * self.k) | |
for _ in range(pk_count): | |
labels = np.random.choice( | |
np.arange(len(self.data_source.label_to_samples.keys()), self.p, replace=False) | |
for l in labels: | |
indices = self.data_source.label_to_samples[l] | |
replace = True if len(indices) < self.k else False | |
for i in np.random.choice(indices, self.k, replace=replace): | |
yield i | |
def __len__(self): | |
pk = self.p * self.k | |
samples = ((len(self.data_source) - 1) // pk + 1) * pk | |
return samples | |
def grouper(iterable, n): | |
it = itertools.cycle(iter(iterable)) | |
for _ in range((len(iterable) - 1) // n + 1): | |
yield list(itertools.islice(it, n)) | |
# full label coverage per 'epoch' | |
class PKSampler2(Sampler): | |
def __init__(self, data_source, p=64, k=16): | |
super().__init__(data_source) | |
self.p = p | |
self.k = k | |
self.data_source = data_source | |
def __iter__(self): | |
rand_labels = np.random.permutation( | |
np.arange(len(self.data_source.label_to_samples.keys()))) | |
for labels in grouper(rand_labels, self.p): | |
for l in labels: | |
indices = self.data_source.label_to_samples[l] | |
replace = True if len(indices) < self.k else False | |
for j in np.random.choice(indices, self.k, replace=replace): | |
yield j | |
def __len__(self): | |
num_labels = len(self.data_source.label_names) | |
samples = ((num_labels - 1) // self.p + 1) * self.p * self.k | |
return samples |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hello ! Thanks a lot.
How would you handle the semi hard positive in the triplet loss ? I tried to get my head around the indexes but got lost ...