Skip to content

Instantly share code, notes, and snippets.

@calebrob6
Created February 28, 2026 03:56
Show Gist options
  • Select an option

  • Save calebrob6/e71adbc64a94e362ec7c251e4fbc5223 to your computer and use it in GitHub Desktop.

Select an option

Save calebrob6/e71adbc64a94e362ec7c251e4fbc5223 to your computer and use it in GitHub Desktop.
Compute per-block AEF embedding stats and PCA rasterization
#!/usr/bin/env python3
"""Compute mean and stdev of Alpha Earth Foundation embeddings per census block.
Required input data
-------------------
1. Census block shapefiles:
- Go to https://www.census.gov/cgi-bin/geo/shapefiles/index.php
- Select year "2025" (or desired year) and layer type "Blocks (2020)"
- Choose a state (e.g. "Washington") and download the ZIP
- Unzip to get tl_2025_53_tabblock20.shp (and companion files)
2. Alpha Earth Foundation (AEF) tile index:
- Download from https://data.source.coop/tge-labs/aef/v1/annual/aef_index.gpkg
- e.g.: curl -O https://data.source.coop/tge-labs/aef/v1/annual/aef_index.gpkg
The script will automatically download the AEF raster tiles it needs to a
local cache directory (default: ./aef_cache).
"""
import argparse
import os
import re
import sys
import threading
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import geopandas as gpd
import numpy as np
import pandas as pd
import rasterio
import rasterio.mask
import requests
from shapely.ops import transform as shapely_transform
from tqdm import tqdm
import pyproj
def parse_args():
parser = argparse.ArgumentParser(
description="Compute mean/stdev AEF embeddings per WA census block."
)
parser.add_argument("--year", type=int, default=2020, help="AEF data year (default: 2020)")
parser.add_argument("--cache-dir", type=str, default="./aef_cache", help="Local tile cache directory")
parser.add_argument("--output", type=str, default="wa_block_aef_stats.geoparquet", help="Output geoparquet path")
parser.add_argument("--blocks", type=str, default="tl_2025_53_tabblock20.shp", help="Blocks shapefile")
parser.add_argument("--index", type=str, default="aef_index.gpkg", help="AEF tile index geopackage")
parser.add_argument("--workers", type=int, default=8, help="Number of parallel threads (default: 8)")
parser.add_argument(
"--pool", type=str, default="mean_std", choices=["mean_std", "gem"],
help="Pooling method: 'mean_std' for mean+stdev, 'gem' for Generalized Mean (GeM, p=3) (default: mean_std)",
)
parser.add_argument("--gem-p", type=float, default=3.0, help="Exponent for GeM pooling (default: 3.0)")
parser.add_argument(
"--bbox", type=float, nargs=4, metavar=("WEST", "SOUTH", "EAST", "NORTH"),
help="Bounding box filter for blocks in EPSG:4326 (west south east north)",
)
parser.add_argument("--skip-water", action="store_true", default=True,
help="Skip blocks with zero land area (default: True)")
parser.add_argument("--no-skip-water", action="store_false", dest="skip_water",
help="Include all-water blocks")
parser.add_argument("--max-land-km2", type=float, default=100.0,
help="Exclude blocks with land area above this threshold in km² (default: 100)")
return parser.parse_args()
def s3_to_https(s3_path: str) -> tuple[str, str]:
"""Convert S3 path to HTTPS URLs for TIFF and VRT."""
# s3://us-west-2.opendata.source.coop/tge-labs/aef/v1/annual/2020/10N/xxx.tiff
# -> https://data.source.coop/tge-labs/aef/v1/annual/2020/10N/xxx.tiff
suffix = s3_path.split("opendata.source.coop/", 1)[1]
base = f"https://data.source.coop/{suffix}"
tiff_url = base
vrt_url = re.sub(r"\.tiff$", ".vrt", base)
return tiff_url, vrt_url
def load_and_filter(args):
"""Load blocks and index, filter index to year, find overlapping tiles."""
print(f"Loading blocks from {args.blocks}...")
blocks = gpd.read_file(args.blocks)
print(f" {len(blocks)} blocks loaded")
# Filter out all-water blocks
if args.skip_water:
n_before = len(blocks)
blocks = blocks[blocks["ALAND20"] > 0].copy()
print(f" Dropped {n_before - len(blocks)} all-water blocks, {len(blocks)} remaining")
# Filter out huge blocks
if args.max_land_km2 is not None:
max_land_m2 = args.max_land_km2 * 1e6
n_before = len(blocks)
blocks = blocks[blocks["ALAND20"] <= max_land_m2].copy()
print(f" Dropped {n_before - len(blocks)} blocks > {args.max_land_km2} km², {len(blocks)} remaining")
print(f"Loading AEF index from {args.index}...")
index = gpd.read_file(args.index)
# Filter index to requested year
index = index[index["year"] == args.year].copy()
print(f" {len(index)} tiles for year {args.year}")
# Reproject blocks to EPSG:4326 to match index CRS
blocks_4326 = blocks.to_crs("EPSG:4326")
# Apply bbox filter to reduce work
if args.bbox:
from shapely.geometry import box
west, south, east, north = args.bbox
bbox_geom = box(west, south, east, north)
mask = blocks_4326.intersects(bbox_geom)
blocks = blocks[mask].copy()
blocks_4326 = blocks_4326[mask].copy()
print(f" Filtered to {len(blocks)} blocks within bbox {args.bbox}")
# Spatial join: find which tiles overlap which blocks
print("Computing tile-block overlaps via spatial join...")
joined = gpd.sjoin(blocks_4326[["GEOID20", "geometry"]], index, how="inner", predicate="intersects")
# Build tile -> [block GEOID20s] mapping
tile_to_blocks = defaultdict(set)
block_to_tiles = defaultdict(set)
for _, row in joined.iterrows():
tile_path = row["path"]
geoid = row["GEOID20"]
tile_to_blocks[tile_path].add(geoid)
block_to_tiles[geoid].add(tile_path)
unique_tiles = set(tile_to_blocks.keys())
print(f" {len(unique_tiles)} unique tiles needed for {len(block_to_tiles)} blocks")
# Build a lookup from tile path to its CRS
tile_crs = {}
for _, row in index.iterrows():
if row["path"] in unique_tiles:
tile_crs[row["path"]] = row["crs"]
return blocks, blocks_4326, tile_to_blocks, block_to_tiles, tile_crs
def download_tiles(tile_paths, tile_crs_map, cache_dir, year, workers):
"""Download TIFF and VRT for each tile in parallel, rewrite VRT source path."""
cache = Path(cache_dir)
local_vrts = {} # tile_path -> local VRT path
tiles_to_download = []
for tile_path in tile_paths:
tiff_url, vrt_url = s3_to_https(tile_path)
parts = tiff_url.rsplit("/", 2)
utm_zone = parts[-2]
filename = parts[-1]
vrt_filename = filename.replace(".tiff", ".vrt")
tile_dir = cache / str(year) / utm_zone
local_tiff = tile_dir / filename
local_vrt = tile_dir / vrt_filename
local_vrts[tile_path] = local_vrt
if local_vrt.exists() and local_tiff.exists():
continue
tiles_to_download.append((tile_path, tiff_url, vrt_url, tile_dir, local_tiff, local_vrt))
if not tiles_to_download:
print("All tiles already cached.")
return local_vrts
print(f"Downloading {len(tiles_to_download)} tiles with {workers} threads...")
pbar = tqdm(total=len(tiles_to_download), desc="Downloading tiles")
def _download_one(item):
tile_path, tiff_url, vrt_url, tile_dir, local_tiff, local_vrt = item
tile_dir.mkdir(parents=True, exist_ok=True)
if not local_tiff.exists():
_download_file(tiff_url, local_tiff)
if not local_vrt.exists():
resp = requests.get(vrt_url, timeout=60)
resp.raise_for_status()
vrt_content = resp.text
vrt_content = re.sub(
r"<SourceDataset[^>]*>[^<]+</SourceDataset>",
f'<SourceDataset relativeToVRT="1">{local_tiff.name}</SourceDataset>',
vrt_content,
)
local_vrt.write_text(vrt_content)
pbar.update(1)
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = [pool.submit(_download_one, item) for item in tiles_to_download]
for f in as_completed(futures):
f.result() # raise any exceptions
pbar.close()
return local_vrts
def _download_file(url, dest):
"""Download a file with streaming and progress."""
resp = requests.get(url, stream=True, timeout=300)
resp.raise_for_status()
total = int(resp.headers.get("content-length", 0))
with open(dest, "wb") as f:
with tqdm(total=total, unit="B", unit_scale=True, desc=dest.name, leave=False) as pbar:
for chunk in resp.iter_content(chunk_size=1024 * 1024):
f.write(chunk)
pbar.update(len(chunk))
def _process_one_tile(tile_path, vrt_path, tile_crs_str, geoids, block_geom_4326, block_pbar):
"""Process a single tile: load into memory then mask each overlapping block."""
n_bands = 64
nodata = -128
results = {} # geoid -> (n_bands, n_pixels) array
transformer = pyproj.Transformer.from_crs("EPSG:4326", tile_crs_str, always_xy=True)
# Read entire VRT into a MemoryFile so repeated mask() calls hit RAM, not disk
with rasterio.open(str(vrt_path)) as src:
profile = src.profile.copy()
profile.update(driver="GTiff")
data = src.read()
memfile = rasterio.MemoryFile()
with memfile.open(**profile) as mem_dst:
mem_dst.write(data)
del data
with memfile.open() as src:
for geoid in geoids:
geom_4326 = block_geom_4326[geoid]
geom_utm = shapely_transform(transformer.transform, geom_4326)
try:
out_image, _ = rasterio.mask.mask(
src, [geom_utm], crop=True, all_touched=True, nodata=nodata
)
except ValueError:
if block_pbar is not None:
block_pbar.update(1)
continue
pixels = out_image.reshape(n_bands, -1)
valid_mask = np.all(pixels != nodata, axis=0)
if valid_mask.sum() == 0:
if block_pbar is not None:
block_pbar.update(1)
continue
valid_pixels = pixels[:, valid_mask].astype(np.float32)
results[geoid] = valid_pixels
if block_pbar is not None:
block_pbar.update(1)
memfile.close()
return results
def compute_block_stats(blocks, blocks_4326, tile_to_blocks, block_to_tiles, tile_crs_map, local_vrts, workers, pool_method="mean_std", gem_p=3.0):
"""Compute per-block statistics across all 64 embedding dimensions using threaded I/O."""
n_bands = 64
# For blocks spanning multiple tiles we accumulate pixels
block_pixels = defaultdict(list) # geoid -> list of (n_bands, n_pixels) arrays
block_geom_4326 = dict(zip(blocks_4326["GEOID20"], blocks_4326.geometry))
sorted_tiles = sorted(local_vrts.keys())
total_block_tile_pairs = sum(len(geoids) for geoids in tile_to_blocks.values())
print(f"Processing {len(sorted_tiles)} tiles ({total_block_tile_pairs} block×tile pairs) with {workers} threads...")
tile_pbar = tqdm(total=len(sorted_tiles), desc="Tiles completed", position=0)
block_pbar = tqdm(total=total_block_tile_pairs, desc="Blocks masked", position=1)
lock = threading.Lock()
def _process_and_merge(tile_path):
vrt_path = local_vrts[tile_path]
tile_crs_str = tile_crs_map[tile_path]
geoids = tile_to_blocks[tile_path]
tile_results = _process_one_tile(tile_path, vrt_path, tile_crs_str, geoids, block_geom_4326, block_pbar)
with lock:
for geoid, pixels in tile_results.items():
block_pixels[geoid].append(pixels)
tile_pbar.update(1)
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = [pool.submit(_process_and_merge, tp) for tp in sorted_tiles]
for f in as_completed(futures):
f.result()
tile_pbar.close()
block_pbar.close()
# Aggregate stats
print(f"Aggregating statistics (method={pool_method})...")
results = []
band_names = [f"A{i:02d}" for i in range(n_bands)]
all_geoids = set(blocks["GEOID20"])
for geoid in tqdm(sorted(all_geoids), desc="Computing stats"):
row = {"GEOID20": geoid}
pixel_arrays = block_pixels.get(geoid, [])
if pixel_arrays:
all_pixels = np.concatenate(pixel_arrays, axis=1) # (64, total_pixels)
if pool_method == "gem":
# GeM: ( mean(x^p) )^(1/p), applied per band
# Shift to non-negative range since int8 values can be negative
shifted = all_pixels - all_pixels.min(axis=1, keepdims=True)
shifted = np.maximum(shifted, 1e-6) # avoid zero
gem = np.power(np.mean(np.power(shifted, gem_p), axis=1), 1.0 / gem_p)
for i, name in enumerate(band_names):
row[f"gem_{name}"] = gem[i]
else:
means = all_pixels.mean(axis=1)
stds = all_pixels.std(axis=1)
for i, name in enumerate(band_names):
row[f"mean_{name}"] = means[i]
row[f"std_{name}"] = stds[i]
else:
if pool_method == "gem":
for i, name in enumerate(band_names):
row[f"gem_{name}"] = np.nan
else:
for i, name in enumerate(band_names):
row[f"mean_{name}"] = np.nan
row[f"std_{name}"] = np.nan
results.append(row)
return pd.DataFrame(results)
def main():
args = parse_args()
# Step 1: Load data and find overlaps
blocks, blocks_4326, tile_to_blocks, block_to_tiles, tile_crs_map = load_and_filter(args)
# Step 2: Download tiles
local_vrts = download_tiles(
tile_paths=set(tile_to_blocks.keys()),
tile_crs_map=tile_crs_map,
cache_dir=args.cache_dir,
year=args.year,
workers=args.workers,
)
# Step 3: Compute per-block stats
df = compute_block_stats(blocks, blocks_4326, tile_to_blocks, block_to_tiles, tile_crs_map, local_vrts, args.workers, args.pool, args.gem_p)
# Step 4: Save output as geoparquet (join stats with block geometries)
print(f"Saving results to {args.output}...")
gdf = blocks[["GEOID20", "geometry"]].merge(df, on="GEOID20", how="inner")
gdf = gpd.GeoDataFrame(gdf, geometry="geometry")
gdf.to_parquet(args.output, index=False)
print(f"Done. {len(gdf)} blocks, {len(gdf.columns) - 2} stat columns.")
if __name__ == "__main__":
main()
#!/usr/bin/env python3
"""Reduce AEF block stats to 3-component PCA and rasterize as a COG."""
import argparse
import geopandas as gpd
import numpy as np
import rasterio
from rasterio.enums import Resampling
from rasterio.features import rasterize
from rasterio.transform import from_bounds
from sklearn.decomposition import PCA
def parse_args():
parser = argparse.ArgumentParser(
description="PCA-reduce AEF block stats and rasterize to a 3-band COG."
)
parser.add_argument("input", help="Input geoparquet with block stats (e.g. wa_block_aef_stats.geoparquet)")
parser.add_argument("-o", "--output", default=None, help="Output COG path (default: <input_stem>_pca.tif)")
parser.add_argument("--resolution", type=float, default=10.0, help="Raster resolution in CRS units (default: 10m)")
parser.add_argument("--crs", default="EPSG:32610", help="Target CRS for rasterization (default: EPSG:32610)")
parser.add_argument("--n-components", type=int, default=3, help="Number of PCA components (default: 3)")
parser.add_argument("--clip-percentile", type=float, default=2.0, help="Percentile for uint8 clipping (default: 2)")
return parser.parse_args()
def main():
args = parse_args()
if args.output is None:
from pathlib import Path
args.output = str(Path(args.input).stem) + "_pca.tif"
# 1. Load data
print(f"Loading {args.input}...")
gdf = gpd.read_parquet(args.input)
gdf = gdf.to_crs(args.crs)
# 2. PCA on stat columns
stat_cols = [c for c in gdf.columns if c.startswith(("mean_", "std_", "gem_"))]
print(f" {len(gdf)} blocks, {len(stat_cols)} stat columns")
X = gdf[stat_cols].values.astype(np.float32)
# Drop rows with NaN for PCA fit, but keep them for output (they'll get nodata)
valid = ~np.isnan(X).any(axis=1)
pca = PCA(n_components=args.n_components)
pca.fit(X[valid])
print(f" PCA explained variance: {pca.explained_variance_ratio_}")
pcs = np.full((len(X), args.n_components), np.nan, dtype=np.float32)
pcs[valid] = pca.transform(X[valid])
# 3. Rasterize at target resolution
bounds = gdf.total_bounds
res = args.resolution
width = int(np.ceil((bounds[2] - bounds[0]) / res))
height = int(np.ceil((bounds[3] - bounds[1]) / res))
transform = from_bounds(bounds[0], bounds[1], bounds[0] + width * res, bounds[1] + height * res, width, height)
print(f" Raster: {width} x {height} px ({width * height / 1e6:.1f} Mpx)")
lo_p = args.clip_percentile
hi_p = 100 - args.clip_percentile
bands = []
for i in range(args.n_components):
vals = pcs[:, i]
valid_vals = vals[~np.isnan(vals)]
lo, hi = np.percentile(valid_vals, [lo_p, hi_p])
clipped = np.clip(vals, lo, hi)
scaled = ((clipped - lo) / (hi - lo) * 254 + 1).astype(np.uint8)
scaled[np.isnan(vals)] = 0 # nodata
shapes = list(zip(gdf.geometry, scaled))
band = rasterize(
shapes, out_shape=(height, width), transform=transform,
fill=0, dtype=np.uint8, all_touched=True,
)
bands.append(band)
print(f" Band {i + 1} (PC{i + 1}): min={band[band > 0].min()}, max={band[band > 0].max()}")
# 4. Write COG
profile = {
"driver": "GTiff",
"dtype": "uint8",
"width": width,
"height": height,
"count": args.n_components,
"crs": args.crs,
"transform": transform,
"nodata": 0,
"compress": "deflate",
"tiled": True,
"blockxsize": 512,
"blockysize": 512,
}
with rasterio.open(args.output, "w", **profile) as dst:
for i in range(args.n_components):
dst.write(bands[i], i + 1)
dst.set_band_description(i + 1, f"PC{i + 1} ({pca.explained_variance_ratio_[i]:.1%})")
with rasterio.open(args.output, "r+") as dst:
dst.build_overviews([2, 4, 8, 16, 32], Resampling.average)
dst.update_tags(ns="rio_overview", resampling="average")
print(f"Saved {args.output}")
if __name__ == "__main__":
main()
@calebrob6
Copy link
Author

Computes statistics of AlphaEarth embeddings over Census polygons (and PCA visualizations of the embeddings)

E.g. over Census blocks over Seattle

image

@calebrob6
Copy link
Author

Over all WA

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment