Last active
January 9, 2017 03:51
-
-
Save jnothman/a125ae314ba5297a79a0617d496a2833 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| from scipy import sparse | |
| def flexible_concatenate(it, final_len=None): | |
| """Concatenate the elements of an iterable | |
| Supports generators of arrays, lists, sparse matrices or tuples thereof | |
| >>> import numpy as np | |
| >>> from scipy import sparse | |
| >>> | |
| >>> def make_example(typ): | |
| ... yield typ([1, 2]) | |
| ... yield typ([3]) | |
| ... yield typ([4, 5, 6]) | |
| ... | |
| >>> flexible_concatenate(make_example(list)) | |
| [1, 2, 3, 4, 5, 6] | |
| >>> flexible_concatenate(make_example(np.array)) | |
| array([1, 2, 3, 4, 5, 6]) | |
| >>> flexible_concatenate(zip(make_example(list), make_example(np.array))) | |
| ([1, 2, 3, 4, 5, 6], array([1, 2, 3, 4, 5, 6])) | |
| >>> flexible_concatenate(make_example(np.array)) | |
| array([1, 2, 3, 4, 5, 6]) | |
| >>> flexible_concatenate(make_example(np.array), final_len=6) | |
| array([1, 2, 3, 4, 5, 6]) | |
| >>> flexible_concatenate(make_example( | |
| ... lambda x: np.array(x).reshape(-1, 1))) | |
| ... # doctest: +NORMALIZE_WHITESPACE | |
| array([[1], [2], [3], [4], [5], [6]]) | |
| >>> M = flexible_concatenate(make_example( | |
| ... lambda x: sparse.csr_matrix(np.array(x).reshape(-1, 1)))) | |
| ... # doctest: +NORMALIZE_WHITESPACE | |
| >>> M.format | |
| 'csr' | |
| >>> M.A # doctest: +NORMALIZE_WHITESPACE | |
| array([[1], [2], [3], [4], [5], [6]], dtype=int64) | |
| >>> M = flexible_concatenate(make_example( | |
| ... lambda x: sparse.csc_matrix(np.array(x).reshape(-1, 1)))) | |
| ... # doctest: +NORMALIZE_WHITESPACE | |
| >>> M.format | |
| 'csc' | |
| >>> M.A # doctest: +NORMALIZE_WHITESPACE | |
| array([[1], [2], [3], [4], [5], [6]], dtype=int64) | |
| """ | |
| def make_accumulator(prototype): | |
| if isinstance(prototype, tuple): | |
| return tuple(make_accumulator(y_proto) for y_proto in prototype) | |
| if isinstance(prototype, np.ndarray) and final_len is not None: | |
| return np.empty((final_len,) + prototype.shape[1:], | |
| dtype=prototype.dtype) | |
| else: | |
| return [] | |
| def accumulate(x, accumulator, prototype): | |
| if isinstance(prototype, tuple): | |
| for y, y_acc, y_prototype in zip(x, accumulator, prototype): | |
| n_rows = accumulate(y, y_acc, y_prototype) | |
| # XXX: could assert all n_rows are identical | |
| return n_rows | |
| elif isinstance(prototype, np.ndarray) and final_len is not None: | |
| accumulator[offset:offset + len(x)] = x | |
| return len(x) | |
| elif isinstance(prototype, list): | |
| accumulator.extend(x) | |
| return len(x) | |
| else: | |
| accumulator.append(x) | |
| if hasattr(x, 'shape'): | |
| return x.shape[0] | |
| return len(x) | |
| def finalize(accumulator, prototype): | |
| if isinstance(prototype, tuple): | |
| return tuple(finalize(y_acc, y_prototype) | |
| for y_acc, y_prototype in zip(accumulator, prototype)) | |
| elif isinstance(prototype, list): | |
| return accumulator | |
| elif isinstance(prototype, np.ndarray) and final_len is not None: | |
| return accumulator | |
| elif isinstance(prototype, np.ndarray): | |
| return np.concatenate(accumulator, axis=0) | |
| elif sparse.isspmatrix(prototype): | |
| return sparse.vstack(accumulator).asformat(prototype.format) | |
| else: | |
| raise NotImplementedError('No finalizing for accumulation of %s' | |
| % type(prototype)) | |
| it = iter(it) | |
| try: | |
| first = next(it) | |
| except StopIteration: | |
| raise ValueError('Require at least one output from the iterator') | |
| accumulator = make_accumulator(first) | |
| offset = 0 | |
| offset = accumulate(first, accumulator, first) | |
| for x in it: | |
| offset += accumulate(x, accumulator, first) | |
| if final_len is not None: | |
| assert offset == final_len, 'Expected %d, got %d' % (final_len, offset) | |
| return finalize(accumulator, first) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment