Last active
April 26, 2023 08:38
-
-
Save sgillies/8d0b77699dea4de9ca1ab6ec5552e1c8 to your computer and use it in GitHub Desktop.
Simple and fancy ways of computing stats of a rasterio dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Three ways of computing stats of a rasterio dataset.""" | |
import dask.array | |
import numpy | |
import rasterio | |
from rasterio.windows import Window | |
# Compute mean of entire dataset read into memory. Easiest, but doesn't | |
# scale for very large datasets. | |
with rasterio.open( | |
"/home/seangillies/projects/rasterio/tests/data/RGB.byte.tif" | |
) as dataset: | |
print(dataset.read(masked=True).mean(axis=(1, 2))) | |
# Compute mean, block-wise. Saves memory, but is not especially fast. | |
with rasterio.open( | |
"/home/seangillies/projects/rasterio/tests/data/RGB.byte.tif" | |
) as dataset: | |
count = numpy.zeros((dataset.count,)) | |
mean = numpy.zeros((dataset.count,)) | |
for ij, window in dataset.block_windows(): | |
data = dataset.read(window=window, masked=True) | |
w_count = data.count(axis=(1, 2)) | |
if w_count.any(): | |
count = count + w_count | |
mean = mean + data.mean(axis=(1, 2)) * w_count | |
print(mean / count) | |
# Compute mean using dask.array and a helper class. | |
class DatasetArray: | |
def __init__(self, dataset) -> None: | |
self.dataset = dataset | |
self.ndim = 3 | |
self.shape = (dataset.count, dataset.height, dataset.width) | |
self.dtype = dataset.dtypes[0] | |
def __getitem__(self, key): | |
row_slice = key[-1] | |
col_slice = key[-2] | |
if len(key) == 3: | |
band_slice = key[0] | |
indexes = list(range(band_slice.start, band_slice.stop)) | |
indexes = [i + 1 for i in indexes] | |
else: | |
indexes = self.dataset.indexes | |
window = Window.from_slices(row_slice, col_slice) | |
return self.dataset.read(indexes=indexes, window=window, masked=True) | |
with rasterio.open( | |
"/home/seangillies/projects/rasterio/tests/data/RGB.byte.tif" | |
) as dataset: | |
arr = dask.array.from_array(DatasetArray(dataset)) | |
print(arr.mean(axis=(1, 2)).compute()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment