|
import dask.array as darray |
|
import numpy as np |
|
import xarray as xr |
|
|
|
from spharm import Spharmt |
|
|
|
|
|
LAT_STR = 'lat' |
|
LON_STR = 'lon' |
|
_HORIZONTAL_DIMS = (LAT_STR, LON_STR) |
|
_NON_HORIZONTAL_DIM = 'non_horizontal' |
|
_HARMONIC_DIM = 'harmonic' |
|
RADIUS = 6370997. |
|
|
|
|
|
def wraps_dask_array(arr): |
|
return isinstance(arr.data, darray.core.Array) |
|
|
|
|
|
def non_horizontal_dims(arr): |
|
return set(arr.dims) - set(_HORIZONTAL_DIMS) |
|
|
|
|
|
def stack_non_horizontal_dims(arr): |
|
"""If present, stack all non-horizontal dims onto one dimension""" |
|
dims_to_stack = tuple(non_horizontal_dims(arr)) |
|
if dims_to_stack: |
|
arr = arr.stack(**{_NON_HORIZONTAL_DIM: dims_to_stack}) |
|
return arr |
|
|
|
|
|
def order_dims_for_spharm(arr): |
|
"""Order dims such that lat and lon come first""" |
|
order = _HORIZONTAL_DIMS + tuple(non_horizontal_dims(arr)) |
|
return arr.transpose(*order) |
|
|
|
|
|
def chunk_in_spherical_shells(arr): |
|
"""If underlying data are chunked, rechunk to spherical shells""" |
|
if wraps_dask_array(arr): |
|
arr = arr.chunk({LAT_STR: arr.sizes[LAT_STR], |
|
LON_STR: arr.sizes[LON_STR]}) |
|
return arr |
|
|
|
|
|
def flip_lat(arr): |
|
return arr.isel(**{LAT_STR: slice(None, None, -1)}) |
|
|
|
|
|
def orient_latitude_north_south(arr): |
|
"""Orients data such that northern latitudes come first |
|
|
|
Returns the transformed array as well as flag noting if the data were |
|
flipped. |
|
|
|
Parameters |
|
---------- |
|
arr : xr.DataArray |
|
Input DataArray |
|
|
|
Returns |
|
------- |
|
xr.DataArray, bool |
|
""" |
|
if all(arr[LAT_STR].diff(LAT_STR) > 0.): |
|
return flip_lat(arr), True |
|
else: |
|
return arr, False |
|
|
|
|
|
def prep_for_spharm(arr): |
|
"""Prepare DataArray for use with spharm |
|
|
|
Parameters |
|
---------- |
|
arr : xr.DataArray |
|
Input DataArray |
|
|
|
Returns |
|
------- |
|
xr.DataArray, bool |
|
""" |
|
arr = stack_non_horizontal_dims(arr) |
|
arr = order_dims_for_spharm(arr) |
|
arr = chunk_in_spherical_shells(arr) |
|
return orient_latitude_north_south(arr) |
|
|
|
|
|
def create_spharmt(arr): |
|
return Spharmt(arr.sizes[LON_STR], arr.sizes[LAT_STR], rsphere=RADIUS, |
|
gridtype='gaussian') |
|
|
|
|
|
def n_harmonics(arr, n_trunc=None): |
|
if n_trunc is None: |
|
n_trunc = arr.sizes[LAT_STR] - 1 |
|
return (n_trunc + 1) * (n_trunc + 2) // 2 |
|
|
|
|
|
def _grdtospec(st, arr, harmonics): |
|
"""Wrap Spharmt.grdtospec to be dask compatible""" |
|
if isinstance(arr, darray.core.Array): |
|
if arr.ndim == 3: |
|
chunks = ((harmonics, ), arr.chunks[-1]) |
|
else: |
|
chunks = ((harmonics, )) |
|
return darray.map_blocks(st.grdtospec, arr, chunks=chunks, |
|
dtype=np.complex, drop_axis=(0, 1), |
|
new_axis=(0, )) |
|
else: |
|
return st.grdtospec(arr) |
|
|
|
|
|
def grdtospec(arr, prepped=False): |
|
"""Transform data from grid space to spectral space |
|
|
|
Assumes data are on a Gaussian grid. |
|
""" |
|
if not prepped: |
|
arr, flipped = prep_for_spharm(arr) |
|
|
|
st = create_spharmt(arr) |
|
_n_harmonics = n_harmonics(arr) |
|
|
|
if _NON_HORIZONTAL_DIM in arr.dims: |
|
output_core_dims = [[_HARMONIC_DIM, _NON_HORIZONTAL_DIM]] |
|
output_sizes = {_HARMONIC_DIM: _n_harmonics, |
|
_NON_HORIZONTAL_DIM: arr.sizes[_NON_HORIZONTAL_DIM]} |
|
else: |
|
output_core_dims = [[_HARMONIC_DIM]] |
|
output_sizes = {_HARMONIC_DIM: _n_harmonics} |
|
|
|
return xr.apply_ufunc(_grdtospec, st, arr, _n_harmonics, |
|
input_core_dims=[[], arr.dims, []], |
|
output_core_dims=output_core_dims, |
|
output_sizes=output_sizes, |
|
exclude_dims=set(_HORIZONTAL_DIMS), dask='allowed') |
|
|
|
|
|
def _getu(st, vort, divg): |
|
"""Wrap getuv to only return u""" |
|
u, _ = st.getuv(vort, divg) |
|
return u |
|
|
|
|
|
def _getv(st, vort, divg): |
|
"""Wrap getuv to only return v""" |
|
_, v = st.getuv(vort, divg) |
|
return v |
|
|
|
|
|
def _getu_dask_allowed(st, vort, divg): |
|
"""Wrap getuv to accept dask arrays and return just u""" |
|
if isinstance(vort, darray.core.Array): |
|
if vort.ndim == 2: |
|
chunks = ((st.nlat, ), (st.nlon, ), vort.chunks[-1]) |
|
else: |
|
chunks = ((st.nlat, ), (st.nlon, )) |
|
return darray.map_blocks(_getu, st, vort, divg, chunks=chunks, |
|
drop_axis=(0, ), new_axis=(0, 1)) |
|
else: |
|
return _getu(st, vort, divg) |
|
|
|
|
|
def _getv_dask_allowed(st, vort, divg): |
|
"""Wrap getuv to accept dask arrays and return just v""" |
|
if isinstance(vort, darray.core.Array): |
|
if vort.ndim == 2: |
|
chunks = ((st.nlat, ), (st.nlon, ), vort.chunks[-1]) |
|
else: |
|
chunks = ((st.nlat, ), (st.nlon, )) |
|
return darray.map_blocks(_getv, st, vort, divg, chunks=chunks, |
|
drop_axis=(0, ), new_axis=(0, 1)) |
|
else: |
|
return _getv(st, vort, divg) |
|
|
|
|
|
def getuv(vort_grid, divg_grid): |
|
"""Wrap getuv to accept *gridded* vorticity and divergence |
|
|
|
This allows us to more readily reattach latitude and longitude coordinates |
|
to the results. |
|
""" |
|
st = create_spharmt(vort_grid) |
|
orig_coords = vort_grid.coords |
|
|
|
vort_grid, _ = prep_for_spharm(vort_grid) |
|
divg_grid, flipped = prep_for_spharm(divg_grid) |
|
vort_spec = grdtospec(vort_grid, prepped=True) |
|
divg_spec = grdtospec(divg_grid, prepped=True) |
|
|
|
common_kwargs = {'input_core_dims': [[], vort_spec.dims, divg_spec.dims], |
|
'output_core_dims': [divg_grid.dims], |
|
'output_sizes': divg_grid.sizes, |
|
'dask': 'allowed'} |
|
|
|
u = xr.apply_ufunc(_getu_dask_allowed, st, vort_spec, divg_spec, |
|
**common_kwargs) |
|
v = xr.apply_ufunc(_getv_dask_allowed, st, vort_spec, divg_spec, |
|
**common_kwargs) |
|
|
|
if _NON_HORIZONTAL_DIM in u.dims: |
|
u = u.unstack(_NON_HORIZONTAL_DIM) |
|
v = v.unstack(_NON_HORIZONTAL_DIM) |
|
|
|
if flipped: |
|
u = flip_lat(u) |
|
v = flip_lat(v) |
|
|
|
for coord in orig_coords: |
|
u[coord] = orig_coords[coord] |
|
v[coord] = orig_coords[coord] |
|
|
|
return u, v |
|
|
|
|
|
def _getvrt(st, u, v): |
|
"""Return the spectral version of the vorticty given gridded u, v""" |
|
vort, _ = st.getvrtdivspec(u, v) |
|
return vort |
|
|
|
|
|
def _getdiv(st, u, v): |
|
"""Return the spectral version of the divergence given gridded u, v""" |
|
_, div = st.getvrtdivspec(u, v) |
|
return div |
|
|
|
|
|
def _getvrt_dask_allowed(st, u, v, harmonics): |
|
if isinstance(u, darray.core.Array): |
|
if u.ndim == 3: |
|
chunks = ((harmonics, ), u.chunks[-1]) |
|
else: |
|
chunks = ((harmonics, )) |
|
return darray.map_blocks(_getvrt, st, u, v, chunks=chunks, |
|
dtype=np.complex, drop_axis=(0, 1), |
|
new_axis=(0,)) |
|
else: |
|
return _getvrt(st, u, v) |
|
|
|
|
|
def _getdiv_dask_allowed(st, u, v, harmonics): |
|
if isinstance(u, darray.core.Array): |
|
if u.ndim == 3: |
|
chunks = ((harmonics, ), u.chunks[-1]) |
|
else: |
|
chunks = ((harmonics, )) |
|
return darray.map_blocks(_getdiv, st, u, v, chunks=chunks, |
|
dtype=np.complex, drop_axis=(0, 1), |
|
new_axis=(0,)) |
|
else: |
|
return _getdiv(st, u, v) |
|
|
|
|
|
def getvrtdivspec(u_grid, v_grid): |
|
"""Wrap getvrtdiv to return vorticity and divergence in spectral space""" |
|
st = create_spharmt(u_grid) |
|
|
|
u_grid, _ = prep_for_spharm(u_grid) |
|
v_grid, flipped = prep_for_spharm(v_grid) |
|
|
|
_n_harmonics = n_harmonics(u_grid) |
|
|
|
if _NON_HORIZONTAL_DIM in u_grid.dims: |
|
output_core_dims = [[_HARMONIC_DIM, _NON_HORIZONTAL_DIM]] |
|
output_sizes = {_HARMONIC_DIM: _n_harmonics, |
|
_NON_HORIZONTAL_DIM: u_grid.sizes[_NON_HORIZONTAL_DIM]} |
|
else: |
|
output_core_dims = [[_HARMONIC_DIM]] |
|
output_sizes = {_HARMONIC_DIM: _n_harmonics} |
|
|
|
common_kwargs = {'input_core_dims': [[], u_grid.dims, v_grid.dims, []], |
|
'output_core_dims': output_core_dims, |
|
'exclude_dims': set(_HORIZONTAL_DIMS), |
|
'output_sizes': output_sizes, |
|
'dask': 'allowed'} |
|
|
|
vort_spec = xr.apply_ufunc(_getvrt_dask_allowed, st, u_grid, v_grid, |
|
_n_harmonics, **common_kwargs) |
|
divg_spec = xr.apply_ufunc(_getdiv_dask_allowed, st, u_grid, v_grid, |
|
_n_harmonics, **common_kwargs) |
|
|
|
return vort_spec, divg_spec |
|
|
|
|
|
def _spectogrd(st, arr): |
|
"""Transform a variable from spectral to grid space""" |
|
if isinstance(arr, darray.core.Array): |
|
if arr.ndim == 2: |
|
chunks = ((st.nlat, ), (st.nlon, ), arr.chunks[-1]) |
|
else: |
|
chunks = ((st.nlat, ), (st.nlon, )) |
|
return darray.map_blocks(st.spectogrd, arr, chunks=chunks, |
|
drop_axis=(0, ), new_axis=(0, 1)) |
|
else: |
|
return st.spectogrd(arr) |
|
|
|
|
|
def getvrtdivgrid(u_grid, v_grid): |
|
"""Wrap getvrtdiv to return *gridded* vorticity and divergence""" |
|
st = create_spharmt(u_grid) |
|
orig_coords = u_grid.coords |
|
vort_spec, divg_spec = getvrtdivspec(u_grid, v_grid) |
|
v_grid, flipped = prep_for_spharm(v_grid) |
|
|
|
common_kwargs = {'input_core_dims': [[], vort_spec.dims], |
|
'output_core_dims': [v_grid.dims], |
|
'output_sizes': u_grid.sizes, |
|
'dask': 'allowed'} |
|
|
|
vort_grid = xr.apply_ufunc(_spectogrd, st, vort_spec, **common_kwargs) |
|
divg_grid = xr.apply_ufunc(_spectogrd, st, divg_spec, **common_kwargs) |
|
|
|
if flipped: |
|
vort_grid = flip_lat(vort_grid) |
|
divg_grid = flip_lat(divg_grid) |
|
|
|
for coord in _HORIZONTAL_DIMS: |
|
vort_grid[coord] = orig_coords[coord] |
|
divg_grid[coord] = orig_coords[coord] |
|
|
|
if _NON_HORIZONTAL_DIM in vort_grid.dims: |
|
return (vort_grid.unstack(_NON_HORIZONTAL_DIM), |
|
divg_grid.unstack(_NON_HORIZONTAL_DIM)) |
|
else: |
|
return vort_grid, divg_grid |