Last active
May 26, 2023 14:56
-
-
Save calebrob6/83ad0ed9b78aad528fb027c1e37d3026 to your computer and use it in GitHub Desktop.
Implementation of four pansharpening methods based on code from https://github.com/ThomasWangWeiHong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""This is a heavily adapted version of the code from https://github.com/ThomasWangWeiHong/Simple-Pansharpening-Algorithms. | |
Some useful resources about pansharpening: | |
- https://github.com/mapbox/rio-pansharpen | |
- https://github.com/gerasimosmichalitsianos/pansharpen | |
- https://desktop.arcgis.com/en/arcmap/latest/manage-data/raster-and-images/fundamentals-of-panchromatic-sharpening.htm | |
- https://gist.github.com/celoyd/5bb5417b24801e0446ad5977cc3581e4 | |
""" | |
import os | |
from typing import Any, Dict, List, Optional | |
import time | |
import numpy as np | |
import rasterio | |
import rasterio.warp | |
import rasterio.mask | |
import cv2 | |
def pansharpen( | |
panchromatic_fn: str, | |
multispectral_fn: str, | |
output_fn: str, | |
geom: Optional[Dict[str, Any]] = None, | |
idx_red: int = 1, | |
idx_green: int = 2, | |
idx_blue: int = 3, | |
idx_nir: int = 4, | |
method: str = "simple_brovey", | |
W: float = 0.1, | |
output_bands: List[int] = [2, 1, 0, 3], | |
verbose: bool = True, | |
): | |
"""This function is used to pansharpen a given multispectral image using its | |
corresponding panchromatic image with one of the following algorithms: Simple Brovey, | |
Simple Mean, ESRI, Brovey. | |
This is a heavily adapted version of the code from https://github.com/ThomasWangWeiHong/Simple-Pansharpening-Algorithms. | |
Some notes: | |
- The output file is saved in cloud optimized geotiff format. | |
- If the height and width of the multispectral imagery is smaller than that of the | |
panchromatic data we assume it needs to be resampled and do so with linear | |
resampling. Else, the height and width of both must match (i.e. the multispectral | |
imagery should be manually resampled to match). | |
- We cast the inputs as float32s then cast back to the dtype of the panchromatic imagery | |
when actually performing the pansharpening steps. TODO: find out how this might | |
break things. | |
- We do _not_ check to see if the aspect ratios of the panchromatic data and | |
multispectral data approximately match. TODO: do this. | |
Args: | |
panchromatic_fn: The path to a GeoTIFF containing the panchromatic band | |
multispectral_fn: The path to a GeoTIFF containing the multispectral bands | |
output_fn: The path to write the pansharpened result (must not exist) | |
geom: Either a GeoJSON description of the subset area of the input to pansharpen | |
or None to pansharpen the entire input | |
idx_red: Zero based index of the red band in the multispectral image | |
idx_green: Zero based index of the green band in the multispectral image | |
idx_blue: Zero based index of the blue band in the multispectral image | |
idx_nir: Zero based index of the NIR band in the multispectral image | |
method: The type of pansharpening method to use. Options are: simple_brovey, | |
brovey, simple_mean, esri. | |
W: Weight value to use for the brovey pansharpening methods | |
output_bands: List of the multispectral bands to pansharpen and include in the | |
output file | |
verbose: Flag to enable printing timing information and progress during execution | |
Returns: | |
pansharpened_image: The pansharpened multispectral bands in C x H x W format | |
Raises: | |
ValueError: if the `method` parameter is not valid or the dimensions of the | |
inputs don't make sense | |
IOError: if the `output_fn` parameter already exists as a file | |
""" | |
if method not in {"simple_brovey", "simple_mean", "esri", "brovey"}: | |
raise ValueError(f"Method '{method}' not recognized") | |
if os.path.exists(output_fn): | |
raise IOError("The output file already exists") | |
eps = 0.00001 | |
########################### | |
# Read panchromatic data | |
########################### | |
tic = time.time() | |
with rasterio.open(panchromatic_fn) as f: | |
pan_metadata = f.profile | |
pan_crs = f.crs.to_string() | |
if geom is not None: | |
geom = rasterio.warp.transform_geom("epsg:4326", pan_crs, geom) | |
pan_img, pan_transform = rasterio.mask.mask(f, [geom], crop=True) | |
else: | |
pan_img = f.read() | |
pan_transform = f.transform | |
pan_img = pan_img.squeeze().astype(np.float32) | |
assert ( | |
len(pan_img.shape) == 2 | |
), "The panchromatic input should contain a single band" | |
pan_height = pan_img.shape[0] | |
pan_width = pan_img.shape[1] | |
if verbose: | |
print( | |
f"Finished reading the panchromatic data in {time.time() - tic:0.2f} seconds" | |
) | |
########################### | |
# Read multispectral data | |
########################### | |
tic = time.time() | |
with rasterio.open(multispectral_fn) as f: | |
assert ( | |
f.crs.to_string() == pan_crs | |
), "The two inputs should be in the same projection" | |
if geom is not None: | |
multispectral_img, _ = rasterio.mask.mask(f, [geom], crop=True) | |
else: | |
multispectral_img = f.read() | |
multispectral_img = np.rollaxis(multispectral_img, 0, 3).astype(np.float32) | |
multispectral_height = multispectral_img.shape[0] | |
multispectral_width = multispectral_img.shape[1] | |
if verbose: | |
print( | |
f"Finished reading the multispectral data in {time.time() - tic:0.2f} seconds" | |
) | |
########################### | |
# Resample the multispectral data if needed | |
########################### | |
tic = time.time() | |
if multispectral_height < pan_height and multispectral_width < pan_width: | |
multispectral_img = cv2.resize( | |
multispectral_img, (pan_width, pan_height), interpolation=cv2.INTER_LINEAR | |
) | |
elif multispectral_height == pan_height and multispectral_width == pan_width: | |
pass # The inputs are the same size, we assume that the multispectral data has | |
# been resampled to the dimensions of the panchromatic data | |
else: | |
raise ValueError( | |
f"The dimensions of the multispectral imagery ({multispectral_height}," | |
+ f" {multispectral_width}) are unexpected given the dimensions of" | |
+ f" the panchromatic image ({pan_height}, {pan_width})" | |
) | |
if verbose: | |
print( | |
f"Finished resizing the multispectral data in" | |
+ f" {time.time() - tic:0.2f} seconds" | |
) | |
########################### | |
# Get down to business | |
########################### | |
tic = time.time() | |
pansharpened_image = np.zeros( | |
(len(output_bands), pan_height, pan_width), dtype=pan_metadata["dtype"] | |
) | |
if method == "simple_brovey": | |
all_in = ( | |
multispectral_img[:, :, idx_red] | |
+ multispectral_img[:, :, idx_green] | |
+ multispectral_img[:, :, idx_blue] | |
+ multispectral_img[:, :, idx_nir] | |
) | |
for i, band in enumerate(output_bands): | |
pansharpened_image[i, :, :] = np.multiply( | |
multispectral_img[:, :, band], (pan_img / (all_in + eps)) | |
) | |
elif method == "simple_mean": | |
for i, band in enumerate(output_bands): | |
pansharpened_image[i, :, :] = 0.5 * ( | |
multispectral_img[:, :, band] + pan_img | |
) | |
elif method == "esri": | |
adj = pan_img - multispectral_img.mean(axis=2) | |
for i, band in enumerate(output_bands): | |
pansharpened_image[i, :, :] = multispectral_img[:, :, band] + adj | |
elif method == "brovey": | |
dnf = (pan_img - W * multispectral_img[:, :, idx_nir]) / ( | |
W * multispectral_img[:, :, idx_red] | |
+ W * multispectral_img[:, :, idx_green] | |
+ W * multispectral_img[:, :, idx_blue] | |
+ eps | |
) | |
for i, band in enumerate(output_bands): | |
pansharpened_image[i, :, :] = multispectral_img[:, :, band] * dnf | |
if verbose: | |
print(f"Finished pansharpening in {time.time() - tic:0.2f} seconds") | |
########################### | |
# Write output | |
########################### | |
tic = time.time() | |
new_metadata = { | |
"driver": "COG", | |
"dtype": pan_metadata["dtype"], | |
"nodata": None, | |
"width": pan_width, | |
"height": pan_height, | |
"count": len(output_bands), | |
"crs": pan_crs, | |
"transform": pan_transform, | |
"compress": "lzw", | |
"predictor": 2, | |
} | |
with rasterio.open(output_fn, "w", **new_metadata) as f: | |
f.write(pansharpened_image) | |
if verbose: | |
print(f"Finished writing output in {time.time() - tic:0.2f} seconds") | |
return pansharpened_image |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment