Created
July 28, 2015 06:29
-
-
Save shoyer/ae30a1200f749c84b4c4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 399, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"from xray.core.pycompat import OrderedDict\n", | |
"from xray.core.npcompat import stack\n", | |
"from scipy.spatial import cKDTree as KDTree\n", | |
"import numpy as np\n", | |
"import xray\n", | |
"\n", | |
"\n", | |
"def _as_kdtree_data(data):\n", | |
" data = np.asarray(data)\n", | |
" if data.ndim == 1:\n", | |
" data = data[:, np.newaxis]\n", | |
" return data\n", | |
"\n", | |
"\n", | |
"class KDTreeIndex(object):\n", | |
" \"\"\"A pandas.Index-like object that uses a KDTree for spatial indexing\n", | |
" \"\"\"\n", | |
" def __init__(self, data, leafsize=100):\n", | |
" self.data = _as_kdtree_data(data)\n", | |
" self.leafsize = leafsize\n", | |
" self._kdtree = None\n", | |
"\n", | |
" @property\n", | |
" def shape(self):\n", | |
" return len(self.data)\n", | |
" \n", | |
" @property\n", | |
" def dtype(self):\n", | |
" return self.data.dtype\n", | |
" \n", | |
" @property\n", | |
" def values(self):\n", | |
" # TODO: make immutable\n", | |
" return self.data\n", | |
"\n", | |
" @property\n", | |
" def kdtree(self):\n", | |
" if self._kdtree is None:\n", | |
" self._kdtree = KDTree(self.data, self.leafsize)\n", | |
" return self._kdtree\n", | |
" \n", | |
" def __repr__(self):\n", | |
" return '%r\\n%r' % (type(self), self.data)\n", | |
"\n", | |
" def get_indexer(self, target, max_distance=np.infty):\n", | |
" # TODO: handle other KDTreeIndex objects with a fast path\n", | |
" target = _as_kdtree_data(target)\n", | |
" _, indexer = self.kdtree.query(\n", | |
" target, distance_upper_bound=max_distance)\n", | |
" # like pandas Index objects, use -1 to mark missing values\n", | |
" indexer[indexer == len(self.kdtree.data)] = -1\n", | |
" return indexer\n", | |
" \n", | |
" def __getitem__(self, key):\n", | |
" if isinstance(key, tuple):\n", | |
" key, = key\n", | |
" return type(self)(self.data[key], self.leafsize)\n", | |
"\n", | |
" \n", | |
"def _shape(coords):\n", | |
" shapes = set(c.shape for c in coords)\n", | |
" if len(shapes) > 1:\n", | |
" raise ValueError('data has inconsistent shapes: %r' % shapes)\n", | |
" return shapes.pop()\n", | |
"\n", | |
"\n", | |
"def _stack(coords):\n", | |
" return stack([c.ravel() for c in coords], axis=-1)\n", | |
"\n", | |
"\n", | |
"def unravel_index_missing_okay(indexer, shape):\n", | |
" \"\"\"A variant of np.unravel_index that understands -1 as a sentinel\n", | |
" value for missing indexes.\n", | |
" \"\"\"\n", | |
" missing_values = indexer == -1\n", | |
" any_missing = np.any(missing_values)\n", | |
" if any_missing:\n", | |
" indexer[missing_values] = 0\n", | |
" indexers = np.unravel_index(indexer, shape)\n", | |
" if any_missing:\n", | |
" indexers = tuple(np.where(missing_values, -1, i)\n", | |
" for i in indexers)\n", | |
" return indexers\n", | |
"\n", | |
"\n", | |
"class NDIndex(object):\n", | |
" \"\"\"A wrapper of Index-like objects to make them handle N-dimensional\n", | |
" indexing.\n", | |
" \n", | |
" Currently designed as a wrapper for KDTreeIndex, but eventually might\n", | |
" handle other indexes for multi-dimensional data like Ball Trees or\n", | |
" R-Trees.\n", | |
" \"\"\"\n", | |
" def __init__(self, index_cls, coords):\n", | |
" coords = [np.asarray(c) for c in coords]\n", | |
" data = _stack(coords)\n", | |
" self._index_cls = index_cls\n", | |
" self._coords = coords\n", | |
" self._index = index_cls(data)\n", | |
" self.shape = _shape(coords)\n", | |
" \n", | |
" @property\n", | |
" def dtype(self):\n", | |
" return self._index.dtype\n", | |
" \n", | |
" @property\n", | |
" def values(self):\n", | |
" # TODO: cythonize?\n", | |
" data = np.empty(self.shape, dtype=object)\n", | |
" for idx in np.ndindex(*self.shape):\n", | |
" data[idx] = tuple(c[idx] for c in self._coords)\n", | |
" data.flags.writeable = False\n", | |
" return data\n", | |
"\n", | |
" def __repr__(self):\n", | |
" return ('%r\\nwrapping: %r\\n%r'\n", | |
" % (type(self), self._index_cls, self.values))\n", | |
"\n", | |
" def get_indexers(self, target, **kwargs):\n", | |
" \"\"\"Returns a tuple of indexers, suitable for use in isel_points\n", | |
" \n", | |
" Using this with isel_points will require that it is able to\n", | |
" handle N-dimensional arrays as input, but this already works\n", | |
" with NumPy's fancy (vectorized) indexing.\n", | |
" \"\"\"\n", | |
" target = [np.asarray(t) for t in target]\n", | |
" target_shape = _shape(target)\n", | |
" target_data = _stack(target)\n", | |
" indexer = self._index.get_indexer(target_data, **kwargs)\n", | |
" indexers = unravel_index_missing_okay(indexer, self.shape)\n", | |
" indexers = tuple(i.reshape(target_shape) for i in indexers)\n", | |
" return indexers\n", | |
"\n", | |
" def __getitem__(self, key):\n", | |
" return type(self)(self._index_cls,\n", | |
" [c[key] for c in self._coords])\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 400, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"index = NDIndex(KDTreeIndex, [[[0, 1, 2], [3, 4, 4]], [[-1, 2, 2.5], [3.5, 4, 4.2]]])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 401, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<class '__main__.NDIndex'>\n", | |
"wrapping: <class '__main__.KDTreeIndex'>\n", | |
"array([[(0, -1.0), (1, 2.0), (2, 2.5)],\n", | |
" [(3, 3.5), (4, 4.0), (4, 4.2000000000000002)]], dtype=object)" | |
] | |
}, | |
"execution_count": 401, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"index" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 402, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[(0, -1.0), (1, 2.0), (2, 2.5)],\n", | |
" [(3, 3.5), (4, 4.0), (4, 4.2000000000000002)]], dtype=object)" | |
] | |
}, | |
"execution_count": 402, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"index.values" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 403, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<xray.Dataset>\n", | |
"Dimensions: (x: 2, y: 3)\n", | |
"Coordinates:\n", | |
" grid (x, y) object (0, -1.0) (1, 2.0) (2, 2.5) (3, 3.5) (4, 4.0) ...\n", | |
" * x (x) int64 0 1\n", | |
" * y (y) int64 0 1 2\n", | |
"Data variables:\n", | |
" *empty*" | |
] | |
}, | |
"execution_count": 403, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"xray.Dataset({}, {'grid': (('x', 'y'), index)})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 404, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<class '__main__.NDIndex'>\n", | |
"wrapping: <class '__main__.KDTreeIndex'>\n", | |
"array([[(2, 2.5), (1, 2.0), (0, -1.0)],\n", | |
" [(4, 4.2000000000000002), (4, 4.0), (3, 3.5)]], dtype=object)" | |
] | |
}, | |
"execution_count": 404, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"index[:2, ::-1]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 405, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"xi, yi = index.get_indexers([[[0, 1], [3, 4]],\n", | |
" [[-1, 3], [3, 1.5]]],\n", | |
" max_distance=2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 406, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0, 0],\n", | |
" [ 1, -1]])" | |
] | |
}, | |
"execution_count": 406, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"xi" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 407, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0, 1],\n", | |
" [ 0, -1]])" | |
] | |
}, | |
"execution_count": 407, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"yi" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 408, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<class '__main__.NDIndex'>\n", | |
"wrapping: <class '__main__.KDTreeIndex'>\n", | |
"array([[(2, 2.5), (1, 2.0), (0, -1.0)],\n", | |
" [(4, 4.2000000000000002), (4, 4.0), (3, 3.5)]], dtype=object)" | |
] | |
}, | |
"execution_count": 408, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"index[:2, ::-1]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 409, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"xs, ys = index[:2, ::-1].get_indexers([[2.3, 4], [5, 4]])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 410, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(array([1, 1]), array([2, 1]))" | |
] | |
}, | |
"execution_count": 410, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"xs, ys" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 416, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<class '__main__.NDIndex'>\n", | |
"wrapping: <class '__main__.KDTreeIndex'>\n", | |
"array([(3, 3.5), (4, 4.0)], dtype=object)" | |
] | |
}, | |
"execution_count": 416, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"index[:2, ::-1][xs, ys]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment