Skip to content

Instantly share code, notes, and snippets.

@amqdn
Last active July 26, 2022 14:17
Show Gist options
  • Save amqdn/f3ba1ea30e4e21c24617f6d7aec75212 to your computer and use it in GitHub Desktop.
Save amqdn/f3ba1ea30e4e21c24617f6d7aec75212 to your computer and use it in GitHub Desktop.
Implementing Class Rectification Loss in fast.ai
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@amqdn
Copy link
Author

amqdn commented Jul 17, 2020

@micimize

Thanks for taking the time to do that! I like the idea of using torch.unique to accomplish the sample counting. As you can see, I published this over a year ago, and I notice my relative inexperience with PyTorch and ML back then shows.

Re: mining results being a function of outputs and targets... If I understand your question: You're asking why it is that I included the mining operation in a callback_fn instead of inside the loss.forward. In looking at my code and trying to remember why, I think I couldn't come up with a way to retain the indices of the majority/minority classes using just the loss module. Since the majority/minority class designation changes from batch-to-batch dynamically in this paper, it's necessary to keep track of the indices (I think) in order to calculate the loss properly. That doesn't mean it's not possible, but I hadn't found a way.

@2foil
Copy link

2foil commented Sep 21, 2020

Great explanation, but I don't think the counter usage works, as the targets are tensors and not plain numbers. That's why your Majority and Minority Indices are the same.

Here is some of the preliminary work I did that fixes this using torch.unique and and generalizes up to SortMinorityClass:

from dataclasses import dataclass

def indices_of(occurences, given_class):
    "extract a tensor of indices of the given instance class"
    return tensor([
        i for i, encoding in enumerate(occurences)
        if torch.all(torch.eq(encoding, given_class))
    ])

@dataclass
class BatchClassification:
    __slots__ = ['encoding', 'indices', 'frequency']
    encoding: Tensor # class encoding as a tensor
    indices: Tensor # indices within the batch
    frequency: int # number of occurrences in the class

def classifications_of(targets: Tensor, descending_frequency=False) -> t.Iterable[BatchClassification]:
    class_encodings, class_indices, class_counts = torch.unique(targets, dim=0, return_counts=True, return_inverse=True)
    return sorted([
        BatchClassification(encoding, indices_of(class_indices, encoding), frequency)
        for encoding, frequency in zip(class_encodings, class_counts)
    ], key=lambda bc: bc.frequency, reverse=descending_frequency)

class SortMinorityClass(LearnerCallback):
    def __init__(self, learn:Learner):
        super().__init__(learn)
        self.iters = 0  # Manage the number of printouts we get
    def on_batch_begin(self, last_target:Tensor, **kwargs:Any) -> Tensor:
        if self.iters < 2:
            for index, bc in enumerate(classifications_of(last_target)):
                print(f'frequency group {index}, count {str(bc.frequency)}: {bc}')
            self.iters += 1

Also, since the mining results are a function of outputs and targets anyways, couldn't we implement the entire mining operation within the loss function? I'm kinda new to ML so maybe I'm missing something

Actually, the Counter works well here.
Because when passing the tensors to Counter, the author calls the tolist() method of torch.tensor object.
Check this bellow demo👇:

屏幕快照 2020-09-21 23 08 18

@micimize
Copy link

@2foil ah you're right – but it seems we're both right, because I was looking at the output of SortMinorityClass, which uses Counter(targets), which is later fixed in the final ClassLevelHardMining callback

@2foil
Copy link

2foil commented Sep 22, 2020

@micimize Okay, I got it. Thanks for your explanation 😊.

@2foil
Copy link

2foil commented Sep 22, 2020

@amqdn Thanks for your tutorial ❤️, it helps me a lot when implementing CRLloss.

Here I have one question.
Now I'm dealing with one training dataset, which has multiple majority and minority classes.
So how to compute the omega in CRLloss?

@amqdn
Copy link
Author

amqdn commented Sep 23, 2020 via email

@2foil
Copy link

2foil commented Sep 24, 2020

@amqdn Got it 😊, thanks for your explanation. ❤️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment