Last active
August 5, 2021 08:42
-
-
Save lukegre/9b9a1a150784af5d95e2aa6b297d1d25 to your computer and use it in GitHub Desktop.
Perform a seasonal cycle decomposition and determine the driver of the trends
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import xarray as xr | |
import numpy as np | |
import joblib | |
import os | |
# SEASONAL CYCLE FITTING # | |
def graven2013_seasonal_fit(da, n_jobs=24): | |
""" | |
Fits a seasonal cycle to data using cos and sin functions. | |
Accounts for annual and 6-monthly cycles. | |
Fits the Graven et al. (2013) | |
assumes time, lat, lon dataarray where time is a monthly resolution | |
""" | |
def fit_graven2013_numpy(x, y): | |
""" | |
Fits a seasonal cycle to data using cos and sin functions. | |
Accounts for annual and 6-monthly cycles. | |
Args: | |
x (array-like): the time-dimension | |
y (array-like): the time-varying value to fit a seasonal cycle to | |
Returns: | |
array-like: a fitted seasonal cycle (y-hat) | |
""" | |
from scipy import optimize | |
from numpy import sin, cos, pi, nan | |
def func(t, a1, a2, a3, a4, a5, a6, a7): | |
"""function to fit as defined by Peter""" | |
return ( | |
a1 + a2*t + a3*t**2 + | |
a4 * sin(2 * pi * t) + | |
a5 * cos(2 * pi * t) + | |
a6 * sin(4 * pi * t) + | |
a7 * cos(4 * pi * t)) | |
try: | |
params, params_covariance = optimize.curve_fit(func, x, y) | |
except RuntimeError: | |
params = [nan] * 7 | |
yhat = func(x, *params) | |
return yhat | |
da_stack = da.stack(coords=['lat', 'lon']) | |
Ynan = da.stack(coords=['lat', 'lon']).values.T | |
Yhat = Ynan * np.nan | |
nan_mask = np.isnan(Ynan).any(axis=1) | |
Y = Ynan[~nan_mask] | |
# the x-input has to be in years for the fitting function to work | |
x = da.time.values.astype('datetime64[M]').astype(float) / 12 | |
n_yrs = x.size // 12 | |
lat = da.lat.values | |
lon = da.lon.values | |
pool = joblib.Parallel(n_jobs=n_jobs) | |
func = joblib.delayed(fit_graven2013_numpy) | |
queue = [func(x, y) for y in Y] | |
Yhat[~nan_mask] = np.array(pool(queue)) | |
Yhat_clim = ( | |
Yhat | |
.reshape(lat.size, lon.size, n_yrs, 12) | |
.mean(axis=2) | |
.transpose(2, 0, 1)) | |
djf = Yhat_clim[[-1, 0, 1]].mean(axis=0) | |
jja = Yhat_clim[[5, 6, 7]].mean(axis=0) | |
rmse = ( | |
(np.square(Yhat - Ynan).mean(axis=1)**0.5) | |
.reshape(lat.size, lon.size)) | |
dims = 'month', 'lat', 'lon' | |
ds = xr.Dataset(coords=dict(month=range(1, 13), lat=da.lat.values, lon=da.lon.values)) | |
ds['seasonal_clim'] = xr.DataArray(Yhat_clim, dims=dims) | |
ds['seas_jja_minus_djf'] = xr.DataArray(jja - djf, dims=dims[1:]).assign_attrs(description='JJA - DFJ') | |
ds['rmse'] = xr.DataArray(rmse, dims=dims[1:]) | |
return ds | |
def climatology_based_seasonal_fit(da): | |
""" | |
Simply calculates seasonal cycle amplitude as the clim.max - clim.mean | |
Args: | |
da (xr.DataArray): input array with a time dimension that is mothly | |
Returns: | |
xr.DataArray: A single `timestep` of data | |
""" | |
clim = da.groupby('time.month').mean('time') | |
ds = xr.Dataset() | |
ds['seasonal_clim'] = clim | |
ds['seas_jja_minus_djf'] = ( | |
clim.sel(month=[6, 7, 8]).mean('month') - | |
clim.sel(month=[12, 1, 2]).mean('month')) | |
ds['rmse'] = np.square(da - clim.sel(month=da.time.dt.month)).mean('time')**0.5 | |
return ds | |
def fit_seasonal_cycle(da, window_size_years=3, func=graven2013_seasonal_fit): | |
amplitudes = [] | |
y0, y1 = da.time.dt.year[[0, -1]].values | |
nicename = '.'.join([func.__module__, func.__name__]) | |
dy = int((window_size_years - 1) / 2) | |
years = list(range(y0, y1+1)) | |
for year in years: | |
print(year, end=', ') | |
t0 = year - dy | |
t1 = year + dy | |
s0 = str(t0) | |
s1 = str(t1) | |
subset = da.sel(time=slice(s0, s1)).dropna('time', how='all') | |
amplitudes += func(subset), | |
print('') | |
amplitude = ( | |
xr.concat(amplitudes, 'time') | |
.assign_coords(time=[np.datetime64(f"{y}-01-01") for y in years]) | |
.assign_attrs(description=f'Seasonal cycle fit with {nicename}') | |
) | |
return amplitude | |
def create_seasonal_cycle_fit_netcdf(ds, dest='.'): | |
ds = ds.load() | |
name = os.path.join(dest, 'seasonal_cycle_fit_{key}_{y0}-{y1}.nc') | |
os.makedirs(dest, exist_ok=True) | |
y0, y1 = ds.time.dt.year[[0, -1]].values | |
for key in ds: | |
sname = name.format(key=key, y0=y0, y1=y1) | |
if os.path.exists(sname): | |
continue | |
print(f'Working on file: {sname}') | |
da = ds[key] | |
# output contains seasonal cycle fits, JJA-DJF, and RMSE | |
seasonal_cycle = fit_seasonal_cycle(da, window_size_years=3) | |
save_nc_zip(seasonal_cycle, sname) | |
# now working on only the JJA-DJF differences | |
sname = os.path.join(dest, f'seasonal_cycl_fit_jja-djf_{y0}-{y1}.nc') | |
if os.path.exists(sname): | |
return xr.open_dataset(sname) | |
delta_seas = xr.Dataset() | |
for key in ds: | |
fname = name.format(key=key, y0=y0, y1=y1) | |
delta_seas[key] = xr.open_dataset(fname).seas_jja_minus_djf | |
delta_seas = delta_seas.assign_attrs( | |
description=('JJA-DJF calculated using the Graven (2013) approach')) | |
save_nc_zip(delta_seas, sname) | |
return delta_seas | |
def prep_data(pco2, dic, alk, temp, salt): | |
salt_norm = salt.mean('time') | |
sdic = (dic / salt * salt_norm).clip(1500, 3000).rename('s_dissicos') | |
salk = (alk / salt * salt_norm).clip(1500, 3000).rename('s_talkos') | |
pco2 = pco2.rename('spco2') | |
temp = temp.rename('tos') | |
salt = salt.rename('sos') | |
ds = xr.merge([pco2, sdic, salk, temp, salt]).assign_attrs( | |
description=( | |
's_dissicos and s_talkos were normalised with the long-term ' | |
'mean of sos')) | |
ds = ds.assign_coords(lon=lambda x: x.lon % 360).sortby('lon') | |
ds = ds.astype('float32') | |
return ds | |
# SENSITIVITY GAMMA # | |
def revelle_factor(alk, dic): | |
gamma = ( | |
(3 * alk * dic - 2 * dic ** 2) / | |
((2 * dic - alk) * (alk - dic))) | |
return gamma | |
def alkalinity_buffer_factor(alk, dic): | |
gamma = -(alk ** 2) / ((2 * dic - alk) * (alk - dic)) | |
return gamma | |
def create_gamma_netcdf(alk, dic, temp, dest='.'): | |
y0, y1 = temp.time.dt.year[[0, -1]].values | |
name = os.path.join(dest, f'spco2_sensitivity_gamma_{y0}-{y1}.nc') | |
file = return_ds_else_create_dir(name) | |
if file is not None: | |
return file | |
gamma = xr.Dataset() | |
gamma['s_dissicos'] = revelle_factor(alk, dic) | |
gamma['s_talkos'] = alkalinity_buffer_factor(alk, dic) | |
gamma['tos'] = temp * 0 + 0.0423 | |
gamma['freshwater'] = 1 + gamma['s_talkos'] + gamma['s_dissicos'] | |
save_nc_zip(gamma, name) | |
return gamma | |
# DECOMPOSITION | |
def make_dataset_for_decomposition(spco2, dic, alk, temp, salt, dest='.'): | |
sname = os.path.join(dest, 'spco2_decomp_inputs.nc') | |
if os.path.exists(sname): | |
return xr.open_dataset(sname) | |
os.makedirs(dest, exist_ok=True) | |
print('Preparing data') | |
ds_mon = prep_data(spco2, dic, alk, temp, salt).load() # used for seasonal cycle | |
ds_an = ds_mon.resample(time='1AS').mean() # annual data used most of the time | |
print('Fitting seasonal cycle') | |
delta_seas = create_seasonal_cycle_fit_netcdf(ds_mon, dest=dest).load() | |
print('Calculating sensitivities') | |
gamma = create_gamma_netcdf(ds_an.s_talkos, ds_an.s_dissicos, ds_an.tos, dest=dest) | |
gamma = gamma.rename(freshwater='sos') # for loop to create data | |
gamma = gamma.load() | |
# set temp to 1 as this is not required as gamma does not change | |
ds_an['tos'] = ds_an.tos * 0 + 1 | |
print('Creating input data') | |
component = ['gamma', 'pco2', 'delta_seas', 'var'] | |
variables = ['s_dissicos', 's_talkos', 'tos', 'sos'] | |
dat = xr.Dataset() | |
for key in variables: | |
dat[key] = xr.concat( | |
[gamma[key], ds_an.spco2, delta_seas[key], ds_an[key]], | |
'component').assign_coords(component=component) | |
dat = dat.rename(sos='freshwater') | |
save_nc_zip(dat, sname) | |
return dat | |
def _decompose_spco2(input_a, input_b, spco2, dest='.'): | |
""" | |
a is the abundant input, | |
b is the shifting case | |
""" | |
a = input_a | |
b = input_b | |
assert a.component.values.tolist() == ['gamma', 'pco2', 'delta_seas', 'var'] | |
assert b.component.values.tolist() == ['gamma', 'pco2', 'delta_seas', 'var'] | |
assert list(a.keys()) == ['s_dissicos', 's_talkos', 'tos', 'freshwater'] | |
assert list(b.keys()) == ['s_dissicos', 's_talkos', 'tos', 'freshwater'] | |
decomp = xr.Dataset() | |
for key in a: | |
decomp[key] = xr.concat([ | |
(b[key][0] * a[key][1] * a[key][2] / a[key][3]).assign_coords(component='gamma'), | |
(a[key][0] * b[key][1] * a[key][2] / a[key][3]).assign_coords(component='pco2'), | |
(a[key][0] * a[key][1] * b[key][2] / a[key][3]).assign_coords(component='delta_seas')], | |
'component') | |
decomp = decomp.where(lambda x: x !=0) | |
delta_seas = create_seasonal_cycle_fit_netcdf(spco2.to_dataset(name='spco2'), dest=dest).spco2 | |
region = regions_seasonal_signal(spco2) | |
sign = ((region.where(lambda x: x>0) - 1) // 2 * 2 - 1) | |
weights = seaflux_weighting(dest) | |
decomp = decomp.to_array(dim='driver', name='spco2_decomp').to_dataset() | |
decomp['spco2_seas_dif'] = delta_seas | |
decomp['region_mask'] = region | |
decomp['sign'] = sign.assign_attrs(description='multiply to get to winter - summer') | |
decomp['weights'] = weights | |
decomp = decomp.assign_coords(driver=['sDIC', 'sAlk', 'Thermal', 'Freshwater']) | |
return decomp | |
def decompose_spco2_with_avg(spco2, dic, alk, temp, salt, dest='.', overwrite=False): | |
sname = os.path.join(dest, 'spco2_decomposed_avg.nc') | |
if not overwrite and os.path.isfile(sname): | |
return xr.open_dataset(sname) | |
if overwrite and os.path.isfile(sname): | |
os.remove(sname) | |
inputs = make_dataset_for_decomposition(spco2, dic, alk, temp, salt, dest) | |
average = inputs.mean('time') | |
decomp = _decompose_spco2(average, inputs, spco2, dest) | |
decomp['spco2_decomp_regional'] = make_regional_decomp(decomp) | |
save_nc_zip(decomp, sname) | |
return decomp | |
def decompose_spco2_with_trend(spco2, dic, alk, temp, salt, dest='.', overwrite=False): | |
sname = os.path.join(dest, 'spco2_decomposed_trend.nc') | |
if not overwrite and os.path.isfile(sname): | |
return xr.open_dataset(sname) | |
if overwrite and os.path.isfile(sname): | |
os.remove(sname) | |
inputs = make_dataset_for_decomposition(spco2, dic, alk, temp, salt, dest) | |
trends = make_rolling_trends_netcdf(inputs, dest) | |
decomp = _decompose_spco2(inputs, trends, spco2, dest) | |
decomp['spco2_decomp_regional'] = make_regional_decomp(decomp) | |
save_nc_zip(decomp, sname) | |
return decomp | |
# REGIONAL FUNCTIONS AND DATASETS # | |
def make_regional_decomp(decomp_out, variable='spco2_decomp'): | |
ds = decomp_out | |
weights = seaflux_weighting() | |
decomp = ds[variable] * ds.sign | |
regions = ds.region_mask.where(lambda x: x > 0) | |
regional = regional_aggregation(decomp, regions, weights) | |
region_names = ['NH-HL', 'NH-LL', 'SH-LL', 'SH-HL'] | |
regional = xr.concat(regional, 'region') | |
regional = regional.assign_coords(region=region_names) | |
regional = regional.assign_attrs(description='weighted averages') | |
return regional | |
def regional_aggregation(xda, region_mask, weights=None, func='mean'): | |
regional = [] | |
for r, da in xda.groupby(region_mask): | |
da = da.unstack() | |
if weights is not None: | |
da = da.weighted(weights) | |
da = getattr(da, func)(['lat', 'lon']) | |
da = da.assign_coords(region=r) | |
regional += da, | |
regional = xr.concat(regional, 'region') | |
return regional | |
def regions_seasonal_signal(pco2): | |
""" | |
Creates a mask for pCO2 dividing the ocean into four latitudinal bands | |
Norhtern/Southern Hemisphere, Low/High latitudes. | |
The split of the latitudes is based on the change in seasonal cycle which | |
is expected to be at roughly 40deg N/S. | |
""" | |
def rolling_annual_avg(da, window=3): | |
return da.rolling(time=window, center=True, min_periods=1).mean() | |
pco2 = pco2.assign_coords(lon=lambda a: a.lon % 360).sortby('lon').load() | |
quarterly = pco2.resample(time='1QS-Dec').mean('time') | |
# select summer and winter data - DJF has to be shifted so that the | |
# year matches the winter year data | |
DJF = rolling_annual_avg(quarterly[0::4].assign_coords(time=lambda a:a.time.dt.year+1)) | |
JJA = rolling_annual_avg(quarterly[2::4].assign_coords(time=lambda a:a.time.dt.year)) | |
seas = (JJA - DJF).mean('time') | |
# getting the seasonal masks with certain conditions met | |
NHHL = (((seas < 0) & (seas.lat > 35)) | ((seas > 0) & ( seas.lat > 55))) & (seas.lat < 65) | |
NHLL = (( seas > 0) & (seas.lat > 10) & ( seas.lat < 50)) | ((seas.lat > 10) & (seas.lat < 25)) | |
SHLL = (( seas < 0) & (seas.lat < -10) & ( seas.lat > -55)) | ((seas.lat < -10) & (seas.lat > -30)) | |
SHHL = ( seas > 0) & (seas.lat < -30) & ( seas.lat > -65) | |
# from north to south increasing | |
mask = (1*NHHL + 2*NHLL + 3*SHLL + 4*SHHL).where(seas.notnull()).fillna(0) | |
mask = mask.rename('pco2_seasonal_cycle_mask') | |
return mask | |
def seaflux_weighting(dest='.'): | |
import pyseaflux | |
sname = os.path.join(dest, 'weights_flux_based_seaflux.nc') | |
if os.path.isfile(sname): | |
return xr.open_dataarray(sname) | |
ds = pyseaflux.get_seaflux_data(verbose=False).sel(wind='ERA5').drop('wind') | |
mask = ds.sol_Weiss74[0].notnull() | |
ds = ds[[ | |
'kw_scaled', | |
'sol_Weiss74', | |
'ice']] | |
ds['area'] = pyseaflux.get_area_from_dataset(ds) | |
ds['ice_free'] = 1 - ds.ice.fillna(0) | |
ds = ds.drop('ice') | |
da = ds.to_array('variable') | |
var = da['variable'].values.astype(str).tolist() | |
var_str = ' * '.join(var) | |
da = da.prod('variable') | |
weights = ( | |
da.load().where(mask).mean('time') | |
.fillna(0) | |
.assign_coords(lon=lambda a: (a.lon % 360)) | |
.sortby('lon') | |
.rename(var_str) | |
.assign_attrs( | |
description=f'weighting is {var_str}')) | |
weights = weights - weights.min() | |
weights = weights / weights.max() | |
weights.to_netcdf(sname) | |
return sname | |
# TREND FUNCTIONS # | |
def make_rolling_trends_netcdf(inputs, dest='.'): | |
sname = os.path.join(dest, 'spco2_decomp_input_trends.nc') | |
file = return_ds_else_create_dir(sname) | |
if file is not None: | |
return file | |
print('Calculating trends') | |
trends = ( | |
calc_rolling_trend(inputs.to_array(), n_jobs=36) | |
.to_dataset(dim='variable') | |
.transpose('component', 'time', 'lat', 'lon')) | |
save_nc_zip(trends, sname) | |
return trends | |
def calc_trend(xda, dim='time'): | |
if xda.shape[0] <= 2: | |
# dummy trend | |
return xda[0] * np.nan | |
if xda.name is None: | |
name = 'slope' | |
else: | |
name = xda.name + '_slope' | |
xda = xda.assign_coords(time=lambda x: x.time.dt.year) | |
trend = ( | |
xda.polyfit('time', 1, skipna=False) | |
.polyfit_coefficients[0] | |
.drop('degree') | |
.rename(name) | |
.assign_attrs(units='units/year') | |
) | |
return trend | |
def calc_rolling_trend(da_in, window_size=3, n_jobs=36): | |
roll = da_in.rolling(time=window_size, min_periods=3, center=True) | |
time = xr.concat([t for t, _ in roll], 'time') | |
if n_jobs > 1: | |
pool = joblib.Parallel(n_jobs=n_jobs) | |
func = joblib.delayed(calc_trend) | |
queue = [func(xda) for _, xda in roll] | |
trends = pool(queue) | |
else: | |
func = calc_trend | |
trends = [func(xda) for _, xda in roll] | |
trends = xr.concat(trends, 'time').assign_coords(time=time) | |
return trends | |
# FILE UTILS # | |
def save_nc_zip(ds, sname): | |
for key in ds.data_vars: | |
ds[key] = ds[key].astype('float32') | |
comp = dict(complevel=4, zlib=True) | |
ds.to_netcdf(sname, encoding={k: comp for k in ds}) | |
def return_ds_else_create_dir(sname): | |
from pathlib import Path | |
sname = Path(sname) | |
if sname.is_file(): | |
return xr.open_dataset(sname) | |
else: | |
sname.parent.mkdir(exist_ok=True, parents=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment