Skip to content

Instantly share code, notes, and snippets.

@prl900
Created April 11, 2026 23:31
Show Gist options
  • Select an option

  • Save prl900/dbe30ea82975682d1b4f4003a0341f4e to your computer and use it in GitHub Desktop.

Select an option

Save prl900/dbe30ea82975682d1b4f4003a0341f4e to your computer and use it in GitHub Desktop.
import os
import sys
import numpy as np
import xarray as xr
from datetime import datetime, timedelta
import calendar
import shutil
import time
def downscale_era5_temp(year):
"""
Downscale ERA5 temperature data for a given year using xarray
Parameters:
-----------
year : int
Year to process
Returns:
--------
ret : output from make_era5_files function
"""
# Define directories
codedir = '/g/data/xc0/project/OzWALD/R2020/ERA5/'
tempdir = os.getcwd() + '/'
# Call make_era5_files (assuming this function exists)
# ret = make_era5_files(year, tempdir)
# Define spatial range
lat_range = (-10, -44)
lon_range = (112, 154)
# Read static data (elevation)
fnz = '/g/data/xc0/user/pablo/AusEnvERA5/static_era5.nc'
ds_static = xr.open_dataset(fnz)
print(ds_static)
print(ds_static.latitude.values)
print(lat_range)
print("-----------------")
print(ds_static.longitude.values)
print(lon_range)
print("-----------------")
z = ds_static['z'].sel(latitude=slice(*lat_range), longitude=slice(*lon_range))
z = z / 9.8 # Convert from m**2 s**-2 to m above sea level
print(z.shape)
_, Nlat, Nlon = z.shape
# Read DEM
fndem = 'DEM_500m.nc'
ds_dem = xr.open_dataset(fndem)
DEM = ds_dem['DEM'].astype(np.float32)
# Define output directory and variables
outdir = '/g/data/ub8/au/OzWALD/daily/meteo/'
outvars = ['Tmax', 'Tmin', 'kTeff', 'kTavg', 'VPeff']
# File pattern
fn_pattern = '/g/data/xc0/user/pablo/AusEnvERA5/{var}_era5_oper_sfc_{tstart}-{tend}.nc'
# Process year
yyyy = year
ydays = 366 if calendar.isleap(yyyy) else 365
ddone = 0
start_time = time.time()
# Initialize output datasets
output_ds = {}
for varname in outvars:
outfn = f'OzWALD.{varname}.{yyyy}.nc'
tempfn = os.path.join(tempdir, outfn)
output_ds[varname] = tempfn
for mm in range(1, 13):
print(f'Loading data for month {mm}')
Nd = calendar.monthrange(yyyy, mm)[1]
# Build filenames
fn = fn_pattern.replace('{yyyy}', str(yyyy))
fn = fn.replace('{tstart}', f'{yyyy:04d}{mm:02d}01')
fn = fn.replace('{tend}', f'{yyyy:04d}{mm:02d}{Nd:02d}')
fn_temp = fn.replace('{var}', 't2m')
fn_rad = fn.replace('{var}', 'ssrd')
fn_dewt = fn.replace('{var}', 'd2m')
if not os.path.exists(fn_temp):
print(f'Warning: file {fn_temp} not found')
continue
print(f'Reading {fn_temp}')
# Read all data for the month using xarray
ds_temp = xr.open_dataset(fn_temp)
Ta = ds_temp['t2m'].sel(latitude=slice(*lat_range),
longitude=slice(*lon_range))
Ta = Ta - 273.15 # Convert to Celsius
# Read radiation data
ds_rad = xr.open_dataset(fn_rad)
rad = ds_rad['ssrd'].sel(latitude=slice(*lat_range),
longitude=slice(*lon_range))
# Read dewpoint temperature
ds_dewt = xr.open_dataset(fn_dewt)
dewt = ds_dewt['d2m'].sel(latitude=slice(*lat_range),
longitude=slice(*lon_range))
dewt = dewt - 273.15
# Calculate vapor pressure
vp = 6.1078 * np.exp((17.269 * dewt) / (237.3 + dewt))
# Read pressure level data for lapse rate
fn_tp = fn.replace('surface', 'pressure').replace('{var}', 't').replace('global', 'aus')
ds_pressure = xr.open_dataset(fn_tp)
Tz1000 = ds_pressure['t'].sel(pressure_level=1000,
latitude=slice(*lat_range),
longitude=slice(*lon_range))
Tz950 = ds_pressure['t'].sel(pressure_level=950,
latitude=slice(*lat_range),
longitude=slice(*lon_range))
# Calculate lapse rate
LapseR = (Tz1000 - Tz950) / ((1000 - 950) / 0.12)
# Process each day
for dd in range(1, Nd + 1):
print(f'Processing day {dd}')
# Select data for this day
day_date = f'{yyyy:04d}-{mm:02d}-{dd:02d}'
Ta_day = Ta.sel(valid_time=day_date)
rad_day = rad.sel(valid_time=day_date)
vp_day = vp.sel(valid_time=day_date)
LR_day = LapseR.sel(valid_time=day_date)
# Calculate Tmax and Tmin
Tmax = Ta_day.max(dim='valid_time')
Tmin = Ta_day.min(dim='valid_time')
# Find lapse rates at max and min temperatures
LRmax = xr.full_like(Tmax, np.nan)
LRmin = xr.full_like(Tmin, np.nan)
for t_idx in range(len(Ta_day.valid_time)):
# For Tmax
mask_max = (Ta_day.isel(valid_time=t_idx) == Tmax)
LRmax = xr.where(mask_max, LR_day.isel(valid_time=t_idx), LRmax)
# For Tmin
mask_min = (Ta_day.isel(valid_time=t_idx) == Tmin)
LRmin = xr.where(mask_min, LR_day.isel(valid_time=t_idx), LRmin)
# Downscale Tmax to sea level, then to high resolution, then back to terrain
Tsmax = Tmax + LRmax * z
Tsmax_hr = Tsmax.interp(latitude=DEM.latitude,
longitude=DEM.longitude,
method='linear')
lapse_max_hr = LRmax.interp(latitude=DEM.latitude,
longitude=DEM.longitude,
method='linear')
Tmax_out = (Tsmax_hr - lapse_max_hr * DEM).astype(np.float32)
# Downscale Tmin
Tsmin = Tmin + LRmin * z
Tsmin_hr = Tsmin.interp(latitude=DEM.latitude,
longitude=DEM.longitude,
method='linear')
lapse_min_hr = LRmin.interp(latitude=DEM.latitude,
longitude=DEM.longitude,
method='linear')
Tmin_out = (Tsmin_hr - lapse_min_hr * DEM).astype(np.float32)
# Calculate effective temperature (radiation-weighted)
Teff = (Ta_day * rad_day).sum(dim='valid_time') / rad_day.sum(dim='valid_time')
Ta_range = Ta_day.max(dim='valid_time') - Ta_day.min(dim='valid_time')
kTeff = (Teff - Ta_day.min(dim='valid_time')) / Ta_range
# Interpolate to lower resolution (340 x 420)
#print("----------------------------")
#print(Ta.latitude.min(), Ta.latitude.max())
#print("----------------------------")
#target_lat = np.linspace(Ta.latitude.values.min(), Ta.latitude.values.max(), 340)
#print(target_lat)
#print("----------------------------")
#print(len(Ta.latitude.values), len(Ta.longitude.values))
#target_lon = np.linspace(Ta.longitude.values.min(), Ta.longitude.values.max(), 420)
res = 0.1
target_lat = np.linspace(lat_range[1]-res/2, lat_range[0]+res/2, 340)
target_lon = np.linspace(lon_range[0]+res/2, lon_range[1]-res/2, 420)
kTeff_out = kTeff.interp(latitude=target_lat,
longitude=target_lon,
method='linear').astype(np.float32)
# Calculate average temperature
Tavg = Ta_day.mean(dim='valid_time')
kTavg = (Tavg - Ta_day.min(dim='valid_time')) / Ta_range
kTavg_out = kTavg.interp(latitude=target_lat,
longitude=target_lon,
method='linear').astype(np.float32)
# Calculate effective vapor pressure
VPeff = (vp_day * rad_day).sum(dim='valid_time') / rad_day.sum(dim='valid_time')
VPeff_out = VPeff.interp(latitude=target_lat,
longitude=target_lon,
method='linear').astype(np.float32)
# Calculate day of year
doy = (datetime(yyyy, mm, dd) - datetime(yyyy - 1, 12, 31)).days
# Prepare data for writing
data_map = {
'Tmax': Tmax_out,
'Tmin': Tmin_out,
'kTeff': kTeff_out,
'kTavg': kTavg_out,
'VPeff': VPeff_out
}
# Write to NetCDF files
for varname in outvars:
tempfn = output_ds[varname]
map_data = data_map[varname]
if dd == 1 and mm == 1:
# Create new dataset for the year
create_yearly_dataset(tempfn, varname, yyyy, ydays, map_data)
# Append data for this day
append_daily_data(tempfn, varname, map_data, doy)
ddone += 1
tsofar = time.time() - start_time
timeleft = (tsofar / ddone) * (ydays - ddone) / 60 # in minutes
print(f'Done {day_date}')
print(f'Estimated time remaining: {timeleft:.0f} minutes')
# Close datasets to free memory
ds_temp.close()
ds_rad.close()
ds_dewt.close()
ds_pressure.close()
tsofar = time.time() - start_time
print(f'Completed processing year {yyyy} in {tsofar/60:.0f} minutes')
# Move files to output directory
for varname in outvars:
outfn = f'OzWALD.{varname}.{yyyy}.nc'
src = os.path.join(tempdir, outfn)
dst = os.path.join(outdir, outfn)
if os.path.exists(src):
shutil.move(src, dst)
print(f'Moved {outfn} to {outdir}')
return None
def create_yearly_dataset(filename, varname, year, ndays, template_data):
"""
Create output NetCDF file for entire year using xarray
Parameters:
-----------
filename : str
Output filename
varname : str
Variable name
year : int
Year
ndays : int
Number of days in year
template_data : xr.DataArray
Template data array with spatial dimensions
"""
# Create time coordinate
start_date = datetime(year, 1, 1)
time_coord = [start_date + timedelta(days=i) for i in range(ndays)]
# Create empty dataset with proper dimensions
ds = xr.Dataset({
varname: (['time', 'latitude', 'longitude'],
np.full((ndays, len(template_data.latitude),
len(template_data.longitude)), np.nan, dtype=np.float32))
}, coords={
'time': time_coord,
'latitude': template_data.latitude,
'longitude': template_data.longitude
})
# Add attributes
ds[varname].attrs['units'] = 'various'
ds[varname].attrs['long_name'] = varname
ds.attrs['description'] = f'OzWALD {varname} for {year}'
ds.attrs['created'] = datetime.now().isoformat()
# Write to file
ds.to_netcdf(filename, mode='w')
ds.close()
def append_daily_data(filename, varname, data, doy):
"""
Append daily data to existing NetCDF file
Parameters:
-----------
filename : str
NetCDF filename
varname : str
Variable name
data : xr.DataArray
Daily data to append
doy : int
Day of year (1-indexed)
"""
# Open in append mode and write data for specific day
ds = xr.open_dataset(filename)
ds[varname][doy - 1, :, :] = data.values
ds.to_netcdf(filename, mode='a')
ds.close()
if __name__ == '__main__':
downscale_era5_temp(2025)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment