Created
June 8, 2021 10:28
-
-
Save angeligareta/83d9024c5e72ac9ebc34c9f0b073c64c to your computer and use it in GitHub Desktop.
Method to generate class weights given a multi-class or multi-label set of classes using Python, supporting one-hot-encoded labels.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from sklearn.utils.class_weight import compute_class_weight | |
from sklearn.preprocessing import MultiLabelBinarizer | |
def generate_class_weights(class_series, multi_class=True, one_hot_encoded=False): | |
""" | |
Method to generate class weights given a set of multi-class or multi-label labels, both one-hot-encoded or not. | |
Some examples of different formats of class_series and their outputs are: | |
- generate_class_weights(['mango', 'lemon', 'banana', 'mango'], multi_class=True, one_hot_encoded=False) | |
{'banana': 1.3333333333333333, 'lemon': 1.3333333333333333, 'mango': 0.6666666666666666} | |
- generate_class_weights([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]], multi_class=True, one_hot_encoded=True) | |
{0: 0.6666666666666666, 1: 1.3333333333333333, 2: 1.3333333333333333} | |
- generate_class_weights([['mango', 'lemon'], ['mango'], ['lemon', 'banana'], ['lemon']], multi_class=False, one_hot_encoded=False) | |
{'banana': 1.3333333333333333, 'lemon': 0.4444444444444444, 'mango': 0.6666666666666666} | |
- generate_class_weights([[0, 1, 1], [0, 0, 1], [1, 1, 0], [0, 1, 0]], multi_class=False, one_hot_encoded=True) | |
{0: 1.3333333333333333, 1: 0.4444444444444444, 2: 0.6666666666666666} | |
The output is a dictionary in the format { class_label: class_weight }. In case the input is one hot encoded, the class_label would be index | |
of appareance of the label when the dataset was processed. | |
In multi_class this is np.unique(class_series) and in multi-label np.unique(np.concatenate(class_series)). | |
Author: Angel Igareta ([email protected]) | |
""" | |
if multi_class: | |
# If class is one hot encoded, transform to categorical labels to use compute_class_weight | |
if one_hot_encoded: | |
class_series = np.argmax(class_series, axis=1) | |
# Compute class weights with sklearn method | |
class_labels = np.unique(class_series) | |
class_weights = compute_class_weight(class_weight='balanced', classes=class_labels, y=class_series) | |
return dict(zip(class_labels, class_weights)) | |
else: | |
# It is neccessary that the multi-label values are one-hot encoded | |
mlb = None | |
if not one_hot_encoded: | |
mlb = MultiLabelBinarizer() | |
class_series = mlb.fit_transform(class_series) | |
n_samples = len(class_series) | |
n_classes = len(class_series[0]) | |
# Count each class frequency | |
class_count = [0] * n_classes | |
for classes in class_series: | |
for index in range(n_classes): | |
if classes[index] != 0: | |
class_count[index] += 1 | |
# Compute class weights using balanced method | |
class_weights = [n_samples / (n_classes * freq) if freq > 0 else 1 for freq in class_count] | |
class_labels = range(len(class_weights)) if mlb is None else mlb.classes_ | |
return dict(zip(class_labels, class_weights)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thank you. You can also use
class_count = np.array(class_series).sum(axis=0)
to count the labels in one line