Skip to content

Instantly share code, notes, and snippets.

@amqdn
Last active July 26, 2022 14:17
Show Gist options
  • Save amqdn/f3ba1ea30e4e21c24617f6d7aec75212 to your computer and use it in GitHub Desktop.
Save amqdn/f3ba1ea30e4e21c24617f6d7aec75212 to your computer and use it in GitHub Desktop.
Implementing Class Rectification Loss in fast.ai
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "## Class Rectification Loss"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "https://arxiv.org/abs/1804.10851\n\nLet's try to implement Class Rectification Loss. First, we need to understand what a loss function is doing -- what goes in, what comes out? Let's make a study of different loss functions and see what we can learn. "
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "%reload_ext autoreload\n%autoreload 2\n%matplotlib inline",
"execution_count": 1,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "from fastai.vision import *",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "np.random.seed(4)",
"execution_count": 3,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### L1 Loss"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "loss = nn.L1Loss()\nloss_per_sample = nn.L1Loss(reduction='none')",
"execution_count": 4,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "In some hypothetical scenario, we can imagine that we have a batch of 15 images, and we have 15 different labels for those images. In practice, we don't use L1 Loss this way, but since we're thinking about batch-wise class rectification, we can think of L1 Loss in these terms to help us grasp the problem in a simple way. \n\nThe loss function expects the \"output\" of the model (i.e., its predictions), and then it will compare those predictions against the ground truth. As we've said, we can imagine this as 15 predictions from the model and 15 ground truth labels. "
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "input = torch.randn(3, 5, requires_grad=True)\ninput # 15 \"predictions\" from the model",
"execution_count": 5,
"outputs": [
{
"data": {
"text/plain": "tensor([[-0.0289, -1.4645, 0.4495, 0.6747, -0.8097],\n [-1.2833, -0.8002, -0.4095, 0.3529, 0.4514],\n [ 1.0435, -1.6748, 1.1579, 1.0776, 0.2149]], requires_grad=True)"
},
"output_type": "execute_result",
"execution_count": 5,
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "target = torch.eye(3, 5)\ntarget # 15 \"ground truth\" labels",
"execution_count": 6,
"outputs": [
{
"data": {
"text/plain": "tensor([[1., 0., 0., 0., 0.],\n [0., 1., 0., 0., 0.],\n [0., 0., 1., 0., 0.]])"
},
"output_type": "execute_result",
"execution_count": 6,
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "output = loss_per_sample(input, target)\noutput # L1 loss first calculates the absolute value of the distances between input x and target y...",
"execution_count": 7,
"outputs": [
{
"data": {
"text/plain": "tensor([[1.0289, 1.4645, 0.4495, 0.6747, 0.8097],\n [1.2833, 1.8002, 0.4095, 0.3529, 0.4514],\n [1.0435, 1.6748, 0.1579, 1.0776, 0.2149]], grad_fn=<L1LossBackward>)"
},
"output_type": "execute_result",
"execution_count": 7,
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "output = loss(input, target)\noutput # ...then calculates the mean; in other words, the mean absolute error",
"execution_count": 8,
"outputs": [
{
"data": {
"text/plain": "tensor(0.8596, grad_fn=<L1LossBackward>)"
},
"output_type": "execute_result",
"execution_count": 8,
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Remember that the objective of SGD is to find the minimum for this loss. "
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Triplet Loss"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "PyTorch includes Triplet Loss in their available loss functions, and since this is the base for Class Rectification Loss after we've hard-mined our samples, let's also take a look at how it works. "
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)",
"execution_count": 21,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "The \"triplet\" in triplet loss is composed of three samples: an anchor sample (a random sample), a positive sample (a random other sample of the same class as the anchor), and a negative sample (a random sample that is not of the same class as the anchor). The similarities (or distances) between the anchor sample $a$ and the positive sample $p$, as well as between anchor $a$ and the negative sample $n$, are calculated using Euclidean Norm: $d_{+} = ||f(a) - f(p)||_{2}$ and $d_{-} = ||f(a) - f(n)||_{2}$. \n\nThe \"margin\" in `TripletMarginLoss` refers to \"margin ranking loss\", which aims to rank the values of these distances such that $d_{-}$ > $d_{+}$ + _µ_. Why? If we enforce a ranking such that $d_{-}$ is always greater than $d_{+}$ plus some margin _µ_, then we are making sure that all negative samples $n$ are farther away from $a$ than our positive samples $p$. And the way we tell the model we want it to optimize for this condition is by a loss function _λ_$(d_{+}, d_{-}) =$ _max_(0, _µ_ + $d_{+}$ - $d_{-}$), otherwise known as _ReLU_(_µ_ + $d_{+}$ - $d_{-}$). \n\nNote that `TripletMarginLoss` expects (a, p, n), and so the selection of these samples is left up to the user."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "a = torch.randn(3, 5, requires_grad=True)\na # Some pre-chosen sample of a class",
"execution_count": 28,
"outputs": [
{
"data": {
"text/plain": "tensor([[ 0.3072, -0.0931, 1.5000, 0.6430, -0.3071],\n [ 0.0776, -0.7049, -0.1978, -0.8530, -0.9692],\n [-1.4200, -0.8154, -1.2950, -0.3744, -0.6828]], requires_grad=True)"
},
"output_type": "execute_result",
"execution_count": 28,
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "p = torch.randn(3, 5, requires_grad=True)\np # Some pre-chosen sample of the same class as a",
"execution_count": 29,
"outputs": [
{
"data": {
"text/plain": "tensor([[-1.7797, -0.3953, -0.2961, -0.6337, -0.4672],\n [-0.9017, 1.3158, 1.6002, -0.5278, 0.0329],\n [-0.9086, 0.7668, 0.8678, -1.7287, 0.6816]], requires_grad=True)"
},
"output_type": "execute_result",
"execution_count": 29,
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "n = torch.randn(3, 5, requires_grad=True)\nn # Some pre-chosen sample of a different class from a",
"execution_count": 30,
"outputs": [
{
"data": {
"text/plain": "tensor([[-0.9549, 0.0301, -1.1895, 0.4843, 0.8937],\n [-1.5880, 0.3522, -0.1486, 1.2759, -0.7453],\n [-1.1284, 0.2808, -0.3532, -0.0723, 0.3850]], requires_grad=True)"
},
"output_type": "execute_result",
"execution_count": 30,
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "output = triplet_loss(a, p, n)\noutput",
"execution_count": 31,
"outputs": [
{
"data": {
"text/plain": "tensor(1.4959, grad_fn=<MeanBackward1>)"
},
"output_type": "execute_result",
"execution_count": 31,
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Class Rectification Loss (Class+Rel)"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Class Rectification Loss (Class+Rel) can be thought of as an expanded Triplet Margin Loss. \n\nWhat is Class + Rel? The authors explore multiple variants of Class Rectification Loss (CRL), but the one that produces the best results is a combination of _class_-level hard mining and _relative_ comparison. \n\nClass-level hard mining looks for those samples that the model is most wrong about. If the sample is of the same class as the anchor $a$ but the model provided a low probability score for that class, then that sample is considered a hard-positive. Similarly, if the sample is of a different class from $a$ but the model provided a high probability score, then that sample is considered a hard-negative. Class-level hard mining sorts the samples according to the worst offenders, for both hard-positive and hard-negative. \n\nThe next part of class rectification involves comparing these mined samples. Relative comparison compares the relative distances of these samples from each other. Specifically, at the class level, we compare directly the prediction probabilities outputed by the model, such that $d_{+} = |P(a) - P(p)|$ and $d_{-} = P(a) - P(n)$, where $P$ represents the probability score(s) for that sample. Note that this constitutes one of the main differences between CRL Class-level Triplet Loss and `TripletMarginLoss` above: the distances $d$ calculated here are the distances between the probabilities that the model is predicting and the ground truth, instead of the Euclidean distance between the samples themselves (which is used in the instance-level variants of CRL). \n\nThe other main difference between CRL and `TripletMarginLoss` is that it is a batch-wise operation. That is, for every mini-batch, CRL defines the majority and minority classes of that batch, mines for hard-positives and hard-negatives relative to those minority classes, and then calculates triplet loss across those minority triplets. The triplet loss is calculated for every single sample of the minority class against some number $k$ of the hard-positives and hard-negatives each, which means we form a triplet for every single anchor against $k$ hard-positives and hard-negatives. If we have 10 anchors and $k = 25$, then we will have $10 * k * k$ triplets. This can be thought of as an expanded form of Online Triplet Loss. \n\nFinally, it is important to note that CRL always focuses the triplet calculation on the minority class, whatever that minority happens to be batch-to-batch. This ensures that the model is incrementally optimizing triplet loss for the minority of every batch, which is the main innovation of CRL in addressing highly imbalanced datasets. Note that the same operation applies if there are multiple minority classes such that their sum representation in the batch is $<=$ 50%.\n\nIn order to construct our loss function, let's now play with a toy dataset to find out what's going in and what's coming out of the model."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "path = untar_data(URLs.MNIST_SAMPLE)\ndata = ImageDataBunch.from_folder(path)",
"execution_count": 9,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "data.show_batch(3, figsize=(5,5))",
"execution_count": 5,
"outputs": [
{
"data": {
"text/plain": "<Figure size 360x360 with 9 Axes>",
"image/png": "\n"
},
"output_type": "display_data",
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "data.batch_size",
"execution_count": 34,
"outputs": [
{
"data": {
"text/plain": "64"
},
"output_type": "execute_result",
"execution_count": 34,
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "fast.ai provides a wonderful callback API that allows us to modify the training process in almost any way, including, among other things, intervening when the loss is about to be calculated. Before we write our callback, let's first summarize what we need from the model so that we know at which stages of the training we want to step in. \n\nIn order to implement CRL, we need to know, for every batch:\n1. Which class is the minority class?\n2. For that minority class, which samples of the same (right) class have the lowest probability prediction from the model (hard-positives)? \n3. For that minority class, which samples of a different (wrong) class had the highest probability prediction (hard-negative)? \n\nAfter reviewing the various callback methods available, there are two places we can step in: `on_batch_begin` and `on_loss_begin`. Let's try an example."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "model = simple_cnn((3, 16, 16, 2))",
"execution_count": 10,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# https://docs.fast.ai/callback.html#callback\nclass PrintLabelsAndOutput(LearnerCallback):\n def __init__(self, learn:Learner):\n super().__init__(learn)\n self.iters = 0 # Manage the number of printouts we get\n def on_batch_begin(self, last_target:Tensor, **kwargs:Any) -> Tensor:\n if self.iters < 3: print('Targets --> ' + str(last_target[:3]))\n def on_loss_begin(self, last_output:Tensor, **kwargs:Any) -> Tensor:\n if self.iters < 3: print('Outputs --> |0| |1|\\n' + str(last_output[:3]))\n self.iters += 1",
"execution_count": 13,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn = Learner(data, model, callback_fns=[PrintLabelsAndOutput])",
"execution_count": 14,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn.fit(1)",
"execution_count": 15,
"outputs": [
{
"data": {
"text/html": "Total time: 00:01 <p><table style='width:300px; margin-bottom:10px'>\n <tr>\n <th>epoch</th>\n <th>train_loss</th>\n <th>valid_loss</th>\n <th>time</th>\n </tr>\n <tr>\n <th>0</th>\n <th>0.066576</th>\n <th>0.055267</th>\n <th>00:01</th>\n </tr>\n</table>\n",
"text/plain": "<IPython.core.display.HTML object>"
},
"output_type": "display_data",
"metadata": {}
},
{
"text": "Targets --> tensor([1, 0, 1], device='cuda:0')\nOutputs --> |0| |1|\ntensor([[ 0.0000, 7.8249],\n [10.2639, 2.9326],\n [ 1.5382, 4.5256]], device='cuda:0', grad_fn=<SliceBackward>)\nTargets --> tensor([0, 1, 1], device='cuda:0')\nOutputs --> |0| |1|\ntensor([[7.9085, 2.0232],\n [1.2495, 5.1679],\n [0.2476, 6.8007]], device='cuda:0', grad_fn=<SliceBackward>)\nTargets --> tensor([1, 1, 0], device='cuda:0')\nOutputs --> |0| |1|\ntensor([[1.7952, 8.2486],\n [2.0123, 4.3107],\n [9.9175, 1.7809]], device='cuda:0', grad_fn=<SliceBackward>)\n",
"name": "stdout",
"output_type": "stream"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "What are we seeing here? First, for our purposes, even though the batch size is 64, we are only taking a look at the first three samples of each batch, up to three batches. \n\nSo \"targets\" refers to the ground truth classes of the first three samples of our mini-batch. For our toy dataset, we only have two possible classes: `0` and `1`, referring to whether the number is a 3 or a 7. \"Outputs\" refers to the output of the model -- that is, the prediction the model has made as to whether the given sample is of class `0` or class `1`. If the first number in a pair is higher, then the model thinks the sample is more likely to be of class `0`; similarly, if the second number is higher, then the model thinks the sample is more likely to be of class `1`. If you run the above cell multiple times, you should see these predictions become more and more \"confident.\"\n\nAt this point, we should have everything we need in order to implement CRL. Let's first sort our batch according to majority and minority classes. Note that our toy dataset is balanced, so we won't necessarily see a benefit here. In fact, our code below may return the same class as minority and majority because of this. "
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "from collections import Counter",
"execution_count": 11,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "class SortMinorityClass(LearnerCallback):\n def __init__(self, learn:Learner):\n super().__init__(learn)\n self.iters = 0 # Manage the number of printouts we get\n def on_batch_begin(self, last_target:Tensor, **kwargs:Any) -> Tensor:\n if self.iters < 2:\n targets = last_target.cpu()\n target_indices = set(enumerate(targets))\n class_tally = Counter(targets)\n minority_class = min(class_tally, key=class_tally.get) # Find the class with the least num of samples\n minority_indices = {i[0] for i in target_indices if i[1] == minority_class}\n majority_class = max(class_tally, key=class_tally.get)\n majority_indices = {i[0] for i in target_indices if i[1] == majority_class}\n print('Minority Indices: ' + str(minority_indices))\n print('Majority Indices: ' + str(majority_indices))\n self.iters += 1",
"execution_count": 34,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn = Learner(data, model, callback_fns=[SortMinorityClass])",
"execution_count": 35,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn.fit(1)",
"execution_count": 36,
"outputs": [
{
"data": {
"text/html": "Total time: 00:01 <p><table style='width:300px; margin-bottom:10px'>\n <tr>\n <th>epoch</th>\n <th>train_loss</th>\n <th>valid_loss</th>\n <th>time</th>\n </tr>\n <tr>\n <th>0</th>\n <th>0.051920</th>\n <th>0.036756</th>\n <th>00:01</th>\n </tr>\n</table>\n",
"text/plain": "<IPython.core.display.HTML object>"
},
"output_type": "display_data",
"metadata": {}
},
{
"text": "Minority Indices: {0, 1, 2, 3, 5, 7, 10, 11, 13, 15, 17, 19, 20, 21, 22, 23, 26, 28, 30, 34, 35, 38, 44, 49, 50, 51, 52, 53, 54, 56, 58, 59, 60, 62}\nMajority Indices: {0, 1, 2, 3, 5, 7, 10, 11, 13, 15, 17, 19, 20, 21, 22, 23, 26, 28, 30, 34, 35, 38, 44, 49, 50, 51, 52, 53, 54, 56, 58, 59, 60, 62}\nMinority Indices: {0, 3, 4, 5, 7, 11, 12, 14, 15, 18, 20, 22, 23, 26, 32, 35, 36, 39, 42, 44, 46, 48, 51, 52, 53, 54, 57, 58, 60, 63}\nMajority Indices: {0, 3, 4, 5, 7, 11, 12, 14, 15, 18, 20, 22, 23, 26, 32, 35, 36, 39, 42, 44, 46, 48, 51, 52, 53, 54, 57, 58, 60, 63}\n",
"name": "stdout",
"output_type": "stream"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Now that we have the indices for every sample sorted by which class they belong to, we can use those indices to split up the predictions from the model and sort those predictions into the worst offenders. "
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "model = simple_cnn((3, 16, 16, 2))",
"execution_count": 4,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "from collections import Counter",
"execution_count": 5,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "class ClassLevelHardMining(LearnerCallback):\n \"\"\"An implementation of class-level hard sample mining that is intended only for the two-class case.\"\"\"\n def __init__(self, learn:Learner):\n super().__init__(learn)\n self.sample_counts = Counter(learn.data.y.items)\n self.omega = None # Class imbalance measure\n self.minority_class = None\n self.minority_indices = None\n self.majority_class = None\n self.majority_indices = None\n def _make_omega(self, b, a):\n if self.omega is not None:\n return self.omega\n return (self.sample_counts[b] - self.sample_counts[a]) / self.sample_counts[b]\n def on_batch_begin(self, last_target:Tensor, **kwargs:Any) -> Tensor:\n targets = last_target.cpu()\n target_indices = set(enumerate(targets))\n class_tally = Counter(targets.tolist())\n self.minority_class = min(class_tally, key=class_tally.get) # Find the class with the least num of samples\n self.minority_indices = tensor([i[0] for i in target_indices if i[1] == self.minority_class])\n self.majority_class = max(class_tally, key=class_tally.get)\n self.majority_indices = tensor([i[0] for i in target_indices if i[1] == self.majority_class])\n self.omega = self._make_omega(self.majority_class, self.minority_class)\n def on_loss_begin(self, last_output:Tensor, **kwargs:Any) -> Tensor:\n predictions = last_output.cpu()\n # Every minority sample is also treated as an anchor\n minority_predictions = anchors = predictions[self.minority_indices][:, self.minority_class]\n majority_predictions = predictions[self.majority_indices][:, self.majority_class]\n k = len(minority_predictions) if len(minority_predictions) < 25 else 25\n bottom_k_hard_pos = torch.sort(minority_predictions)[0][:k]\n top_k_hard_neg = torch.sort(majority_predictions, descending=True)[0][:k]\n return {'last_output': (anchors, bottom_k_hard_pos, top_k_hard_neg, predictions, self.omega)}",
"execution_count": 37,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "This callback will provide us with a set of anchors, the bottom-$k$ hard-positives, and the top-$k$ hard-negatives. Our loss function will then need to form triplets for every single anchor against all pairs of hard-positives and hard-negatives and then calculate the average sum (class probability) distances across all triplets. "
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "class ClassRectificationLoss(nn.Module):\n def __init__(self, eta=0.01, margin=0.5):\n super(ClassRectificationLoss, self).__init__()\n self.eta = eta # Subparameter of α\n self.margin = margin # Triplet Loss\n self.ce = CrossEntropyFlat()\n def _relative_comparison(self, a, p, n):\n # Cartesian product of all combinations\n num_a, num_p, num_n = len(a), len(p), len(n)\n a = a.view(-1, 1).expand(num_a, num_p**2).reshape(-1)\n p = p.view(-1, 1).expand(num_p, num_n).reshape(-1).repeat(num_a)\n n = n.repeat(num_a * num_p)\n a, p, n = torch.stack([a, p, n]) # Simplify tensor math\n d_p = (a - p)\n d_n = (a - n)\n losses = F.relu(self.margin + d_p - d_n)\n # TODO: Remove those losses == 0, then avg over the remaining triplets\n return losses.sum() / len(a)\n def forward(self, last_output, targets, reduction='none'):\n targets = targets.cpu()\n a, p, n, predictions, omega = last_output\n alpha = (self.eta * omega)\n return (alpha * self._relative_comparison(a, p, n)) + ((1 - alpha) * self.ce(predictions, targets))",
"execution_count": 38,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn = Learner(data, model, loss_func=ClassRectificationLoss(), callback_fns=[ClassLevelHardMining])",
"execution_count": 39,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn.fit_one_cycle(5)",
"execution_count": 40,
"outputs": [
{
"data": {
"text/html": "Total time: 00:13 <p><table style='width:300px; margin-bottom:10px'>\n <tr>\n <th>epoch</th>\n <th>train_loss</th>\n <th>valid_loss</th>\n <th>time</th>\n </tr>\n <tr>\n <th>0</th>\n <th>0.040992</th>\n <th>0.033535</th>\n <th>00:02</th>\n </tr>\n <tr>\n <th>1</th>\n <th>0.037327</th>\n <th>0.028478</th>\n <th>00:02</th>\n </tr>\n <tr>\n <th>2</th>\n <th>0.034604</th>\n <th>0.021936</th>\n <th>00:02</th>\n </tr>\n <tr>\n <th>3</th>\n <th>0.024023</th>\n <th>0.021723</th>\n <th>00:02</th>\n </tr>\n <tr>\n <th>4</th>\n <th>0.019512</th>\n <th>0.020085</th>\n <th>00:02</th>\n </tr>\n</table>\n",
"text/plain": "<IPython.core.display.HTML object>"
},
"output_type": "display_data",
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn.show_results(rows=5)",
"execution_count": 41,
"outputs": [
{
"data": {
"text/plain": "<Figure size 1440x1440 with 25 Axes>",
"image/png": "\n"
},
"output_type": "display_data",
"metadata": {
"needs_background": "light"
}
}
]
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
},
"language_info": {
"mimetype": "text/x-python",
"pygments_lexer": "ipython3",
"version": "3.7.1",
"nbconvert_exporter": "python",
"name": "python",
"file_extension": ".py",
"codemirror_mode": {
"name": "ipython",
"version": 3
}
},
"gist": {
"id": "0453f149b41c46d060cbc7ecf0b5632d",
"data": {
"description": "Implementing Class Rectification Loss in fast.ai",
"public": false
}
},
"_draft": {
"nbviewer_url": "https://gist.github.com/0453f149b41c46d060cbc7ecf0b5632d"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@micimize
Copy link

micimize commented Jul 4, 2020

Great explanation, but I don't think the counter usage works, as the targets are tensors and not plain numbers. That's why your Majority and Minority Indices are the same.

Here is some of the preliminary work I did that fixes this using torch.unique and and generalizes up to SortMinorityClass:

from dataclasses import dataclass

def indices_of(occurences, given_class):
    "extract a tensor of indices of the given instance class"
    return tensor([
        i for i, encoding in enumerate(occurences)
        if torch.all(torch.eq(encoding, given_class))
    ])

@dataclass
class BatchClassification:
    __slots__ = ['encoding', 'indices', 'frequency']
    encoding: Tensor # class encoding as a tensor
    indices: Tensor # indices within the batch
    frequency: int # number of occurrences in the class

def classifications_of(targets: Tensor, descending_frequency=False) -> t.Iterable[BatchClassification]:
    class_encodings, class_indices, class_counts = torch.unique(targets, dim=0, return_counts=True, return_inverse=True)
    return sorted([
        BatchClassification(encoding, indices_of(class_indices, encoding), frequency)
        for encoding, frequency in zip(class_encodings, class_counts)
    ], key=lambda bc: bc.frequency, reverse=descending_frequency)

class SortMinorityClass(LearnerCallback):
    def __init__(self, learn:Learner):
        super().__init__(learn)
        self.iters = 0  # Manage the number of printouts we get
    def on_batch_begin(self, last_target:Tensor, **kwargs:Any) -> Tensor:
        if self.iters < 2:
            for index, bc in enumerate(classifications_of(last_target)):
                print(f'frequency group {index}, count {str(bc.frequency)}: {bc}')
            self.iters += 1

Also, since the mining results are a function of outputs and targets anyways, couldn't we implement the entire mining operation within the loss function? I'm kinda new to ML so maybe I'm missing something

@amqdn
Copy link
Author

amqdn commented Jul 17, 2020

@micimize

Thanks for taking the time to do that! I like the idea of using torch.unique to accomplish the sample counting. As you can see, I published this over a year ago, and I notice my relative inexperience with PyTorch and ML back then shows.

Re: mining results being a function of outputs and targets... If I understand your question: You're asking why it is that I included the mining operation in a callback_fn instead of inside the loss.forward. In looking at my code and trying to remember why, I think I couldn't come up with a way to retain the indices of the majority/minority classes using just the loss module. Since the majority/minority class designation changes from batch-to-batch dynamically in this paper, it's necessary to keep track of the indices (I think) in order to calculate the loss properly. That doesn't mean it's not possible, but I hadn't found a way.

@2foil
Copy link

2foil commented Sep 21, 2020

Great explanation, but I don't think the counter usage works, as the targets are tensors and not plain numbers. That's why your Majority and Minority Indices are the same.

Here is some of the preliminary work I did that fixes this using torch.unique and and generalizes up to SortMinorityClass:

from dataclasses import dataclass

def indices_of(occurences, given_class):
    "extract a tensor of indices of the given instance class"
    return tensor([
        i for i, encoding in enumerate(occurences)
        if torch.all(torch.eq(encoding, given_class))
    ])

@dataclass
class BatchClassification:
    __slots__ = ['encoding', 'indices', 'frequency']
    encoding: Tensor # class encoding as a tensor
    indices: Tensor # indices within the batch
    frequency: int # number of occurrences in the class

def classifications_of(targets: Tensor, descending_frequency=False) -> t.Iterable[BatchClassification]:
    class_encodings, class_indices, class_counts = torch.unique(targets, dim=0, return_counts=True, return_inverse=True)
    return sorted([
        BatchClassification(encoding, indices_of(class_indices, encoding), frequency)
        for encoding, frequency in zip(class_encodings, class_counts)
    ], key=lambda bc: bc.frequency, reverse=descending_frequency)

class SortMinorityClass(LearnerCallback):
    def __init__(self, learn:Learner):
        super().__init__(learn)
        self.iters = 0  # Manage the number of printouts we get
    def on_batch_begin(self, last_target:Tensor, **kwargs:Any) -> Tensor:
        if self.iters < 2:
            for index, bc in enumerate(classifications_of(last_target)):
                print(f'frequency group {index}, count {str(bc.frequency)}: {bc}')
            self.iters += 1

Also, since the mining results are a function of outputs and targets anyways, couldn't we implement the entire mining operation within the loss function? I'm kinda new to ML so maybe I'm missing something

Actually, the Counter works well here.
Because when passing the tensors to Counter, the author calls the tolist() method of torch.tensor object.
Check this bellow demo👇:

屏幕快照 2020-09-21 23 08 18

@micimize
Copy link

@2foil ah you're right – but it seems we're both right, because I was looking at the output of SortMinorityClass, which uses Counter(targets), which is later fixed in the final ClassLevelHardMining callback

@2foil
Copy link

2foil commented Sep 22, 2020

@micimize Okay, I got it. Thanks for your explanation 😊.

@2foil
Copy link

2foil commented Sep 22, 2020

@amqdn Thanks for your tutorial ❤️, it helps me a lot when implementing CRLloss.

Here I have one question.
Now I'm dealing with one training dataset, which has multiple majority and minority classes.
So how to compute the omega in CRLloss?

@amqdn
Copy link
Author

amqdn commented Sep 23, 2020 via email

@2foil
Copy link

2foil commented Sep 24, 2020

@amqdn Got it 😊, thanks for your explanation. ❤️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment