-
-
Save amqdn/f3ba1ea30e4e21c24617f6d7aec75212 to your computer and use it in GitHub Desktop.
Great explanation, but I don't think the counter usage works, as the
targets
are tensors and not plain numbers. That's why your Majority and Minority Indices are the same.Here is some of the preliminary work I did that fixes this using
torch.unique
and and generalizes up toSortMinorityClass
:from dataclasses import dataclass def indices_of(occurences, given_class): "extract a tensor of indices of the given instance class" return tensor([ i for i, encoding in enumerate(occurences) if torch.all(torch.eq(encoding, given_class)) ]) @dataclass class BatchClassification: __slots__ = ['encoding', 'indices', 'frequency'] encoding: Tensor # class encoding as a tensor indices: Tensor # indices within the batch frequency: int # number of occurrences in the class def classifications_of(targets: Tensor, descending_frequency=False) -> t.Iterable[BatchClassification]: class_encodings, class_indices, class_counts = torch.unique(targets, dim=0, return_counts=True, return_inverse=True) return sorted([ BatchClassification(encoding, indices_of(class_indices, encoding), frequency) for encoding, frequency in zip(class_encodings, class_counts) ], key=lambda bc: bc.frequency, reverse=descending_frequency) class SortMinorityClass(LearnerCallback): def __init__(self, learn:Learner): super().__init__(learn) self.iters = 0 # Manage the number of printouts we get def on_batch_begin(self, last_target:Tensor, **kwargs:Any) -> Tensor: if self.iters < 2: for index, bc in enumerate(classifications_of(last_target)): print(f'frequency group {index}, count {str(bc.frequency)}: {bc}') self.iters += 1Also, since the mining results are a function of outputs and targets anyways, couldn't we implement the entire mining operation within the loss function? I'm kinda new to ML so maybe I'm missing something
Actually, the Counter
works well here.
Because when passing the tensors to Counter
, the author calls the tolist()
method of torch.tensor
object.
Check this bellow demo👇:
@2foil ah you're right – but it seems we're both right, because I was looking at the output of SortMinorityClass
, which uses Counter(targets)
, which is later fixed in the final ClassLevelHardMining
callback
@micimize Okay, I got it. Thanks for your explanation 😊.
@amqdn Thanks for your tutorial ❤️, it helps me a lot when implementing CRLloss.
Here I have one question.
Now I'm dealing with one training dataset, which has multiple majority and minority classes.
So how to compute the omega
in CRLloss?
@amqdn Got it 😊, thanks for your explanation. ❤️
@micimize
Thanks for taking the time to do that! I like the idea of using
torch.unique
to accomplish the sample counting. As you can see, I published this over a year ago, and I notice my relative inexperience with PyTorch and ML back then shows.Re: mining results being a function of outputs and targets... If I understand your question: You're asking why it is that I included the mining operation in a
callback_fn
instead of inside theloss.forward
. In looking at my code and trying to remember why, I think I couldn't come up with a way to retain the indices of the majority/minority classes using just the loss module. Since the majority/minority class designation changes from batch-to-batch dynamically in this paper, it's necessary to keep track of the indices (I think) in order to calculate the loss properly. That doesn't mean it's not possible, but I hadn't found a way.