Created
April 13, 2015 23:21
-
-
Save barentsen/7450661b124a4b60c482 to your computer and use it in GitHub Desktop.
Example coalesce() function for use with astropy's MaskedColumn objects
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def coalesce(columns): | |
"""Coalesces masked columns. | |
Parameters | |
---------- | |
columns : iterable of type `MaskedColumn` | |
Returns | |
------- | |
column : coalesced result | |
""" | |
columns = _get_list_of_columns(columns) # validates input | |
# todo: need to verify that the columns have the same size | |
result = columns[0].copy() | |
for col in columns[1:]: | |
mask_coalesce = result.mask & ~col.mask | |
result.data[mask_coalesce] = col.data[mask_coalesce] | |
result.mask[mask_coalesce] = False | |
return result | |
def _get_list_of_columns(columns): | |
""" | |
Check that columns is a Column or sequence of Columns. Returns the | |
corresponding list of Columns. | |
""" | |
import collections | |
# Make sure we have a list of things | |
if not isinstance(columns, collections.Sequence): | |
columns = [columns] | |
# Make sure each thing is a Column | |
if any(not isinstance(x, table.Column) for x in columns) or len(columns) == 0: | |
raise TypeError('`columns` arg must be a Column or sequence of Columns') | |
return columns |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TestCoalesce(): | |
def setup_method(self, method): | |
self.c1 = MaskedColumn(name='col1', data=[1, 2, 3], mask=[False, False, False]) | |
self.c2 = MaskedColumn(name='col2', data=[4, 5, 6], mask=[True, False, False]) | |
self.c3 = MaskedColumn(name='col3', data=[7, 8, 9], mask=[False, True, False]) | |
def test_basic(self): | |
assert np.all(coalesce([self.c1, self.c2, self.c3]) == self.c1) | |
assert np.all(coalesce([self.c2, self.c1]) == [1, 5, 6]) | |
assert np.all(coalesce([self.c2, self.c3]) == [7, 5, 6]) | |
def test_single_column(self): | |
for col in [self.c1, self.c2, self.c3]: | |
assert np.all(coalesce(col) == col) | |
def test_bad_input_type(self): | |
with pytest.raises(TypeError): | |
coalesce([]) | |
with pytest.raises(TypeError): | |
coalesce(1) | |
with pytest.raises(TypeError): | |
coalesce([self.c1, 1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment