Created
June 21, 2022 13:07
-
-
Save tommylees112/951a0921b88a705a2cc9ef6f49777ff9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import xarray as xr | |
import numpy as np | |
import pandas as pd | |
from pathlib import Path | |
import rioxarray | |
from tqdm import tqdm | |
from scripts.clip_netcdf_to_shapefile import ( | |
prepare_rio_data, | |
rasterize_all_geoms, | |
create_timeseries_of_masked_datasets, | |
) | |
import geopandas as gpd | |
def initalise_rio_geospatial( | |
ds: xr.Dataset, crs: str = "epsg:3035", lon_dim: str = "x", lat_dim: str = "y" | |
): | |
ds = ds.rio.set_spatial_dims(x_dim=lon_dim, y_dim=lat_dim) | |
ds = ds.rio.write_crs(crs) | |
return ds | |
def reproject_ds(ds: xr.Dataset, reproject_crs: str = "EPSG:4326") -> xr.Dataset: | |
try: | |
ds = ds.rio.reproject(reproject_crs) | |
except MemoryError: | |
# do reprojection time by time | |
all_times = [] | |
for time in tqdm(ds.time.values, desc="Reproject Time"): | |
reproject_ds = ds.sel(time=time).rio.reproject(reproject_crs) | |
all_times.append(reproject_ds) | |
print("Concatenating all times") | |
ds = xr.concat(all_times, dim="time") | |
return ds | |
def convert_xy_to_station_id(ds: xr.Dataset, gauge_latlons: pd.DataFrame) -> xr.Dataset: | |
assert all(np.isin(["gauge_lat", "gauge_lon"], gauge_latlons.columns)) | |
all_sids = [] | |
# extract | |
for sid, row in gauge_latlons.iterrows(): | |
lat, lon = row["gauge_lat"], row["gauge_lon"] | |
# get the data for the nearest point on the stream | |
data = ds.sel(x=lon, y=lat, method="nearest") | |
# assign station_id as a coordinate | |
data = data.assign_coords({"station_id": ([sid])}) | |
all_sids.append(data) | |
# join all station_id making station_id a dimension -> (station_id, ) | |
_ds = xr.concat(all_sids, dim="station_id") | |
# expand the time dimension -> (time, station_id) | |
if ds.time.size == 1: | |
_ds = _ds.expand_dims("time") | |
return _ds | |
def select_point_data_from_gridded_modelGrid_output( | |
modelGrid_ds: xr.Dataset, gauge_latlons: pd.DataFrame | |
) -> xr.Dataset: | |
all_times = [] | |
pbar = tqdm(modelGrid_ds["time"].values, desc="Reproject and Select Station ID") | |
for time in pbar: | |
pbar.set_postfix_str(time) | |
# reproject a single time (memory issues) | |
single_time_ds = reproject_ds( | |
modelGrid_ds.sel(time=time), reproject_crs="EPSG:4326" | |
) | |
_ds = convert_xy_to_station_id(single_time_ds, gauge_latlons) | |
all_times.append(_ds) | |
print("Concatenating all timesteps") | |
sid_ds = xr.concat(all_times, dim="time") | |
return sid_ds | |
def _rasterize_shapefile( | |
input_ds: xr.Dataset, | |
shp_filepath: Path, | |
id_column: str = "ID_STRING", | |
shape_dimension: str = "station_id", | |
lat_dim: str = "y", | |
lon_dim: str = "x", | |
) -> xr.Dataset: | |
# 1. Create station_id xarray shape masks | |
single_time_ds = reproject_ds(input_ds.isel(time=0), reproject_crs="EPSG:4326") | |
gdf = gpd.read_file(shp_filepath) | |
# ensure that data properly initialised (e.g. CRS is the same) | |
single_time_ds, gdf = prepare_rio_data( | |
single_time_ds, gdf, lat_dim=lat_dim, lon_dim=lon_dim | |
) | |
# rasterize the shapefile | |
masks = rasterize_all_geoms( | |
ds=single_time_ds, | |
gdf=gdf, | |
id_column=id_column, | |
shape_dimension=shape_dimension, | |
geometry_column="geometry", | |
lat_dim=lat_dim, | |
lon_dim=lon_dim, | |
) | |
return masks | |
def select_catchment_average_from_gridded_modelGrid_output( | |
modelGrid_data: xr.Dataset, | |
shp_filepath: Path, | |
id_column: str = "ID_STRING", | |
shape_dimension: str = "station_id", | |
lat_dim: str = "y", | |
lon_dim: str = "x", | |
) -> xr.Dataset: | |
masks = _rasterize_shapefile( | |
input_ds=modelGrid_data, | |
shp_filepath=shp_filepath, | |
id_column=id_column, | |
shape_dimension=shape_dimension, | |
lat_dim=lat_dim, | |
lon_dim=lon_dim, | |
) | |
# for each timestep extract the mean station_id values for every catchment | |
pbar = tqdm(modelGrid_data["time"].values, desc="Create Mean of Masked Area") | |
all_times = [] | |
for time in pbar: | |
pbar.set_postfix_str(time) | |
# 2. Create timeseries of mean values (chop roi) | |
single_time_ds = reproject_ds( | |
modelGrid_data.sel(time=time), reproject_crs="EPSG:4326" | |
) | |
_ds = create_timeseries_of_masked_datasets( | |
ds=single_time_ds, | |
masks=masks, | |
shape_dimension=shape_dimension, | |
lat_dim=lat_dim, | |
lon_dim=lon_dim, | |
use_pbar=False, | |
) | |
# expand the time dimension -> (time, station_id) | |
if single_time_ds.time.size == 1: | |
_ds = _ds.expand_dims("time") | |
all_times.append(_ds) | |
print("Concatenating all times") | |
ds = xr.concat(all_times, dim="time") | |
return ds | |
if __name__ == "__main__": | |
data_dir = Path("/DataDrive200/data") | |
cwat_dir = data_dir / "modelGrid" | |
input_files = { | |
"Precipitation": cwat_dir / "Precipitation_daily.nc", | |
"Tavg": cwat_dir / "Tavg_daily.nc", | |
} | |
target_file = cwat_dir / "discharge_daily.nc" | |
hidden_files = { | |
"channelStorage": cwat_dir / "channelStorage_daily.nc", | |
"storGroundwater": cwat_dir / "storGroundwater_daily.nc", | |
"sum_interceptStor": cwat_dir / "sum_interceptStor_daily.nc", | |
"SnowCover": cwat_dir / "SnowCover_daily.nc", | |
"sum_w1": cwat_dir / "sum_w1_daily.nc", | |
"sum_w2": cwat_dir / "sum_w2_daily.nc", | |
"sum_w3": cwat_dir / "sum_w3_daily.nc", | |
} | |
assert all([f.exists() for f in list(input_files.values()) + [target_file]]) | |
TARGET = False | |
INPUT = False | |
HIDDEN = True | |
JOIN = False | |
# target data (DISCHARGE AT A POUR POINT) | |
if TARGET: | |
target = xr.open_dataset(target_file) | |
target = initalise_rio_geospatial(target[["discharge"]], crs="epsg:3035") | |
static = xr.open_dataset(data_dir / "static.nc") | |
gauge_latlons = static[["gauge_lat", "gauge_lon"]].to_dataframe() | |
# 1. reproject gridded modelGrid to latlon | |
# 2. get gauge lat lon | |
# 3. select gauge from gridded modelGrid | |
sid_ds = select_point_data_from_gridded_modelGrid_output( | |
modelGrid_ds=target, gauge_latlons=gauge_latlons | |
) | |
sid_ds.to_netcdf(data_dir / "modelGrid_discharge.nc") | |
# input data (mean over catchment area - "lumped") | |
if INPUT: | |
# for var, filepath in [("Tavg", cwat_dir / "Tavg_daily.nc")]: | |
# for var, filepath in [("Precipitation", cwat_dir / "Precipitation_daily.nc")]: | |
for var, filepath in input_files.items(): | |
# load in the gridded input data | |
input_ds = xr.open_dataset(filepath).drop("lambert_azimuthal_equal_area") | |
input_ds = initalise_rio_geospatial(input_ds[[var]], crs="epsg:3035") | |
shp_data_dir = data_dir / "CAMELS_GB_DATASET" | |
shp_filepath = ( | |
shp_data_dir / "Catchment_Boundaries/CAMELS_GB_catchment_boundaries.shp" | |
) | |
# 1. reproject gridded modelGrid to latlon | |
# 2. get catchment shapefiles | |
# 3. calculate shapefile means over catchment area | |
ds = select_catchment_average_from_gridded_modelGrid_output( | |
modelGrid_data=input_ds, shp_filepath=shp_filepath | |
) | |
ds.to_netcdf(data_dir / f"{var.lower()}_modelGrid.nc") | |
if HIDDEN: | |
# for var, filepath in [("storGroundwater", cwat_dir / "storGroundwater_daily.nc")]: | |
# for var, filepath in [("SnowCover", cwat_dir / "SnowCover_daily.nc")]: | |
# for var, filepath in [("sum_w1", cwat_dir / "sum_w1_daily.nc")]: | |
# for var, filepath in [("sum_w2", cwat_dir / "sum_w2_daily.nc")]: | |
# for var, filepath in [("sum_w3", cwat_dir / "sum_w3_daily.nc")]: | |
for var, filepath in hidden_files.items(): | |
# load in the gridded input data | |
input_ds = xr.open_dataset(filepath).drop("lambert_azimuthal_equal_area") | |
input_ds = initalise_rio_geospatial(input_ds[[var]], crs="epsg:3035") | |
shp_data_dir = data_dir / "CAMELS_GB_DATASET" | |
shp_filepath = ( | |
shp_data_dir / "Catchment_Boundaries/CAMELS_GB_catchment_boundaries.shp" | |
) | |
# 1. reproject gridded modelGrid to latlon | |
# 2. get catchment shapefiles | |
# 3. calculate shapefile means over catchment area | |
ds = select_catchment_average_from_gridded_modelGrid_output( | |
modelGrid_data=input_ds, shp_filepath=shp_filepath | |
) | |
ds.to_netcdf(data_dir / f"{var.lower()}_modelGrid.nc") | |
if JOIN: | |
# join all of the files together -> "modelGrid.nc" | |
# all_files = [data_dir / f"{var.lower()}_modelGrid.nc" for var in input_files.keys()] + [data_dir / "modelGrid_discharge.nc"] | |
all_files = [f for f in data_dir.glob("*_modelGrid.nc")] + [ | |
data_dir / "modelGrid_discharge.nc" | |
] | |
list_ds = [] | |
for fpath in all_files: | |
ds_ = xr.open_dataset(fpath) | |
ds_["station_id"] = ds_["station_id"].astype(int) | |
if "spatial_ref" in ds_.data_vars: | |
ds_ = ds_.drop("spatial_ref") | |
list_ds.append(ds_) | |
all_ds = xr.merge(list_ds) | |
all_ds["station_id"] | |
all_ds.to_netcdf(data_dir / "modelGrid.nc") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment