Created
February 28, 2026 03:56
-
-
Save calebrob6/e71adbc64a94e362ec7c251e4fbc5223 to your computer and use it in GitHub Desktop.
Compute per-block AEF embedding stats and PCA rasterization
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """Compute mean and stdev of Alpha Earth Foundation embeddings per census block. | |
| Required input data | |
| ------------------- | |
| 1. Census block shapefiles: | |
| - Go to https://www.census.gov/cgi-bin/geo/shapefiles/index.php | |
| - Select year "2025" (or desired year) and layer type "Blocks (2020)" | |
| - Choose a state (e.g. "Washington") and download the ZIP | |
| - Unzip to get tl_2025_53_tabblock20.shp (and companion files) | |
| 2. Alpha Earth Foundation (AEF) tile index: | |
| - Download from https://data.source.coop/tge-labs/aef/v1/annual/aef_index.gpkg | |
| - e.g.: curl -O https://data.source.coop/tge-labs/aef/v1/annual/aef_index.gpkg | |
| The script will automatically download the AEF raster tiles it needs to a | |
| local cache directory (default: ./aef_cache). | |
| """ | |
| import argparse | |
| import os | |
| import re | |
| import sys | |
| import threading | |
| from collections import defaultdict | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from pathlib import Path | |
| import geopandas as gpd | |
| import numpy as np | |
| import pandas as pd | |
| import rasterio | |
| import rasterio.mask | |
| import requests | |
| from shapely.ops import transform as shapely_transform | |
| from tqdm import tqdm | |
| import pyproj | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description="Compute mean/stdev AEF embeddings per WA census block." | |
| ) | |
| parser.add_argument("--year", type=int, default=2020, help="AEF data year (default: 2020)") | |
| parser.add_argument("--cache-dir", type=str, default="./aef_cache", help="Local tile cache directory") | |
| parser.add_argument("--output", type=str, default="wa_block_aef_stats.geoparquet", help="Output geoparquet path") | |
| parser.add_argument("--blocks", type=str, default="tl_2025_53_tabblock20.shp", help="Blocks shapefile") | |
| parser.add_argument("--index", type=str, default="aef_index.gpkg", help="AEF tile index geopackage") | |
| parser.add_argument("--workers", type=int, default=8, help="Number of parallel threads (default: 8)") | |
| parser.add_argument( | |
| "--pool", type=str, default="mean_std", choices=["mean_std", "gem"], | |
| help="Pooling method: 'mean_std' for mean+stdev, 'gem' for Generalized Mean (GeM, p=3) (default: mean_std)", | |
| ) | |
| parser.add_argument("--gem-p", type=float, default=3.0, help="Exponent for GeM pooling (default: 3.0)") | |
| parser.add_argument( | |
| "--bbox", type=float, nargs=4, metavar=("WEST", "SOUTH", "EAST", "NORTH"), | |
| help="Bounding box filter for blocks in EPSG:4326 (west south east north)", | |
| ) | |
| parser.add_argument("--skip-water", action="store_true", default=True, | |
| help="Skip blocks with zero land area (default: True)") | |
| parser.add_argument("--no-skip-water", action="store_false", dest="skip_water", | |
| help="Include all-water blocks") | |
| parser.add_argument("--max-land-km2", type=float, default=100.0, | |
| help="Exclude blocks with land area above this threshold in km² (default: 100)") | |
| return parser.parse_args() | |
| def s3_to_https(s3_path: str) -> tuple[str, str]: | |
| """Convert S3 path to HTTPS URLs for TIFF and VRT.""" | |
| # s3://us-west-2.opendata.source.coop/tge-labs/aef/v1/annual/2020/10N/xxx.tiff | |
| # -> https://data.source.coop/tge-labs/aef/v1/annual/2020/10N/xxx.tiff | |
| suffix = s3_path.split("opendata.source.coop/", 1)[1] | |
| base = f"https://data.source.coop/{suffix}" | |
| tiff_url = base | |
| vrt_url = re.sub(r"\.tiff$", ".vrt", base) | |
| return tiff_url, vrt_url | |
| def load_and_filter(args): | |
| """Load blocks and index, filter index to year, find overlapping tiles.""" | |
| print(f"Loading blocks from {args.blocks}...") | |
| blocks = gpd.read_file(args.blocks) | |
| print(f" {len(blocks)} blocks loaded") | |
| # Filter out all-water blocks | |
| if args.skip_water: | |
| n_before = len(blocks) | |
| blocks = blocks[blocks["ALAND20"] > 0].copy() | |
| print(f" Dropped {n_before - len(blocks)} all-water blocks, {len(blocks)} remaining") | |
| # Filter out huge blocks | |
| if args.max_land_km2 is not None: | |
| max_land_m2 = args.max_land_km2 * 1e6 | |
| n_before = len(blocks) | |
| blocks = blocks[blocks["ALAND20"] <= max_land_m2].copy() | |
| print(f" Dropped {n_before - len(blocks)} blocks > {args.max_land_km2} km², {len(blocks)} remaining") | |
| print(f"Loading AEF index from {args.index}...") | |
| index = gpd.read_file(args.index) | |
| # Filter index to requested year | |
| index = index[index["year"] == args.year].copy() | |
| print(f" {len(index)} tiles for year {args.year}") | |
| # Reproject blocks to EPSG:4326 to match index CRS | |
| blocks_4326 = blocks.to_crs("EPSG:4326") | |
| # Apply bbox filter to reduce work | |
| if args.bbox: | |
| from shapely.geometry import box | |
| west, south, east, north = args.bbox | |
| bbox_geom = box(west, south, east, north) | |
| mask = blocks_4326.intersects(bbox_geom) | |
| blocks = blocks[mask].copy() | |
| blocks_4326 = blocks_4326[mask].copy() | |
| print(f" Filtered to {len(blocks)} blocks within bbox {args.bbox}") | |
| # Spatial join: find which tiles overlap which blocks | |
| print("Computing tile-block overlaps via spatial join...") | |
| joined = gpd.sjoin(blocks_4326[["GEOID20", "geometry"]], index, how="inner", predicate="intersects") | |
| # Build tile -> [block GEOID20s] mapping | |
| tile_to_blocks = defaultdict(set) | |
| block_to_tiles = defaultdict(set) | |
| for _, row in joined.iterrows(): | |
| tile_path = row["path"] | |
| geoid = row["GEOID20"] | |
| tile_to_blocks[tile_path].add(geoid) | |
| block_to_tiles[geoid].add(tile_path) | |
| unique_tiles = set(tile_to_blocks.keys()) | |
| print(f" {len(unique_tiles)} unique tiles needed for {len(block_to_tiles)} blocks") | |
| # Build a lookup from tile path to its CRS | |
| tile_crs = {} | |
| for _, row in index.iterrows(): | |
| if row["path"] in unique_tiles: | |
| tile_crs[row["path"]] = row["crs"] | |
| return blocks, blocks_4326, tile_to_blocks, block_to_tiles, tile_crs | |
| def download_tiles(tile_paths, tile_crs_map, cache_dir, year, workers): | |
| """Download TIFF and VRT for each tile in parallel, rewrite VRT source path.""" | |
| cache = Path(cache_dir) | |
| local_vrts = {} # tile_path -> local VRT path | |
| tiles_to_download = [] | |
| for tile_path in tile_paths: | |
| tiff_url, vrt_url = s3_to_https(tile_path) | |
| parts = tiff_url.rsplit("/", 2) | |
| utm_zone = parts[-2] | |
| filename = parts[-1] | |
| vrt_filename = filename.replace(".tiff", ".vrt") | |
| tile_dir = cache / str(year) / utm_zone | |
| local_tiff = tile_dir / filename | |
| local_vrt = tile_dir / vrt_filename | |
| local_vrts[tile_path] = local_vrt | |
| if local_vrt.exists() and local_tiff.exists(): | |
| continue | |
| tiles_to_download.append((tile_path, tiff_url, vrt_url, tile_dir, local_tiff, local_vrt)) | |
| if not tiles_to_download: | |
| print("All tiles already cached.") | |
| return local_vrts | |
| print(f"Downloading {len(tiles_to_download)} tiles with {workers} threads...") | |
| pbar = tqdm(total=len(tiles_to_download), desc="Downloading tiles") | |
| def _download_one(item): | |
| tile_path, tiff_url, vrt_url, tile_dir, local_tiff, local_vrt = item | |
| tile_dir.mkdir(parents=True, exist_ok=True) | |
| if not local_tiff.exists(): | |
| _download_file(tiff_url, local_tiff) | |
| if not local_vrt.exists(): | |
| resp = requests.get(vrt_url, timeout=60) | |
| resp.raise_for_status() | |
| vrt_content = resp.text | |
| vrt_content = re.sub( | |
| r"<SourceDataset[^>]*>[^<]+</SourceDataset>", | |
| f'<SourceDataset relativeToVRT="1">{local_tiff.name}</SourceDataset>', | |
| vrt_content, | |
| ) | |
| local_vrt.write_text(vrt_content) | |
| pbar.update(1) | |
| with ThreadPoolExecutor(max_workers=workers) as pool: | |
| futures = [pool.submit(_download_one, item) for item in tiles_to_download] | |
| for f in as_completed(futures): | |
| f.result() # raise any exceptions | |
| pbar.close() | |
| return local_vrts | |
| def _download_file(url, dest): | |
| """Download a file with streaming and progress.""" | |
| resp = requests.get(url, stream=True, timeout=300) | |
| resp.raise_for_status() | |
| total = int(resp.headers.get("content-length", 0)) | |
| with open(dest, "wb") as f: | |
| with tqdm(total=total, unit="B", unit_scale=True, desc=dest.name, leave=False) as pbar: | |
| for chunk in resp.iter_content(chunk_size=1024 * 1024): | |
| f.write(chunk) | |
| pbar.update(len(chunk)) | |
| def _process_one_tile(tile_path, vrt_path, tile_crs_str, geoids, block_geom_4326, block_pbar): | |
| """Process a single tile: load into memory then mask each overlapping block.""" | |
| n_bands = 64 | |
| nodata = -128 | |
| results = {} # geoid -> (n_bands, n_pixels) array | |
| transformer = pyproj.Transformer.from_crs("EPSG:4326", tile_crs_str, always_xy=True) | |
| # Read entire VRT into a MemoryFile so repeated mask() calls hit RAM, not disk | |
| with rasterio.open(str(vrt_path)) as src: | |
| profile = src.profile.copy() | |
| profile.update(driver="GTiff") | |
| data = src.read() | |
| memfile = rasterio.MemoryFile() | |
| with memfile.open(**profile) as mem_dst: | |
| mem_dst.write(data) | |
| del data | |
| with memfile.open() as src: | |
| for geoid in geoids: | |
| geom_4326 = block_geom_4326[geoid] | |
| geom_utm = shapely_transform(transformer.transform, geom_4326) | |
| try: | |
| out_image, _ = rasterio.mask.mask( | |
| src, [geom_utm], crop=True, all_touched=True, nodata=nodata | |
| ) | |
| except ValueError: | |
| if block_pbar is not None: | |
| block_pbar.update(1) | |
| continue | |
| pixels = out_image.reshape(n_bands, -1) | |
| valid_mask = np.all(pixels != nodata, axis=0) | |
| if valid_mask.sum() == 0: | |
| if block_pbar is not None: | |
| block_pbar.update(1) | |
| continue | |
| valid_pixels = pixels[:, valid_mask].astype(np.float32) | |
| results[geoid] = valid_pixels | |
| if block_pbar is not None: | |
| block_pbar.update(1) | |
| memfile.close() | |
| return results | |
| def compute_block_stats(blocks, blocks_4326, tile_to_blocks, block_to_tiles, tile_crs_map, local_vrts, workers, pool_method="mean_std", gem_p=3.0): | |
| """Compute per-block statistics across all 64 embedding dimensions using threaded I/O.""" | |
| n_bands = 64 | |
| # For blocks spanning multiple tiles we accumulate pixels | |
| block_pixels = defaultdict(list) # geoid -> list of (n_bands, n_pixels) arrays | |
| block_geom_4326 = dict(zip(blocks_4326["GEOID20"], blocks_4326.geometry)) | |
| sorted_tiles = sorted(local_vrts.keys()) | |
| total_block_tile_pairs = sum(len(geoids) for geoids in tile_to_blocks.values()) | |
| print(f"Processing {len(sorted_tiles)} tiles ({total_block_tile_pairs} block×tile pairs) with {workers} threads...") | |
| tile_pbar = tqdm(total=len(sorted_tiles), desc="Tiles completed", position=0) | |
| block_pbar = tqdm(total=total_block_tile_pairs, desc="Blocks masked", position=1) | |
| lock = threading.Lock() | |
| def _process_and_merge(tile_path): | |
| vrt_path = local_vrts[tile_path] | |
| tile_crs_str = tile_crs_map[tile_path] | |
| geoids = tile_to_blocks[tile_path] | |
| tile_results = _process_one_tile(tile_path, vrt_path, tile_crs_str, geoids, block_geom_4326, block_pbar) | |
| with lock: | |
| for geoid, pixels in tile_results.items(): | |
| block_pixels[geoid].append(pixels) | |
| tile_pbar.update(1) | |
| with ThreadPoolExecutor(max_workers=workers) as pool: | |
| futures = [pool.submit(_process_and_merge, tp) for tp in sorted_tiles] | |
| for f in as_completed(futures): | |
| f.result() | |
| tile_pbar.close() | |
| block_pbar.close() | |
| # Aggregate stats | |
| print(f"Aggregating statistics (method={pool_method})...") | |
| results = [] | |
| band_names = [f"A{i:02d}" for i in range(n_bands)] | |
| all_geoids = set(blocks["GEOID20"]) | |
| for geoid in tqdm(sorted(all_geoids), desc="Computing stats"): | |
| row = {"GEOID20": geoid} | |
| pixel_arrays = block_pixels.get(geoid, []) | |
| if pixel_arrays: | |
| all_pixels = np.concatenate(pixel_arrays, axis=1) # (64, total_pixels) | |
| if pool_method == "gem": | |
| # GeM: ( mean(x^p) )^(1/p), applied per band | |
| # Shift to non-negative range since int8 values can be negative | |
| shifted = all_pixels - all_pixels.min(axis=1, keepdims=True) | |
| shifted = np.maximum(shifted, 1e-6) # avoid zero | |
| gem = np.power(np.mean(np.power(shifted, gem_p), axis=1), 1.0 / gem_p) | |
| for i, name in enumerate(band_names): | |
| row[f"gem_{name}"] = gem[i] | |
| else: | |
| means = all_pixels.mean(axis=1) | |
| stds = all_pixels.std(axis=1) | |
| for i, name in enumerate(band_names): | |
| row[f"mean_{name}"] = means[i] | |
| row[f"std_{name}"] = stds[i] | |
| else: | |
| if pool_method == "gem": | |
| for i, name in enumerate(band_names): | |
| row[f"gem_{name}"] = np.nan | |
| else: | |
| for i, name in enumerate(band_names): | |
| row[f"mean_{name}"] = np.nan | |
| row[f"std_{name}"] = np.nan | |
| results.append(row) | |
| return pd.DataFrame(results) | |
| def main(): | |
| args = parse_args() | |
| # Step 1: Load data and find overlaps | |
| blocks, blocks_4326, tile_to_blocks, block_to_tiles, tile_crs_map = load_and_filter(args) | |
| # Step 2: Download tiles | |
| local_vrts = download_tiles( | |
| tile_paths=set(tile_to_blocks.keys()), | |
| tile_crs_map=tile_crs_map, | |
| cache_dir=args.cache_dir, | |
| year=args.year, | |
| workers=args.workers, | |
| ) | |
| # Step 3: Compute per-block stats | |
| df = compute_block_stats(blocks, blocks_4326, tile_to_blocks, block_to_tiles, tile_crs_map, local_vrts, args.workers, args.pool, args.gem_p) | |
| # Step 4: Save output as geoparquet (join stats with block geometries) | |
| print(f"Saving results to {args.output}...") | |
| gdf = blocks[["GEOID20", "geometry"]].merge(df, on="GEOID20", how="inner") | |
| gdf = gpd.GeoDataFrame(gdf, geometry="geometry") | |
| gdf.to_parquet(args.output, index=False) | |
| print(f"Done. {len(gdf)} blocks, {len(gdf.columns) - 2} stat columns.") | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """Reduce AEF block stats to 3-component PCA and rasterize as a COG.""" | |
| import argparse | |
| import geopandas as gpd | |
| import numpy as np | |
| import rasterio | |
| from rasterio.enums import Resampling | |
| from rasterio.features import rasterize | |
| from rasterio.transform import from_bounds | |
| from sklearn.decomposition import PCA | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description="PCA-reduce AEF block stats and rasterize to a 3-band COG." | |
| ) | |
| parser.add_argument("input", help="Input geoparquet with block stats (e.g. wa_block_aef_stats.geoparquet)") | |
| parser.add_argument("-o", "--output", default=None, help="Output COG path (default: <input_stem>_pca.tif)") | |
| parser.add_argument("--resolution", type=float, default=10.0, help="Raster resolution in CRS units (default: 10m)") | |
| parser.add_argument("--crs", default="EPSG:32610", help="Target CRS for rasterization (default: EPSG:32610)") | |
| parser.add_argument("--n-components", type=int, default=3, help="Number of PCA components (default: 3)") | |
| parser.add_argument("--clip-percentile", type=float, default=2.0, help="Percentile for uint8 clipping (default: 2)") | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| if args.output is None: | |
| from pathlib import Path | |
| args.output = str(Path(args.input).stem) + "_pca.tif" | |
| # 1. Load data | |
| print(f"Loading {args.input}...") | |
| gdf = gpd.read_parquet(args.input) | |
| gdf = gdf.to_crs(args.crs) | |
| # 2. PCA on stat columns | |
| stat_cols = [c for c in gdf.columns if c.startswith(("mean_", "std_", "gem_"))] | |
| print(f" {len(gdf)} blocks, {len(stat_cols)} stat columns") | |
| X = gdf[stat_cols].values.astype(np.float32) | |
| # Drop rows with NaN for PCA fit, but keep them for output (they'll get nodata) | |
| valid = ~np.isnan(X).any(axis=1) | |
| pca = PCA(n_components=args.n_components) | |
| pca.fit(X[valid]) | |
| print(f" PCA explained variance: {pca.explained_variance_ratio_}") | |
| pcs = np.full((len(X), args.n_components), np.nan, dtype=np.float32) | |
| pcs[valid] = pca.transform(X[valid]) | |
| # 3. Rasterize at target resolution | |
| bounds = gdf.total_bounds | |
| res = args.resolution | |
| width = int(np.ceil((bounds[2] - bounds[0]) / res)) | |
| height = int(np.ceil((bounds[3] - bounds[1]) / res)) | |
| transform = from_bounds(bounds[0], bounds[1], bounds[0] + width * res, bounds[1] + height * res, width, height) | |
| print(f" Raster: {width} x {height} px ({width * height / 1e6:.1f} Mpx)") | |
| lo_p = args.clip_percentile | |
| hi_p = 100 - args.clip_percentile | |
| bands = [] | |
| for i in range(args.n_components): | |
| vals = pcs[:, i] | |
| valid_vals = vals[~np.isnan(vals)] | |
| lo, hi = np.percentile(valid_vals, [lo_p, hi_p]) | |
| clipped = np.clip(vals, lo, hi) | |
| scaled = ((clipped - lo) / (hi - lo) * 254 + 1).astype(np.uint8) | |
| scaled[np.isnan(vals)] = 0 # nodata | |
| shapes = list(zip(gdf.geometry, scaled)) | |
| band = rasterize( | |
| shapes, out_shape=(height, width), transform=transform, | |
| fill=0, dtype=np.uint8, all_touched=True, | |
| ) | |
| bands.append(band) | |
| print(f" Band {i + 1} (PC{i + 1}): min={band[band > 0].min()}, max={band[band > 0].max()}") | |
| # 4. Write COG | |
| profile = { | |
| "driver": "GTiff", | |
| "dtype": "uint8", | |
| "width": width, | |
| "height": height, | |
| "count": args.n_components, | |
| "crs": args.crs, | |
| "transform": transform, | |
| "nodata": 0, | |
| "compress": "deflate", | |
| "tiled": True, | |
| "blockxsize": 512, | |
| "blockysize": 512, | |
| } | |
| with rasterio.open(args.output, "w", **profile) as dst: | |
| for i in range(args.n_components): | |
| dst.write(bands[i], i + 1) | |
| dst.set_band_description(i + 1, f"PC{i + 1} ({pca.explained_variance_ratio_[i]:.1%})") | |
| with rasterio.open(args.output, "r+") as dst: | |
| dst.build_overviews([2, 4, 8, 16, 32], Resampling.average) | |
| dst.update_tags(ns="rio_overview", resampling="average") | |
| print(f"Saved {args.output}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment

Computes statistics of AlphaEarth embeddings over Census polygons (and PCA visualizations of the embeddings)
E.g. over Census blocks over Seattle