Skip to content

Instantly share code, notes, and snippets.

@blu3r4y
Created March 9, 2021 18:54
Show Gist options
  • Save blu3r4y/09dde4d982f8ff772f21c2aea79fbb5d to your computer and use it in GitHub Desktop.
Save blu3r4y/09dde4d982f8ff772f21c2aea79fbb5d to your computer and use it in GitHub Desktop.
An efficient tie-breaking, numpy-based function that computes the majority class from multiple predictions
import numpy as np
def majority_vote(*arrays: np.array, n_classes: int = None) -> np.array:
"""
Given an arbitrary number of integer-based, one-dimensional input vectors
that represent class labels, compute a new vector that will take the majority
label of the input vectors. Ties are broken randomly for each observation.
If you omit the `n_classes` argument, it will be inferred by counting
the number of unique elements in all input vectors.
:param arrays: One or more one-dimensional, integer-based input vectors
:param n_classes: (Optional) The number of unique values in all input vectors
:return: A single one-dimensional, integer-based output vector
"""
assert all([arr.ndim == 1 for arr in arrays]), \
"the input arrays must be one-dimensional"
assert all([arr.shape[0] == arrays[0].shape[0] for arr in arrays]), \
"the input arrays must all be of the same length"
assert all([np.issubdtype(int, arr.dtype) for arr in arrays]), \
"the input arrays must be integer-typed"
# infer the number of classes
if n_classes is None:
n_classes = np.unique(np.stack(arrays)).size
# transform to one-hot encoding
one_hots = []
for arr in arrays:
one_hot = np.eye(n_classes, dtype=int)[arr]
one_hots.append(one_hot)
# stack the one hot encoded labels on a new axis
# and average them along this axis
tensor = np.stack(one_hots)
tensor = np.mean(tensor, axis=0)
# break ties at random
n_models = len(arrays)
delta = 0.5 / n_models
assert delta < 1 / n_models
tensor += np.random.uniform(0, delta, tensor.shape)
# take the maximum again to reproduce a final vector
tensor = np.argmax(tensor, axis=1)
assert tensor.shape == arrays[0].shape
return tensor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment