Skip to content

Instantly share code, notes, and snippets.

@j08lue
Created April 1, 2016 09:33
Show Gist options
  • Save j08lue/358568a103165e6de058d34538d5fbe3 to your computer and use it in GitHub Desktop.
Save j08lue/358568a103165e6de058d34538d5fbe3 to your computer and use it in GitHub Desktop.
import numpy as np
from scipy.spatial import cKDTree as KDTree
import xarray as xr
def grid_to_points(grid, points, coord_names):
"""Index a gridded dataset with a Pandas DataFrame of station coordinates
grid : xr.Dataset or xr.DataArray
gridded source data
points : pd.Dataframe
query points
coord_names : list of str
coordinate index names in `points`
Credits:
https://github.com/pydata/xarray/issues/214#issuecomment-119036789
"""
if not coord_names:
raise ValueError("No coordinate names provided")
spat_dims = {d for n in coord_names for d in grid[n].dims}
not_spatial = set(grid.dims) - spat_dims
spatial_selection = {n:0 for n in not_spatial}
spat_only = grid.isel(**spatial_selection)
coords = bcast(spat_only, coord_names)
kd = KDTree(list(zip(*[c.ravel() for c in coords])))
_, indx = kd.query(list(zip(*[points[n].values for n in coord_names])))
indx = np.unravel_index(indx, coords[0].shape)
station_da = xr.DataArray(name='station', dims='station', data=stations.index.values)
return xray.concat(
(grid.isel(**{n:j for n, j in zip(spat_only.dims, i)})
for i in zip(*indx)),
dim=station_da)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment