import dask.array as darray |
import numpy as np |
import xarray as xr |
from spharm import Spharmt |
LAT_STR = 'lat' |
LON_STR = 'lon' |
_NON_HORIZONTAL_DIM = 'non_horizontal' |
_HARMONIC_DIM = 'harmonic' |
RADIUS = 6370997. |
def wraps_dask_array(arr): |
return isinstance(arr.data, darray.core.Array) |
def non_horizontal_dims(arr): |
return set(arr.dims) - set(_HORIZONTAL_DIMS) |
def stack_non_horizontal_dims(arr): |
"""If present, stack all non-horizontal dims onto one dimension""" |
dims_to_stack = tuple(non_horizontal_dims(arr)) |
if dims_to_stack: |
arr = arr.stack(**{_NON_HORIZONTAL_DIM: dims_to_stack}) |
return arr |
def order_dims_for_spharm(arr): |
"""Order dims such that lat and lon come first""" |
order = _HORIZONTAL_DIMS + tuple(non_horizontal_dims(arr)) |
return arr.transpose(*order) |
def chunk_in_spherical_shells(arr): |
"""If underlying data are chunked, rechunk to spherical shells""" |
if wraps_dask_array(arr): |
arr = arr.chunk({LAT_STR: arr.sizes[LAT_STR], |
LON_STR: arr.sizes[LON_STR]}) |
return arr |
def flip_lat(arr): |
return arr.isel(**{LAT_STR: slice(None, None, -1)}) |
def orient_latitude_north_south(arr): |
"""Orients data such that northern latitudes come first |
Returns the transformed array as well as flag noting if the data were |
flipped. |
Parameters |
---------- |
arr : xr.DataArray |
Input DataArray |
Returns |
------- |
xr.DataArray, bool |
""" |
if all(arr[LAT_STR].diff(LAT_STR) > 0.): |
return flip_lat(arr), True |
else: |
return arr, False |
def prep_for_spharm(arr): |
"""Prepare DataArray for use with spharm |
Parameters |
---------- |
arr : xr.DataArray |
Input DataArray |
Returns |
------- |
xr.DataArray, bool |
""" |
arr = stack_non_horizontal_dims(arr) |
arr = order_dims_for_spharm(arr) |
arr = chunk_in_spherical_shells(arr) |
return orient_latitude_north_south(arr) |
def create_spharmt(arr): |
return Spharmt(arr.sizes[LON_STR], arr.sizes[LAT_STR], rsphere=RADIUS, |
gridtype='gaussian') |
def n_harmonics(arr, n_trunc=None): |
if n_trunc is None: |
n_trunc = arr.sizes[LAT_STR] - 1 |
return (n_trunc + 1) * (n_trunc + 2) // 2 |
def _grdtospec(st, arr, harmonics): |
"""Wrap Spharmt.grdtospec to be dask compatible""" |
if isinstance(arr, darray.core.Array): |
if arr.ndim == 3: |
chunks = ((harmonics, ), arr.chunks[-1]) |
else: |
chunks = ((harmonics, )) |
return darray.map_blocks(st.grdtospec, arr, chunks=chunks, |
dtype=np.complex, drop_axis=(0, 1), |
new_axis=(0, )) |
else: |
return st.grdtospec(arr) |
def grdtospec(arr, prepped=False): |
"""Transform data from grid space to spectral space |
Assumes data are on a Gaussian grid. |
""" |
if not prepped: |
arr, flipped = prep_for_spharm(arr) |
st = create_spharmt(arr) |
_n_harmonics = n_harmonics(arr) |
if _NON_HORIZONTAL_DIM in arr.dims: |
output_core_dims = [[_HARMONIC_DIM, _NON_HORIZONTAL_DIM]] |
output_sizes = {_HARMONIC_DIM: _n_harmonics, |
else: |
output_core_dims = [[_HARMONIC_DIM]] |
output_sizes = {_HARMONIC_DIM: _n_harmonics} |
return xr.apply_ufunc(_grdtospec, st, arr, _n_harmonics, |
input_core_dims=[[], arr.dims, []], |
output_core_dims=output_core_dims, |
output_sizes=output_sizes, |
exclude_dims=set(_HORIZONTAL_DIMS), dask='allowed') |
def _getu(st, vort, divg): |
"""Wrap getuv to only return u""" |
u, _ = st.getuv(vort, divg) |
return u |
def _getv(st, vort, divg): |
"""Wrap getuv to only return v""" |
_, v = st.getuv(vort, divg) |
return v |
def _getu_dask_allowed(st, vort, divg): |
"""Wrap getuv to accept dask arrays and return just u""" |
if isinstance(vort, darray.core.Array): |
if vort.ndim == 2: |
chunks = ((st.nlat, ), (st.nlon, ), vort.chunks[-1]) |
else: |
chunks = ((st.nlat, ), (st.nlon, )) |
return darray.map_blocks(_getu, st, vort, divg, chunks=chunks, |
drop_axis=(0, ), new_axis=(0, 1)) |
else: |
return _getu(st, vort, divg) |
def _getv_dask_allowed(st, vort, divg): |
"""Wrap getuv to accept dask arrays and return just v""" |
if isinstance(vort, darray.core.Array): |
if vort.ndim == 2: |
chunks = ((st.nlat, ), (st.nlon, ), vort.chunks[-1]) |
else: |
chunks = ((st.nlat, ), (st.nlon, )) |
return darray.map_blocks(_getv, st, vort, divg, chunks=chunks, |
drop_axis=(0, ), new_axis=(0, 1)) |
else: |
return _getv(st, vort, divg) |
def getuv(vort_grid, divg_grid): |
"""Wrap getuv to accept *gridded* vorticity and divergence |
This allows us to more readily reattach latitude and longitude coordinates |
to the results. |
""" |
st = create_spharmt(vort_grid) |
orig_coords = vort_grid.coords |
vort_grid, _ = prep_for_spharm(vort_grid) |
divg_grid, flipped = prep_for_spharm(divg_grid) |
vort_spec = grdtospec(vort_grid, prepped=True) |
divg_spec = grdtospec(divg_grid, prepped=True) |
common_kwargs = {'input_core_dims': [[], vort_spec.dims, divg_spec.dims], |
'output_core_dims': [divg_grid.dims], |
'output_sizes': divg_grid.sizes, |
'dask': 'allowed'} |
u = xr.apply_ufunc(_getu_dask_allowed, st, vort_spec, divg_spec, |
**common_kwargs) |
v = xr.apply_ufunc(_getv_dask_allowed, st, vort_spec, divg_spec, |
**common_kwargs) |
if _NON_HORIZONTAL_DIM in u.dims: |
u = u.unstack(_NON_HORIZONTAL_DIM) |
v = v.unstack(_NON_HORIZONTAL_DIM) |
if flipped: |
u = flip_lat(u) |
v = flip_lat(v) |
for coord in orig_coords: |
u[coord] = orig_coords[coord] |
v[coord] = orig_coords[coord] |
return u, v |
def _getvrt(st, u, v): |
"""Return the spectral version of the vorticty given gridded u, v""" |
vort, _ = st.getvrtdivspec(u, v) |
return vort |
def _getdiv(st, u, v): |
"""Return the spectral version of the divergence given gridded u, v""" |
_, div = st.getvrtdivspec(u, v) |
return div |
def _getvrt_dask_allowed(st, u, v, harmonics): |
if isinstance(u, darray.core.Array): |
if u.ndim == 3: |
chunks = ((harmonics, ), u.chunks[-1]) |
else: |
chunks = ((harmonics, )) |
return darray.map_blocks(_getvrt, st, u, v, chunks=chunks, |
dtype=np.complex, drop_axis=(0, 1), |
new_axis=(0,)) |
else: |
return _getvrt(st, u, v) |
def _getdiv_dask_allowed(st, u, v, harmonics): |
if isinstance(u, darray.core.Array): |
if u.ndim == 3: |
chunks = ((harmonics, ), u.chunks[-1]) |
else: |
chunks = ((harmonics, )) |
return darray.map_blocks(_getdiv, st, u, v, chunks=chunks, |
dtype=np.complex, drop_axis=(0, 1), |
new_axis=(0,)) |
else: |
return _getdiv(st, u, v) |
def getvrtdivspec(u_grid, v_grid): |
"""Wrap getvrtdiv to return vorticity and divergence in spectral space""" |
st = create_spharmt(u_grid) |
u_grid, _ = prep_for_spharm(u_grid) |
v_grid, flipped = prep_for_spharm(v_grid) |
_n_harmonics = n_harmonics(u_grid) |
if _NON_HORIZONTAL_DIM in u_grid.dims: |
output_core_dims = [[_HARMONIC_DIM, _NON_HORIZONTAL_DIM]] |
output_sizes = {_HARMONIC_DIM: _n_harmonics, |
else: |
output_core_dims = [[_HARMONIC_DIM]] |
output_sizes = {_HARMONIC_DIM: _n_harmonics} |
common_kwargs = {'input_core_dims': [[], u_grid.dims, v_grid.dims, []], |
'output_core_dims': output_core_dims, |
'exclude_dims': set(_HORIZONTAL_DIMS), |
'output_sizes': output_sizes, |
'dask': 'allowed'} |
vort_spec = xr.apply_ufunc(_getvrt_dask_allowed, st, u_grid, v_grid, |
_n_harmonics, **common_kwargs) |
divg_spec = xr.apply_ufunc(_getdiv_dask_allowed, st, u_grid, v_grid, |
_n_harmonics, **common_kwargs) |
return vort_spec, divg_spec |
def _spectogrd(st, arr): |
"""Transform a variable from spectral to grid space""" |
if isinstance(arr, darray.core.Array): |
if arr.ndim == 2: |
chunks = ((st.nlat, ), (st.nlon, ), arr.chunks[-1]) |
else: |
chunks = ((st.nlat, ), (st.nlon, )) |
return darray.map_blocks(st.spectogrd, arr, chunks=chunks, |
drop_axis=(0, ), new_axis=(0, 1)) |
else: |
return st.spectogrd(arr) |
def getvrtdivgrid(u_grid, v_grid): |
"""Wrap getvrtdiv to return *gridded* vorticity and divergence""" |
st = create_spharmt(u_grid) |
orig_coords = u_grid.coords |
vort_spec, divg_spec = getvrtdivspec(u_grid, v_grid) |
v_grid, flipped = prep_for_spharm(v_grid) |
common_kwargs = {'input_core_dims': [[], vort_spec.dims], |
'output_core_dims': [v_grid.dims], |
'output_sizes': u_grid.sizes, |
'dask': 'allowed'} |
vort_grid = xr.apply_ufunc(_spectogrd, st, vort_spec, **common_kwargs) |
divg_grid = xr.apply_ufunc(_spectogrd, st, divg_spec, **common_kwargs) |
if flipped: |
vort_grid = flip_lat(vort_grid) |
divg_grid = flip_lat(divg_grid) |
for coord in _HORIZONTAL_DIMS: |
vort_grid[coord] = orig_coords[coord] |
divg_grid[coord] = orig_coords[coord] |
if _NON_HORIZONTAL_DIM in vort_grid.dims: |
return (vort_grid.unstack(_NON_HORIZONTAL_DIM), |
divg_grid.unstack(_NON_HORIZONTAL_DIM)) |
else: |
return vort_grid, divg_grid |