Created
May 22, 2019 11:59
-
-
Save ivankeller/4b26b1b5092465da17a43504982f59ab to your computer and use it in GitHub Desktop.
sub sample two-class sets into subsets given a ratio of items
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytest | |
import numpy as np | |
def r_cut(x0: float, x1: float, r: float) -> (float, float): | |
"""Return y0, y1 such that y1/y0 = r, y0<x0, y1<x1 and y0+y1 is maximum. | |
Parameters | |
---------- | |
x0, x1 > 0 | |
r >= 0 and <= 1 | |
Returns | |
------- | |
(float, float) | |
""" | |
R = x1 / (x0 + x1) # actual ratio | |
if r == 0: | |
return x0, 0 | |
if r == 1: | |
return 0, x1 | |
if r <= R: | |
y0 = x0 | |
y1 = y0 * r/(1-r) | |
else: | |
y1 = x1 | |
y0 = y1 * (1-r)/r | |
return y0, y1 | |
def r_sample(set0, set1, r): | |
"""Return a subset of two sets with a given ratio of items. | |
Given two sets of items corresponding to 2 classes, | |
return the length of a subset of each one such that #class1 / (#class0 + #class1) == r, | |
maximizing the number of available items in total. | |
Parameters | |
---------- | |
set0, set1 : sequence of objects to sample | |
r : float | |
Returns | |
------- | |
(int, int) | |
""" | |
x0, x1 = len(set0), len(set1) | |
return tuple([int(value) for value in r_cut(x0, x1, r)]) | |
@pytest.mark.parametrize("test_input, expected", [ | |
((10, 5, 0), (10, 0)), # special case r == 0 | |
((10, 5, 1), (0, 5)), # special case r == 1 | |
((10, 5, 0.5), (5, 5)), # x0 > x1 and r > R | |
((5, 10, 0.5), (5, 5)), # x0 < x1 and r < R | |
((9, 10, 0.1), (9, 1)), # x0 < x1 and r < R | |
((10, 9, 0.1), (10, 10/9)), # x0 > x1 and r < R | |
((5, 10, 0.9), (10/9, 10)) # x0 < x1 and r > R | |
]) | |
def test_r_cut(test_input, expected): | |
np.testing.assert_allclose(r_cut(*test_input), expected) | |
@pytest.mark.parametrize("test_input, expected", [ | |
((10*[1], 5*[1], 0), (10, 0)), | |
((10*[1], 5*[1], 1), (0, 5)), | |
((10*[1], 5*[1], 0.5), (5, 5)), | |
((5*[1], 10*[1], 0.5), (5, 5)), | |
((9*[1], 10*[1], 0.1), (9, 1)), | |
((10*[1], 9*[1], 0.1), (10, 1)), | |
((5*[1], 10*[1], 0.9), (1, 10)) | |
]) | |
def test_r_sample(test_input, expected): | |
assert r_sample(*test_input) == expected |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment