|
import dask.array as darray |
|
import numpy as np |
|
import xarray as xr |
|
|
|
from collections import defaultdict |
|
|
|
|
|
RADIUS = 6.371e6 # Radius of the Earth in meters |
|
LAT_DIM = "lat" |
|
LON_DIM = "lon" |
|
LAT_BOUNDS_DIM = "latb" |
|
LON_BOUNDS_DIM = "lonb" |
|
TIME_DIM = "time" |
|
|
|
|
|
def sfc_area(latb, lonb, lat, lon): |
|
"""Compute the surface area given. |
|
|
|
Parameters |
|
---------- |
|
latb : xr.DataArray |
|
Latitude bounds of the grid. |
|
lonb : xr.DataArray |
|
Longitude bounds of the grid. |
|
lat : xr.DataArray |
|
Latitude of the grid cell centers. |
|
lon : xr.DataArray |
|
Longitude of the grid cell centers. |
|
|
|
Returns |
|
------- |
|
xr.DataArray |
|
""" |
|
dsinlat = np.sin(np.deg2rad(latb)).diff(LAT_BOUNDS_DIM) |
|
dsinlat = dsinlat.rename({LAT_BOUNDS_DIM: LAT_DIM}).drop(LAT_DIM) |
|
dlon = ( |
|
np.deg2rad(lonb) |
|
.diff(LON_BOUNDS_DIM) |
|
.rename({LON_BOUNDS_DIM: LON_DIM}) |
|
.drop(LON_DIM) |
|
) |
|
area = dsinlat * dlon * RADIUS ** 2.0 |
|
area[LON_DIM] = lon |
|
area[LAT_DIM] = lat |
|
return area |
|
|
|
|
|
def standardize(da, dim): |
|
"""Standardize a DataArray or Dataset along a dimension. |
|
|
|
Parameters |
|
---------- |
|
da : xr.DataArray or xr.Dataset |
|
Input data. |
|
dim : Hashable |
|
Dimension name to standardize along. |
|
|
|
Returns |
|
------- |
|
xr.DataArray or xr.Dataset |
|
""" |
|
return (da - da.mean(dim)) / da.std(dim) |
|
|
|
|
|
def _fft(arr, *args, **kwargs): |
|
if isinstance(arr, darray.core.Array): |
|
# arr = arr.rechunk({0: 1, arr.ndim - 1: arr.shape[-1]}) |
|
return darray.fft.fft(arr, *args, **kwargs) |
|
else: |
|
return np.fft.fft(arr, *args, **kwargs) |
|
|
|
|
|
def fft(da, dim, n=None, d=1.0, transformed_dim_name="freq", invert_freq=False): |
|
"""Compute the one-dimensional discrete Fourier Transform |
|
|
|
Proceeds as follows: |
|
1. Compute the transform |
|
2. Rename the dimension it was applied along |
|
3. Compute the frequencies |
|
4. Shift the frequencies and data such the the zero frequency is in the |
|
center |
|
|
|
Parameters |
|
---------- |
|
da : xr.DataArray |
|
Input DataArray, can be complex |
|
dim : str |
|
Dimension name upon which to compute the transform |
|
n : int, optional (default None) |
|
Length of the transformed axis of the output. If `n` is smaller than |
|
the length of the input, the input is cropped. If it is larger, the |
|
input is padded with zeros. If `n` is not given, the length of the |
|
input along the axis specified by the dimension used. |
|
d : scalar |
|
Sample spacing (inverse of the sampling rate). Defaults to 1.0. |
|
transformed_dim_name : str |
|
Dimension name after transform. |
|
invert_freq : bool |
|
Whether to scale the inferred frequency by minus one (e.g. if by |
|
convention we write the frequency as -omega). |
|
|
|
Returns |
|
------- |
|
xr.DataArray |
|
The truncated or zero-padded input, transformed along the given |
|
dimension, with coordinates computed by `np.fftfreq`. Frequencies are |
|
shifted such that the zero frequency occurs in the middle. |
|
""" |
|
result = xr.apply_ufunc( |
|
_fft, |
|
da, |
|
input_core_dims=[(dim,)], |
|
output_core_dims=[(dim,)], |
|
kwargs={"n": n}, |
|
dask="allowed", |
|
) |
|
result = result.rename({dim: transformed_dim_name}) |
|
size = result.sizes[transformed_dim_name] |
|
result[transformed_dim_name] = np.fft.fftfreq(size, d) |
|
result[transformed_dim_name] = fftshift( |
|
result[transformed_dim_name], transformed_dim_name |
|
) |
|
if invert_freq: |
|
result[transformed_dim_name] = -result[transformed_dim_name] |
|
return fftshift(result, transformed_dim_name) |
|
|
|
|
|
def _fftshift(arr, *args, **kwargs): |
|
if isinstance(arr, darray.core.Array): |
|
return darray.fft.fftshift(arr, *args, **kwargs) |
|
else: |
|
return np.fft.fftshift(arr, *args, **kwargs) |
|
|
|
|
|
def fftshift(da, dim): |
|
"""Shift the zero-frequency component to the center of the spectrum |
|
|
|
Parameters |
|
---------- |
|
da : xr.DataArray |
|
Input DataArray |
|
dim : str |
|
Dimension over which to apply the shift |
|
|
|
Returns |
|
------- |
|
xr.DataArray |
|
""" |
|
return xr.apply_ufunc( |
|
_fftshift, |
|
da, |
|
input_core_dims=[(dim,)], |
|
output_core_dims=[(dim,)], |
|
kwargs={"axes": -1}, |
|
dask="allowed", |
|
) |
|
|
|
|
|
def _ifftshift(arr, *args, **kwargs): |
|
if isinstance(arr, darray.core.Array): |
|
# arr = arr.rechunk({0: 1, arr.ndim - 1: arr.shape[-1]}) |
|
return darray.fft.ifftshift(arr, *args, **kwargs) |
|
else: |
|
return np.fft.ifftshift(arr, *args, **kwargs) |
|
|
|
|
|
def ifftshift(da, dim): |
|
"""Inverse fftshift |
|
|
|
Parameters |
|
---------- |
|
da : xr.DataArray |
|
Input DataArray |
|
dim : str |
|
Dimension over which to apply the inverse shift |
|
|
|
Returns |
|
------- |
|
xr.DataArray |
|
""" |
|
return xr.apply_ufunc( |
|
_ifftshift, |
|
da, |
|
input_core_dims=[(dim,)], |
|
output_core_dims=[(dim,)], |
|
kwargs={"axes": -1}, |
|
dask="allowed", |
|
) |
|
|
|
|
|
def _ifft(arr, *args, **kwargs): |
|
if isinstance(arr, darray.core.Array): |
|
return darray.fft.ifft(arr, *args, **kwargs) |
|
else: |
|
return np.fft.ifft(arr, *args, **kwargs) |
|
|
|
|
|
def ifft(da, dim, n=None, norm=None): |
|
"""Compute the one-dimensional inverse discrete Fourier Transform. |
|
|
|
Parameters |
|
---------- |
|
da : xr.DataArray |
|
Input DataArray |
|
dim : str |
|
Dimension name |
|
n : int, optional |
|
Length of the transformed axis of the output |
|
|
|
Returns |
|
------- |
|
xr.DataArray |
|
""" |
|
return xr.apply_ufunc( |
|
_ifft, |
|
da, |
|
input_core_dims=[(dim,)], |
|
output_core_dims=[(dim,)], |
|
kwargs={"n": n, "axis": -1}, |
|
dask="allowed", |
|
) |
|
|
|
|
|
def sample_spacing(da, dim, scale=360.0): |
|
"""Compute the sample spacing along a given dimension for use |
|
in fftfreq. |
|
|
|
Parameters |
|
---------- |
|
da : xr.DataArray |
|
Input DataArray |
|
dim : str |
|
Name of dimension |
|
scale : float |
|
Scale factor to divide by (e.g. 360.0 for longitude, 2.0 * np.pi |
|
for time |
|
|
|
Returns |
|
------- |
|
float |
|
""" |
|
return (da[dim].isel(**{dim: 1}) - da[dim].isel(**{dim: 0})).item() / scale |
|
|
|
|
|
def filter( |
|
da, |
|
dim, |
|
n=None, |
|
d=1.0, |
|
invert_freq=False, |
|
bounds=(-np.inf, np.inf), |
|
absolute_value=False, |
|
): |
|
"""Filter DataArray along a given dimension by masking out certain |
|
modes. |
|
|
|
Parameters |
|
---------- |
|
da : xr.DataArray |
|
Input DataArray, can be complex |
|
dim : str |
|
Dimension name upon which to compute the transform |
|
n : int, optional (default None) |
|
Length of the transformed axis of the output. If `n` is smaller than |
|
the length of the input, the input is cropped. If it is larger, the |
|
input is padded with zeros. If `n` is not given, the length of the |
|
input along the axis specified by the dimension used. |
|
d : scalar |
|
Sample spacing (inverse of the sampling rate). Defaults to 1.0. |
|
invert_freq : bool |
|
Invert the transformed dimension |
|
bounds : tuple |
|
Range of wavenumbers or frequencies to keep |
|
absolute_value : bool (default False) |
|
Whether the condition should apply to the absolute value of the |
|
frequency. |
|
""" |
|
lower_bound, upper_bound = bounds |
|
transformed = fft(da, dim, n=n, d=d, invert_freq=invert_freq) |
|
if not absolute_value: |
|
condition = (transformed.freq >= lower_bound) & ( |
|
transformed.freq <= upper_bound |
|
) |
|
else: |
|
condition = (np.abs(transformed.freq) >= lower_bound) & ( |
|
np.abs(transformed.freq) <= upper_bound |
|
) |
|
filtered = transformed.where(condition).fillna(0.0) |
|
inverted = ifft(ifftshift(filtered, "freq").chunk({"freq": -1}), "freq").rename( |
|
{"freq": dim} |
|
) |
|
inverted[dim] = da[dim] |
|
return inverted |
|
|
|
|
|
def _matmul(arr1, arr2): |
|
if isinstance(arr1, darray.core.Array): |
|
return darray.matmul(arr1, arr2) |
|
else: |
|
return np.matmul(arr1, arr2) |
|
|
|
|
|
def regress(ds, index, sampling_dim): |
|
"""Regress a variable against a given index along a sampling dimension. |
|
|
|
Assumes the dimension of the index is the sampling dimension. |
|
|
|
Parameters |
|
---------- |
|
ds : xr.Dataset |
|
Input Dataset |
|
index : xr.DataArray |
|
Index DataArray; must be 1D |
|
sampling_dim : xr.DataArray |
|
Dimension name |
|
|
|
Returns |
|
------- |
|
xr.Dataset or xr.DataArray |
|
""" |
|
if isinstance(ds, xr.DataArray): |
|
ds = ds.to_dataset() |
|
|
|
structure_dims = defaultdict(list) |
|
for var in ds.data_vars: |
|
sdims = set(ds[var].dims) - set([sampling_dim]) |
|
sdims = tuple(sorted(tuple(sdims))) |
|
structure_dims[sdims].append(var) |
|
|
|
datasets = [] |
|
for sdims, variables in structure_dims.items(): |
|
structure = ds.get(variables) |
|
structure = structure.stack(structure=tuple(sdims)) |
|
structure = structure.transpose("structure", sampling_dim) |
|
|
|
# Chunk to preserve block size |
|
structure_chunk_sizes = [ds.chunks[dim][0] for dim in sdims] |
|
structure_chunk_size = np.product(structure_chunk_sizes) |
|
structure = structure.chunk({"structure": structure_chunk_size}) |
|
|
|
result = xr.apply_ufunc( |
|
_matmul, |
|
structure.fillna(0.0), |
|
index, |
|
input_core_dims=[("structure", sampling_dim), (sampling_dim,)], |
|
output_core_dims=[("structure",)], |
|
dask="allowed", |
|
) |
|
samples = xr.apply_ufunc( |
|
np.isfinite, structure, dask="parallelized", output_dtypes=[np.float] |
|
).sum(sampling_dim) |
|
result = result / samples |
|
datasets.append(result.unstack("structure")) |
|
return xr.merge(datasets) |
|
|
|
|
|
def lag_regress(da, index, sampling_dim, sampling_subset, lag): |
|
"""Compute a regression with a given lag. |
|
|
|
Parameters |
|
---------- |
|
da : xr.DataArray or xr.Dataset |
|
Dataset to regress against the index. |
|
index : xr.DataArray |
|
Index for regression. |
|
sampling_dim : Hashable |
|
Name of dimension the index uses. |
|
sampling_subset : xr.DataArray |
|
Boolean DataArray indicating which points of the timeseries are used |
|
in the calculation (e.g. to subset by season). |
|
lag : int |
|
Integer offset along the sampling dimension. |
|
|
|
Returns |
|
------- |
|
xr.DataArray or xr.Dataset |
|
""" |
|
shifted = da.shift({sampling_dim: -lag}) |
|
shifted, index = xr.align(shifted, index) |
|
index = standardize(index.sel({sampling_dim: sampling_subset}), sampling_dim) |
|
shifted = shifted.sel({sampling_dim: sampling_subset}) |
|
return regress(shifted, index, sampling_dim) |
|
|
|
|
|
if __name__ == "__main__": |
|
import cartopy.crs as ccrs |
|
import matplotlib.pyplot as plt |
|
import pandas as pd |
|
|
|
from dask.diagnostics import ProgressBar |
|
from faceted import faceted |
|
|
|
DATA = "/archive/Spencer.Clark/idealized_smd/smd_T2LeMvB150_1xCO2/gfdl.ncrc3-default-repro/history/atmos_4xday.zarr.zip" |
|
ds = xr.open_zarr(DATA) |
|
area = sfc_area(ds.latb, ds.lonb, ds.lat, ds.lon) |
|
precip = ds.precip.chunk({TIME_DIM: -1, LAT_DIM: 8}) |
|
d_lon = sample_spacing(ds, LON_DIM, 360.0) |
|
d_time = 0.25 # Units of days |
|
|
|
filtered = filter( |
|
precip - precip.mean([TIME_DIM, LON_DIM]), |
|
TIME_DIM, |
|
invert_freq=True, |
|
d=d_time, |
|
bounds=(1 / 15, np.inf), |
|
) |
|
filtered = filter(filtered, LON_DIM, d=d_lon, bounds=(-30, -3)) |
|
filtered = filtered.real |
|
|
|
region = {LON_DIM: slice(75, 85), LAT_DIM: slice(12.5, 22.5)} |
|
index = filtered.sel(region).weighted(area.sel(region)).mean([LAT_DIM, LON_DIM]) |
|
jjas = ds.time.dt.month.isin(range(6, 10)) & ds.time.dt.year.isin(range(10, 21)) |
|
lag_regression_sequence = xr.concat( |
|
[lag_regress(ds.precip, index, "time", jjas, 4 * lag) for lag in range(-5, 6)], |
|
pd.Index(range(-5, 6), name="lag_day"), |
|
) |
|
|
|
with ProgressBar(): |
|
result = lag_regression_sequence.compute() |
|
|
|
fig, axes, cax = faceted( |
|
5, |
|
1, |
|
width=5.0, |
|
aspect=0.5, |
|
axes_kwargs={"projection": ccrs.PlateCarree()}, |
|
cbar_mode="single", |
|
cbar_location="bottom", |
|
cbar_pad=0.3, |
|
internal_pad=0.3, |
|
bottom_pad=0.5 |
|
) |
|
|
|
VMIN = -3 |
|
VMAX = 3 |
|
|
|
for ax, lag_day in zip(axes, range(-2, 3)): |
|
c = result.precip.sel(lag_day=lag_day).plot.contourf( |
|
ax=ax, |
|
transform=ccrs.PlateCarree(), |
|
add_colorbar=False, |
|
vmin=VMIN, |
|
vmax=VMAX, |
|
levels=21, |
|
cmap="BrBG", |
|
) |
|
ax.set_xlim([60, 110]) |
|
ax.set_ylim([0, 25]) |
|
ax.coastlines(lw=0.5) |
|
|
|
plt.colorbar( |
|
c, cax=cax, orientation="horizontal", label="Precipitation anomaly [mm/day]" |
|
) |
|
fig.savefig("lag-regression-example.pdf") |