|
# Import the library |
|
import os |
|
import rasterio # Require version > 1.3 |
|
import numpy as np |
|
from tqdm import tqdm |
|
from rasterio.enums import Resampling |
|
|
|
def rasterio_dtype_to_gdal_code(rasterio_dtype_str): |
|
""" |
|
Map a rasterio dtype string (e.g. 'float32') to the GDAL numeric code. |
|
|
|
:param rasterio_dtype_str: e.g. 'uint8', 'int16', 'float32', etc. |
|
:return: integer code for GDAL (1 for Byte, 2 for UInt16, etc.) |
|
""" |
|
mapping = { |
|
'uint8': 1, # GDT_Byte |
|
'uint16': 2, # GDT_UInt16 |
|
'int16': 3, # GDT_Int16 |
|
'uint32': 4, # GDT_UInt32 |
|
'int32': 5, # GDT_Int32 |
|
'float32': 6, # GDT_Float32 |
|
'float64': 7, # GDT_Float64 |
|
# Some less common types can be added if needed: |
|
# 'complex64': 10, # GDT_CFloat32 |
|
# 'complex128': 11, # GDT_CFloat64 |
|
# 'int8': 12 # That is sometimes used, though not standard for Rasterio |
|
} |
|
if rasterio_dtype_str not in mapping: |
|
raise ValueError(f"Unsupported Rasterio dtype: {rasterio_dtype_str}") |
|
return mapping[rasterio_dtype_str] |
|
|
|
def determine_nodata_value_and_type(data_type): |
|
""" |
|
Determine a suitable NoData value and the corresponding GDAL data type |
|
based on the raster data type code. |
|
|
|
:param data_type: GDAL numeric data type code (e.g., 6 for Float32). |
|
:return: (nodata_value, gdal_data_type) pair |
|
""" |
|
if data_type == 1: # Byte (GDAL GDT_Byte) |
|
return 255, 1 |
|
elif data_type == 2: # UInt16 |
|
return 65535, 3 # 3 is GDT_UInt16 in GDAL |
|
elif data_type == 3: # Int16 |
|
return -32768, 2 # 2 is GDT_Int16 |
|
elif data_type == 4: # UInt32 |
|
return 4294967295, 4 |
|
elif data_type == 5: # Int32 |
|
return -2147483648, 5 |
|
elif data_type == 6: # Float32 |
|
return -3.40282346639e+38, 6 |
|
elif data_type == 7: # Float64 |
|
return -1.7976931348623157e+308, 7 |
|
elif data_type == 12: # Int8 |
|
return -128, 12 |
|
else: |
|
raise ValueError(f"Unsupported data type code: {data_type}") |
|
|
|
def generate_dekads(start_year=1981, end_year=2024): |
|
""" |
|
Generate all (year, month, dekad) triplets in chronological order. |
|
|
|
:param start_year: First year (inclusive). |
|
:param end_year: Last year (inclusive). |
|
:return: Generator of (year, month, dekad). |
|
""" |
|
for year in range(start_year, end_year + 1): |
|
for month in range(1, 13): |
|
for dekad in [1, 2, 3]: |
|
yield (year, month, dekad) |
|
|
|
|
|
def dekad_to_filename(year, month, dekad, base_dir): |
|
""" |
|
Construct the CHIRPS filename given year, month, dekad. |
|
|
|
:param year: Four-digit year (e.g., 1981). |
|
:param month: Month integer (1-12). |
|
:param dekad: Dekad integer (1, 2, 3). |
|
:param base_dir: Path to folder with CHIRPS data. |
|
:return: Full path to the CHIRPS GeoTIFF file. |
|
""" |
|
# Convert month to zero-padded 2-digit string |
|
mm = f"{month:02d}" |
|
# Dekad is 1, 2, 3 in the filename |
|
d = str(dekad) |
|
fname = f"chirps-v3.0.{year}.{mm}.{d}.tif" |
|
return os.path.join(base_dir, fname) |
|
|
|
def get_day_from_dekad(dekad): |
|
""" |
|
Map dekad to day string: |
|
1 -> '01' |
|
2 -> '11' |
|
3 -> '21' |
|
|
|
:param dekad: Integer dekad (1, 2, or 3). |
|
:return: String day ('01', '11', or '21'). |
|
""" |
|
mapping = {1: '01', 2: '11', 3: '21'} |
|
return mapping[dekad] |
|
|
|
def shift_dekad(year, month, dekad, shift): |
|
""" |
|
Shift a given (year, month, dekad) by `shift` dekads (which can be negative). |
|
Returns the new (year, month, dekad). |
|
|
|
:param year: Four-digit year. |
|
:param month: Month (1-12). |
|
:param dekad: Dekad (1,2,3). |
|
:param shift: Number of dekads to shift (e.g., -1 for previous dekad). |
|
:return: (new_year, new_month, new_dekad). |
|
""" |
|
# Convert the triplet into an absolute index: |
|
# E.g., index = (year - 1)*36 + (month - 1)*3 + (dekad - 1). |
|
# Then shift the index, and convert back. |
|
absolute_index = (year * 36) + (month - 1) * 3 + (dekad - 1) |
|
new_index = absolute_index + shift |
|
if new_index < 0: |
|
return None # Means out of range |
|
|
|
# Extract year, month, dekad from new_index |
|
new_year = new_index // 36 |
|
remainder = new_index % 36 |
|
new_month = (remainder // 3) + 1 |
|
new_dekad = (remainder % 3) + 1 |
|
|
|
# This logic sets 'year' as new_index // 36, but we originally used (year * 36). |
|
# If we want to strictly limit up to 2024, we can check if new_year > 2024, etc. |
|
return (new_year, new_month, new_dekad) |
|
|
|
def sum_rasters(raster_paths, nodata_val=None): |
|
""" |
|
Sum multiple rasters (same shape, CRS, etc.) *without* masking -9999. |
|
After summation, replace any negative value with 'nodata_val'. |
|
|
|
:param raster_paths: List of paths to raster files. |
|
:param nodata_val: If provided, use as NoData for output; |
|
else read from first raster or fallback. |
|
:return: (summed_array, profile) or None if no valid input. |
|
""" |
|
arrays = [] |
|
valid_profile = None |
|
first_dtype_str = None |
|
|
|
for rp in raster_paths: |
|
if not os.path.isfile(rp): |
|
continue |
|
with rasterio.open(rp) as src: |
|
if valid_profile is None: |
|
valid_profile = src.profile.copy() |
|
first_dtype_str = valid_profile['dtype'] |
|
|
|
# Convert that Rasterio dtype to a GDAL code |
|
gdal_code = rasterio_dtype_to_gdal_code(first_dtype_str) |
|
# e.g., for float32 => gdal_code=6 |
|
fallback_nodata, _ = determine_nodata_value_and_type(gdal_code) |
|
# For float32, fallback_nodata is ~ -3.40282346639e+38 |
|
|
|
if nodata_val is None: |
|
# If not explicitly passed, just hold on to fallback_nodata |
|
# as the official nodata in the *metadata*. |
|
nodata_val = fallback_nodata |
|
|
|
# Read as float32 |
|
data = src.read(1).astype(np.float32) |
|
arrays.append(data) |
|
|
|
if not arrays or valid_profile is None: |
|
return None |
|
|
|
# Sum in float32 |
|
stacked = np.stack(arrays, axis=0) |
|
summed = stacked.sum(axis=0, dtype=np.float32) |
|
|
|
# Replace negative with -9999 |
|
mask_neg = (summed < 0) |
|
summed[mask_neg] = -9999 |
|
|
|
# Replace inf/nan with -9999 |
|
mask_inf = ~np.isfinite(summed) |
|
summed[mask_inf] = -9999 |
|
|
|
# At this point, any "bad" pixel is -9999, not the official float32 NoData |
|
|
|
# Update profile for single-band float32 output |
|
valid_profile.update({ |
|
'count': 1, |
|
'driver': 'COG', |
|
'dtype': 'float32', |
|
'nodata': nodata_val, |
|
'add_overviews': True, |
|
'overview_levels': [2, 4, 8], |
|
'overview_resampling': Resampling.average.name, |
|
'compress': 'deflate', |
|
'blocksize': 512, |
|
'predictor': 2, |
|
}) |
|
|
|
return summed, valid_profile |
|
|
|
def main_chirps_rolling( |
|
base_dir, |
|
out_dir, |
|
start_year=1981, |
|
end_year=2024, |
|
accum_months_list=[1,2,3,6,9,12,18,24,36,48,60,72], |
|
nodata_val=None |
|
): |
|
""" |
|
Main function to compute rolling accumulations from CHIRPS dekadal data. |
|
|
|
:param base_dir: Directory where CHIRPS files are located. |
|
:param out_dir: Root output directory. |
|
:param start_year: Start year of the data range. |
|
:param end_year: End year of the data range. |
|
:param accum_months_list: List of month-accumulation windows, e.g. [1,2,3,6,9,12,24]. |
|
:param nodata_val: If not None, force this NoData value for outputs. |
|
Otherwise, determine dynamically from first input file. |
|
""" |
|
# Make sure the output directories exist |
|
print(f"Creating output directories...") |
|
for n_months in accum_months_list: |
|
out_subdir = os.path.join(out_dir, f"month{n_months}_rolling_dekad") |
|
os.makedirs(out_subdir, exist_ok=True) |
|
|
|
all_dekads = list(generate_dekads(start_year, end_year)) |
|
total_dekads = len(all_dekads) |
|
|
|
print(f"Processing {total_dekads} dekads from {start_year} to {end_year}") |
|
print(f"Accumulation periods (months): {accum_months_list}") |
|
|
|
# Outer progress bar for the total dekads |
|
for (year, month, dekad) in tqdm(all_dekads, desc="Overall Progress"): |
|
# For each accumulation period |
|
for n_months in accum_months_list: |
|
n_dekads = 3 * n_months # 3 dekads per month |
|
|
|
# Collect the needed raster files and keep track of (year,month,dekad) |
|
shift_list = [] |
|
raster_files = [] |
|
# We want to go backward (n_dekads-1) times from the current (year,month,dekad) |
|
# including the current one. i.e. from i-(n_dekads-1) to i |
|
for shift_val in range(n_dekads): |
|
# shift_val = 0 means current dekad, shift_val = 1 means 1 dekad before, etc. |
|
shifted = shift_dekad(year, month, dekad, -shift_val) |
|
if shifted is None: |
|
# out of range (earlier than data start) |
|
continue |
|
yy, mm, dd = shifted |
|
rf = dekad_to_filename(yy, mm, dd, base_dir) |
|
|
|
# Only add if the file actually exists |
|
if os.path.isfile(rf): |
|
raster_files.append(rf) |
|
shift_list.append((yy, mm, dd)) |
|
|
|
# Require the full set (n_dekads) if we want a complete sum |
|
# (Otherwise skip partial coverage) |
|
if len(raster_files) < n_dekads: |
|
# We can uncomment the next line to see a message: |
|
tqdm.write(f"Skipping incomplete data: needed {n_dekads}, found {len(raster_files)}.") |
|
continue |
|
|
|
# Sum the rasters ignoring NoData |
|
result = sum_rasters(raster_files, nodata_val=nodata_val) |
|
if result is None: |
|
tqdm.write(f"Warning: Failed to process {year}-{month:02d}-{dekad} ({n_months}-month)") |
|
continue |
|
summed_array, out_profile = result |
|
|
|
# Sort the shift_list by chronological order to find the "last" (newest) date |
|
shift_list.sort(key=lambda t: (t[0] * 36 + (t[1]-1)*3 + (t[2]-1))) |
|
last_year, last_month, last_dekad = shift_list[-1] |
|
day_str = get_day_from_dekad(dekad) |
|
|
|
# Where to save |
|
out_name = f"wld_cli_chirps3_month{n_months}_{year}{month:02d}{day_str}.tif" |
|
out_subdir = os.path.join(out_dir, f"month{n_months}_rolling_dekad") |
|
out_path = os.path.join(out_subdir, out_name) |
|
|
|
# Write out the sum as a single-band COG |
|
with rasterio.open(out_path, 'w', **out_profile) as dst: |
|
dst.write(summed_array, 1) |
|
|
|
tqdm.write(f"Created: {out_name}") |
|
|
|
print("\nProcessing completed!") |
|
print(f"Output files are saved in: {out_dir}") |
|
|
|
if __name__ == "__main__": |
|
""" |
|
Example usage: |
|
python chirps_rolling.py |
|
Make sure to adjust base_dir, out_dir, and other parameters to our needs. |
|
""" |
|
|
|
# Folder where CHIRPS dekad .tif files are stored |
|
CHIRPS_BASE_DIR = "/mnt/e/temp/chirps/dekad" |
|
# Folder where we want to store the outputs |
|
OUTPUT_DIR = "/mnt/e/temp/chirps" |
|
|
|
# Call our main function |
|
main_chirps_rolling( |
|
base_dir=CHIRPS_BASE_DIR, |
|
out_dir=OUTPUT_DIR, |
|
start_year=1981, |
|
end_year=2024, |
|
accum_months_list=[24], |
|
# If None, NoData is automatically determined from the first raster; |
|
# or we can pass a numeric value to force e.g. -9999 or 9999, etc. |
|
nodata_val=None |
|
) |