Skip to content

Instantly share code, notes, and snippets.

@lucaswells
Last active September 26, 2024 18:41
Show Gist options
  • Save lucaswells/fd2fd73c513872966c1a0257afee1887 to your computer and use it in GitHub Desktop.
Save lucaswells/fd2fd73c513872966c1a0257afee1887 to your computer and use it in GitHub Desktop.
Convert GeoTIFF to Zarr array
import matplotlib.pyplot as plt
import rasterio
from rasterio.windows import Window
import time
import zarr
def convert(raster_filepath, chunk_mbs=1):
"""
Converts raster file to chunked and compressed zarr array. Tested
with GeoTIFF format, but should work with other raster formats
compatible with rasterio
Parameters
----------
raster_filepath : string
Path and filename of input raster
chunk_mbs : float, optional
Desired size (MB) of chunks in zarr file
"""
# Open the raster file
raster = rasterio.open(raster_filepath)
# Extract metadata we need for initializing the zarr array
width = raster.width
height = raster.height
n_bands = raster.count
dtype = raster.dtypes[0].lower()
# Specify the number of bytes for common raster
# datatypes so we can compute chunk shape
dtype_bytes = {
'byte' : 1,
'uint16' : 2,
'int16' : 2,
'uint32' : 4,
'int32' : 4,
'float32' : 4,
'float64' : 8,
}
# Compute the chunk shape
chunk_shape = (int((1e6/dtype_bytes[dtype])**0.5),)*2
# Setup zarr file
zarray_filepath = f"{'.'.join(raster_filepath.split('.')[:-1])}.zarr"
zarray = zarr.open(
zarray_filepath,
mode='w',
shape=(height, width, n_bands),
chunks=chunk_shape,
dtype=dtype
)
# Let's add the metadata to the zarr file
zarray.attrs['width'] = width
zarray.attrs['height'] = height
zarray.attrs['count'] = n_bands
zarray.attrs['dtype'] = dtype
zarray.attrs['bounds'] = raster.bounds
zarray.attrs['transform'] = raster.transform
zarray.attrs['crs'] = raster.crs.to_string()
# Loop through bands; raster band indecies starts at 1
for k in raster.indexes:
# Now we'll read and write the data according to the chuck size
# to prevent memory saturation
for j in range(0, width+chunk_shape[1], chunk_shape[1]):
print(f'column {j} of {width}')
j = width if j > width else j
for i in range(0, height+chunk_shape[0], chunk_shape[0]):
i = height if i > height else i
data = raster.read(k, window=Window(j, i, chunk_shape[1], chunk_shape[0]))
zarray[i:i+chunk_shape[0], j:j+chunk_shape[1], k-1] = data
# Close the raster dataset; no need to close the zarr file
raster.close()
def test(raster_filepath, zarr_filepath, window):
"""
Validate that the data were correctly copied to the zarr file
Parameters
----------
raster_filepath: string
Path and filename of the raster file
zarr_filepath: string
Path and filename of the zarr file
window: 4-tuple
Window to extract sub arrray; (x1, y1, x2, y2)
"""
x1, y1, x2, y2 = window
# Read a subarray using rasterio
st = time.time()
raster = rasterio.open(raster_filepath)
raster_sub = raster.read(1, window=Window(x1, y1, x2-x1, y2-y1))
raster_time = time.time() - st
# Read a subarray using zarr
st = time.time()
zarray = zarr.open(zarr_filepath)
zarray_sub = zarray[y1:y2, x1:x2]
zarr_time = time.time() - st
# Check for visual differences
_, ax = plt.subplots(1,2)
ax[0].imshow(raster_sub)
ax[0].set_title(f'RasterIO ({raster_time:.4f} secs)')
ax[1].imshow(zarray_sub)
ax[1].set_title(f'Zarr ({zarr_time:.4f} secs)')
plt.show()
# Check for numerical differences
print((raster_sub - zarray_sub).sum())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment