Skip to content

Instantly share code, notes, and snippets.

@NegatioN
Last active July 11, 2018 09:09
Show Gist options
  • Save NegatioN/e1bf23ea1bd3bc999d6052b18c058c48 to your computer and use it in GitHub Desktop.
Save NegatioN/e1bf23ea1bd3bc999d6052b18c058c48 to your computer and use it in GitHub Desktop.
warp loss
def num_tries_gt_zero(scores, batch_size, max_trials, max_num, device):
'''
returns: [1 x batch_size] the lowest indice per row where scores were first greater than 0. plus 1
'''
tmp = scores.gt(0).nonzero().t()
# We offset these values by 1 to look for unset values (zeros) later
values = tmp[1] + 1
# TODO just allocate normal zero-tensor and fill it?
# Sparse tensors can't be moved with .to() or .cuda() if you want to send in cuda variables first
if device.type == 'cuda':
t = torch.cuda.sparse.LongTensor(tmp, values, torch.Size((batch_size, max_trials+1))).to_dense()
else:
t = torch.sparse.LongTensor(tmp, values, torch.Size((batch_size, max_trials+1))).to_dense()
t[(t == 0)] += max_num # set all unused indices to be max possible number so its not picked by min() call
tries = torch.min(t, dim=1)[0]
return tries
def warp_loss(positive_predictions, negative_predictions, num_labels, device):
'''
positive_predictions: [batch_size x 1]
negative_predictions: [batch_size x N]
num_labels: int
'''
batch_size, max_trials, num_labels = negative_predictions.size(0), negative_predictions.size(1), num_labels - 1
offsets, ones, max_num = (torch.arange(0, batch_size, 1).long().to(device) * (max_trials + 1),
torch.ones(batch_size, 1).float().to(device),
batch_size * (max_trials + 1) )
sample_scores = (1 + negative_predictions - positive_predictions).squeeze()
# Add column of ones so we know when we used all our attempts, This is used for indexing and computing should_count_loss if no real value is above 0
sample_scores, negative_predictions = (torch.cat([sample_scores, ones], dim=1),
torch.cat([negative_predictions, ones], dim=1))
tries = num_tries_gt_zero(sample_scores, batch_size, max_trials, max_num, device)
attempts, trial_offset = tries.float(), (tries - 1) + offsets
loss_weights, should_count_loss = ( torch.log(torch.floor(num_labels / (attempts + 1))),
(attempts <= max_trials).float()) #Don't count loss if we used max number of attempts
losses = loss_weights * ((1 - positive_predictions.view(-1)) + negative_predictions.view(-1)[trial_offset]) * should_count_loss
return losses.sum()#.mean()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment