Skip to content

Instantly share code, notes, and snippets.

@shoyer
Created July 28, 2015 06:29
Show Gist options
  • Save shoyer/ae30a1200f749c84b4c4 to your computer and use it in GitHub Desktop.
Save shoyer/ae30a1200f749c84b4c4 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 399,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from xray.core.pycompat import OrderedDict\n",
"from xray.core.npcompat import stack\n",
"from scipy.spatial import cKDTree as KDTree\n",
"import numpy as np\n",
"import xray\n",
"\n",
"\n",
"def _as_kdtree_data(data):\n",
" data = np.asarray(data)\n",
" if data.ndim == 1:\n",
" data = data[:, np.newaxis]\n",
" return data\n",
"\n",
"\n",
"class KDTreeIndex(object):\n",
" \"\"\"A pandas.Index-like object that uses a KDTree for spatial indexing\n",
" \"\"\"\n",
" def __init__(self, data, leafsize=100):\n",
" self.data = _as_kdtree_data(data)\n",
" self.leafsize = leafsize\n",
" self._kdtree = None\n",
"\n",
" @property\n",
" def shape(self):\n",
" return len(self.data)\n",
" \n",
" @property\n",
" def dtype(self):\n",
" return self.data.dtype\n",
" \n",
" @property\n",
" def values(self):\n",
" # TODO: make immutable\n",
" return self.data\n",
"\n",
" @property\n",
" def kdtree(self):\n",
" if self._kdtree is None:\n",
" self._kdtree = KDTree(self.data, self.leafsize)\n",
" return self._kdtree\n",
" \n",
" def __repr__(self):\n",
" return '%r\\n%r' % (type(self), self.data)\n",
"\n",
" def get_indexer(self, target, max_distance=np.infty):\n",
" # TODO: handle other KDTreeIndex objects with a fast path\n",
" target = _as_kdtree_data(target)\n",
" _, indexer = self.kdtree.query(\n",
" target, distance_upper_bound=max_distance)\n",
" # like pandas Index objects, use -1 to mark missing values\n",
" indexer[indexer == len(self.kdtree.data)] = -1\n",
" return indexer\n",
" \n",
" def __getitem__(self, key):\n",
" if isinstance(key, tuple):\n",
" key, = key\n",
" return type(self)(self.data[key], self.leafsize)\n",
"\n",
" \n",
"def _shape(coords):\n",
" shapes = set(c.shape for c in coords)\n",
" if len(shapes) > 1:\n",
" raise ValueError('data has inconsistent shapes: %r' % shapes)\n",
" return shapes.pop()\n",
"\n",
"\n",
"def _stack(coords):\n",
" return stack([c.ravel() for c in coords], axis=-1)\n",
"\n",
"\n",
"def unravel_index_missing_okay(indexer, shape):\n",
" \"\"\"A variant of np.unravel_index that understands -1 as a sentinel\n",
" value for missing indexes.\n",
" \"\"\"\n",
" missing_values = indexer == -1\n",
" any_missing = np.any(missing_values)\n",
" if any_missing:\n",
" indexer[missing_values] = 0\n",
" indexers = np.unravel_index(indexer, shape)\n",
" if any_missing:\n",
" indexers = tuple(np.where(missing_values, -1, i)\n",
" for i in indexers)\n",
" return indexers\n",
"\n",
"\n",
"class NDIndex(object):\n",
" \"\"\"A wrapper of Index-like objects to make them handle N-dimensional\n",
" indexing.\n",
" \n",
" Currently designed as a wrapper for KDTreeIndex, but eventually might\n",
" handle other indexes for multi-dimensional data like Ball Trees or\n",
" R-Trees.\n",
" \"\"\"\n",
" def __init__(self, index_cls, coords):\n",
" coords = [np.asarray(c) for c in coords]\n",
" data = _stack(coords)\n",
" self._index_cls = index_cls\n",
" self._coords = coords\n",
" self._index = index_cls(data)\n",
" self.shape = _shape(coords)\n",
" \n",
" @property\n",
" def dtype(self):\n",
" return self._index.dtype\n",
" \n",
" @property\n",
" def values(self):\n",
" # TODO: cythonize?\n",
" data = np.empty(self.shape, dtype=object)\n",
" for idx in np.ndindex(*self.shape):\n",
" data[idx] = tuple(c[idx] for c in self._coords)\n",
" data.flags.writeable = False\n",
" return data\n",
"\n",
" def __repr__(self):\n",
" return ('%r\\nwrapping: %r\\n%r'\n",
" % (type(self), self._index_cls, self.values))\n",
"\n",
" def get_indexers(self, target, **kwargs):\n",
" \"\"\"Returns a tuple of indexers, suitable for use in isel_points\n",
" \n",
" Using this with isel_points will require that it is able to\n",
" handle N-dimensional arrays as input, but this already works\n",
" with NumPy's fancy (vectorized) indexing.\n",
" \"\"\"\n",
" target = [np.asarray(t) for t in target]\n",
" target_shape = _shape(target)\n",
" target_data = _stack(target)\n",
" indexer = self._index.get_indexer(target_data, **kwargs)\n",
" indexers = unravel_index_missing_okay(indexer, self.shape)\n",
" indexers = tuple(i.reshape(target_shape) for i in indexers)\n",
" return indexers\n",
"\n",
" def __getitem__(self, key):\n",
" return type(self)(self._index_cls,\n",
" [c[key] for c in self._coords])\n"
]
},
{
"cell_type": "code",
"execution_count": 400,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"index = NDIndex(KDTreeIndex, [[[0, 1, 2], [3, 4, 4]], [[-1, 2, 2.5], [3.5, 4, 4.2]]])"
]
},
{
"cell_type": "code",
"execution_count": 401,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<class '__main__.NDIndex'>\n",
"wrapping: <class '__main__.KDTreeIndex'>\n",
"array([[(0, -1.0), (1, 2.0), (2, 2.5)],\n",
" [(3, 3.5), (4, 4.0), (4, 4.2000000000000002)]], dtype=object)"
]
},
"execution_count": 401,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"index"
]
},
{
"cell_type": "code",
"execution_count": 402,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([[(0, -1.0), (1, 2.0), (2, 2.5)],\n",
" [(3, 3.5), (4, 4.0), (4, 4.2000000000000002)]], dtype=object)"
]
},
"execution_count": 402,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"index.values"
]
},
{
"cell_type": "code",
"execution_count": 403,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<xray.Dataset>\n",
"Dimensions: (x: 2, y: 3)\n",
"Coordinates:\n",
" grid (x, y) object (0, -1.0) (1, 2.0) (2, 2.5) (3, 3.5) (4, 4.0) ...\n",
" * x (x) int64 0 1\n",
" * y (y) int64 0 1 2\n",
"Data variables:\n",
" *empty*"
]
},
"execution_count": 403,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xray.Dataset({}, {'grid': (('x', 'y'), index)})"
]
},
{
"cell_type": "code",
"execution_count": 404,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<class '__main__.NDIndex'>\n",
"wrapping: <class '__main__.KDTreeIndex'>\n",
"array([[(2, 2.5), (1, 2.0), (0, -1.0)],\n",
" [(4, 4.2000000000000002), (4, 4.0), (3, 3.5)]], dtype=object)"
]
},
"execution_count": 404,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"index[:2, ::-1]"
]
},
{
"cell_type": "code",
"execution_count": 405,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"xi, yi = index.get_indexers([[[0, 1], [3, 4]],\n",
" [[-1, 3], [3, 1.5]]],\n",
" max_distance=2)"
]
},
{
"cell_type": "code",
"execution_count": 406,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0, 0],\n",
" [ 1, -1]])"
]
},
"execution_count": 406,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xi"
]
},
{
"cell_type": "code",
"execution_count": 407,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0, 1],\n",
" [ 0, -1]])"
]
},
"execution_count": 407,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"yi"
]
},
{
"cell_type": "code",
"execution_count": 408,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<class '__main__.NDIndex'>\n",
"wrapping: <class '__main__.KDTreeIndex'>\n",
"array([[(2, 2.5), (1, 2.0), (0, -1.0)],\n",
" [(4, 4.2000000000000002), (4, 4.0), (3, 3.5)]], dtype=object)"
]
},
"execution_count": 408,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"index[:2, ::-1]"
]
},
{
"cell_type": "code",
"execution_count": 409,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"xs, ys = index[:2, ::-1].get_indexers([[2.3, 4], [5, 4]])"
]
},
{
"cell_type": "code",
"execution_count": 410,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"(array([1, 1]), array([2, 1]))"
]
},
"execution_count": 410,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xs, ys"
]
},
{
"cell_type": "code",
"execution_count": 416,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<class '__main__.NDIndex'>\n",
"wrapping: <class '__main__.KDTreeIndex'>\n",
"array([(3, 3.5), (4, 4.0)], dtype=object)"
]
},
"execution_count": 416,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"index[:2, ::-1][xs, ys]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment