Created
February 8, 2022 04:32
-
-
Save emileten/cd28b028b3262dfe76fcb681eed129a3 to your computer and use it in GitHub Desktop.
impose a cell specific temporal cap on an [time, lon, lat] xarray dataset.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import xarray as xr | |
import numpy as np | |
### Fake data ### | |
def spatio_temporal_gcm_factory( | |
x=np.random.rand(1, 361, 721), | |
start_date="1995-01-01", | |
lat=np.arange(-90, 90.5, 0.5), | |
lon=np.arange(-180, 180.5, 0.5), | |
units="someunit", | |
): | |
time = xr.cftime_range( | |
start=start_date, freq="D", periods=len(x), calendar="standard" | |
) | |
out = xr.DataArray( | |
data=x, | |
coords={"time": time, "lat": lat, "lon": lon}, | |
dims=["time", "lat", "lon"], | |
attrs={"units": units}, | |
) | |
return out | |
# be fast : 2 time steps, 1 latitude value, 2 longitude values. | |
def tiny_factory(start_date): | |
return spatio_temporal_gcm_factory(x=np.random.rand(2, 1, 2), start_date=start_date, lat=[1.0], lon=[1.0, 2.0]) | |
# to check if dask likes this workflow | |
def chunked_tiny_factory(start_date): | |
non_chunked = spatio_temporal_gcm_factory(x=np.random.rand(2, 1, 2), start_date=start_date, lat=[1.0], lon=[1.0, 2.0]) | |
chunked = non_chunked.chunk({'time':-1,'lat': -1, 'lon':2}) # can't chunk across time ! | |
return chunked | |
factory = chunked_tiny_factory | |
era = factory(start_date="1950-01-01") | |
clean_hist = factory(start_date="1950-01-01") | |
clean_future = factory(start_date="2050-01-01") | |
ds_future = factory(start_date="2050-01-01") | |
### Compute ### | |
def cell_max(da): | |
return da.stack(gridcell=["lat", "lon"]).groupby("gridcell").max('time').unstack("gridcell") | |
era_cell_max = cell_max(era) | |
clean_hist_cell_max = cell_max(clean_hist) | |
clean_future_cell_max = cell_max(clean_future) | |
bias_corrected_cap = era_cell_max * (clean_future_cell_max/clean_hist_cell_max) | |
bias_corrected_cap.values[0][0] = 0.1 # so that behavior is obvious in the output. | |
time_steps = ds_future['time'].values # to expand dims | |
### Swap values for replacement where needed | |
cell_specific_cap_expanded = cell_specific_cap.expand_dims(dim=dict(time=time_steps)) # Need to add time dimension to the 'max' (creating duplicates) | |
final_output = ds_future.where(ds_future < cell_specific_cap_expanded, cell_specific_cap_expanded) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment