Created
June 5, 2017 19:31
-
-
Save jhumigas/aa74e7a08fb654e46f0fba4e705ba4c2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Small example of function to reorder a data set. | |
""" | |
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import numpy as np | |
def extract_k(labels, k): | |
"""Extract k percent of the dataset. | |
Organise the dataset such as the first k percent of the indexes is | |
s.t each class is represented by k samples of its class. | |
Args: | |
labels(np.ndarray): n*1 matrix | |
k(int): Number between 0 and 1 | |
Returns: | |
np.ndarray: Indices such as the 1st indices correspond to the 1st K k percent | |
in each class | |
""" | |
if not (0 <= k <=1): | |
raise ValueError('k has to be between 0 and 1') | |
return None | |
cx = np.unique(labels) | |
bins = np.concatenate((cx, [np.max(labels)+1])) | |
hx, _ = np.histogram(labels, bins=bins) | |
idx = np.array([]) | |
idx_ = np.array([]) | |
for ci in range(len(cx)): | |
nx = int(hx[ci]*k) | |
idx = np.concatenate((idx, np.where(labels==cx[ci])[0][:nx])) | |
idx_ = np.concatenate((idx_, np.where(labels==cx[ci])[0][nx:])) | |
idx = np.concatenate((idx, idx_)) | |
idx = idx.astype(int) | |
return idx | |
def reorder_k(X, y, k): | |
"""Order data in such a way the top k percent contain exactly k percent of each class. | |
Args: | |
X(np.ndarray): n*d matrix where each row represent a sample | |
y(np.ndarray): n*1 matrix where each row represent a label | |
Return: | |
tuple: Samples and labels ordered | |
""" | |
idx = extract_k(y, k) | |
return X[idx,:], y[idx] | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment