Skip to content

Instantly share code, notes, and snippets.

@angeligareta
Created June 8, 2021 10:28
Show Gist options
  • Save angeligareta/83d9024c5e72ac9ebc34c9f0b073c64c to your computer and use it in GitHub Desktop.
Save angeligareta/83d9024c5e72ac9ebc34c9f0b073c64c to your computer and use it in GitHub Desktop.
Method to generate class weights given a multi-class or multi-label set of classes using Python, supporting one-hot-encoded labels.
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
from sklearn.preprocessing import MultiLabelBinarizer
def generate_class_weights(class_series, multi_class=True, one_hot_encoded=False):
"""
Method to generate class weights given a set of multi-class or multi-label labels, both one-hot-encoded or not.
Some examples of different formats of class_series and their outputs are:
- generate_class_weights(['mango', 'lemon', 'banana', 'mango'], multi_class=True, one_hot_encoded=False)
{'banana': 1.3333333333333333, 'lemon': 1.3333333333333333, 'mango': 0.6666666666666666}
- generate_class_weights([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]], multi_class=True, one_hot_encoded=True)
{0: 0.6666666666666666, 1: 1.3333333333333333, 2: 1.3333333333333333}
- generate_class_weights([['mango', 'lemon'], ['mango'], ['lemon', 'banana'], ['lemon']], multi_class=False, one_hot_encoded=False)
{'banana': 1.3333333333333333, 'lemon': 0.4444444444444444, 'mango': 0.6666666666666666}
- generate_class_weights([[0, 1, 1], [0, 0, 1], [1, 1, 0], [0, 1, 0]], multi_class=False, one_hot_encoded=True)
{0: 1.3333333333333333, 1: 0.4444444444444444, 2: 0.6666666666666666}
The output is a dictionary in the format { class_label: class_weight }. In case the input is one hot encoded, the class_label would be index
of appareance of the label when the dataset was processed.
In multi_class this is np.unique(class_series) and in multi-label np.unique(np.concatenate(class_series)).
Author: Angel Igareta ([email protected])
"""
if multi_class:
# If class is one hot encoded, transform to categorical labels to use compute_class_weight
if one_hot_encoded:
class_series = np.argmax(class_series, axis=1)
# Compute class weights with sklearn method
class_labels = np.unique(class_series)
class_weights = compute_class_weight(class_weight='balanced', classes=class_labels, y=class_series)
return dict(zip(class_labels, class_weights))
else:
# It is neccessary that the multi-label values are one-hot encoded
mlb = None
if not one_hot_encoded:
mlb = MultiLabelBinarizer()
class_series = mlb.fit_transform(class_series)
n_samples = len(class_series)
n_classes = len(class_series[0])
# Count each class frequency
class_count = [0] * n_classes
for classes in class_series:
for index in range(n_classes):
if classes[index] != 0:
class_count[index] += 1
# Compute class weights using balanced method
class_weights = [n_samples / (n_classes * freq) if freq > 0 else 1 for freq in class_count]
class_labels = range(len(class_weights)) if mlb is None else mlb.classes_
return dict(zip(class_labels, class_weights))
@gbiz123
Copy link

gbiz123 commented Apr 3, 2023

Thank you. You can also use class_count = np.array(class_series).sum(axis=0) to count the labels in one line

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment