Skip to content

Instantly share code, notes, and snippets.

@RichardScottOZ
Forked from ljstrnadiii/reproject_resample.py
Created November 2, 2024 06:08
Show Gist options
  • Save RichardScottOZ/cac300edbb4b8345f8b659a9f21b0368 to your computer and use it in GitHub Desktop.
Save RichardScottOZ/cac300edbb4b8345f8b659a9f21b0368 to your computer and use it in GitHub Desktop.
Resample/Reproject Cogs w/ Dask + Rioxarray
import logging
import subprocess as sp
import tempfile
import threading
import geopandas as gpd
import numpy as np
import rioxarray
import xarray as xr
from rasterio.enums import Resampling
def build_warpvrt(
in_file: str,
out_file: str,
te: tuple[float, float, float, float],
target_crs: str,
resx: float,
resy: float,
resampling_method: Resampling,
):
try:
sp.run(
[
"gdalwarp",
"-overwrite",
"-t_srs",
target_crs,
"-tr",
str(resx),
str(resy),
"-te",
*[str(x) for x in te],
"-r",
resampling_method.name,
in_file,
out_file,
],
check=True,
capture_output=True,
)
except sp.CalledProcessError as e:
raise RuntimeError(e) from e
def build_vrt(in_files: list[str], out_file: str):
try:
sp.run(
[
"gdalbuildvrt",
"-overwrite",
out_file,
*in_files,
],
check=True,
capture_output=True,
)
except sp.CalledProcessError as e:
raise RuntimeError(e) from e
def _build_dataset(
tile_asset_gdf: gpd.GeoDataFrame,
resolution_m: float,
target_crs: str,
target_extent: tuple[float, float, float, float],
) -> xr.Dataset:
dsets = []
for collection, feature_subset in tile_asset_gdf.groupby("collection"):
f_enum = [f for f in FeatureEnum if f.value == feature][0]
feature = collection_to_feature(collection)
feature_date_dsets = []
for date, feature_date_subset in feature_subset.groupby("datetime"):
feature_date_subset = feature_subset[feature_subset.datetime == date]
crs_vrts = []
for _, crs_subset in feature_date_subset.groupby("crs"):
in_files = [f.replace("gs://", "/vsigs/") for f in crs_subset.url]
_, output_vrt = tempfile.mkstemp("-crs_vrt.vrt")
_, warped_vrt = tempfile.mkstemp("-warped_crs_vrt.vrt")
build_vrt(in_files, output_vrt)
build_warpvrt(
in_file=output_vrt,
out_file=warped_vrt,
target_crs=target_crs,
resx=resolution_m,
resy=resolution_m,
te=target_extent,
resampling_method=feature.resampling_method, # type: from rasterio.enums import Resampling
)
crs_vrts.append(warped_vrt)
_, date_feature_vrt = tempfile.mkstemp("-combined_vrt.vrt")
build_vrt(crs_vrts, date_feature_vrt)
array: xr.DataArray = rioxarray.open_rasterio(
date_feature_vrt,
chunks=(4, "auto", -1),
) # type: ignore
array = array.expand_dims("time")
array["time"] = [np.datetime64(date, "D").astype("datetime64[ns]")] # type: ignore
array.name = "variables"
array["band"] = feature_protocol.band_names
feature_date_dsets.append(array)
date_combined = xr.concat(feature_date_dsets, dim="time")
dsets.append(date_combined)
return xr.merge(dsets)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment