Last active
May 22, 2024 02:20
-
-
Save calebrob6/af17853f9d9e9cac2b817950226c37c9 to your computer and use it in GitHub Desktop.
Runs inference on large satellite image scenes with SAM models
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Requires the `segment-geospatial` package https://samgeo.gishub.org/ | |
import argparse | |
import os | |
import cv2 | |
import numpy as np | |
import rasterio | |
import rasterio.features | |
import rasterio.transform | |
import rasterio.windows | |
from tqdm.contrib import itertools | |
import logging | |
from samgeo.fast_sam import SamGeo | |
def add_parser_args(parser): | |
parser.add_argument("--input_fn", type=str, help="Input GeoTIFF imagery") | |
parser.add_argument("--output_fn", type=str, help="Output GeoTIFF mask") | |
parser.add_argument( | |
"--device", type=int, default=0, help="Device to run inference on" | |
) | |
parser.add_argument( | |
"--patch_size", type=int, default=1024, help="Patch size for inference" | |
) | |
parser.add_argument( | |
"--upsample_size", | |
type=int, | |
default=1, | |
help="Upsample factor for input imagery", | |
) | |
parser.add_argument( | |
"--padding", type=int, default=256, help="Padding for input imagery" | |
) | |
parser.add_argument( | |
"--overwrite", action="store_true", help="Overwrite existing output file" | |
) | |
parser.add_argument( | |
"--skip_valid_checks", action="store_true", help="Skip validation checks" | |
) | |
parser.add_argument( | |
"--size_filter_pixels", | |
type=int, | |
default=100_000, | |
help="Filter size for mask postprocessing", | |
) | |
parser.add_argument( | |
"--iou", | |
type=float, | |
default=0.01, | |
help="Intersection over union threshold for mask postprocessing", | |
) | |
parser.add_argument( | |
"--confidence", | |
type=float, | |
default=0.00, | |
help="Confidence threshold for mask postprocessing", | |
) | |
parser.add_argument( | |
"--max_detections", | |
type=int, | |
default=10_000, | |
help="Maximum number of detections for mask postprocessing", | |
) | |
def run_sam_on_img( | |
img, sam, upsample_size, size_filter_pixels, **kwargs | |
): | |
t_height, t_width, _ = img.shape | |
# args for this aren't documented anywhere, see https://github.com/CASIA-IVA-Lab/FastSAM/blob/main/predict.py#L87 for options | |
sam.set_image(img, **kwargs) | |
mask = sam.everything_prompt(output=None) | |
if len(mask) > 0: | |
mask_sizes = mask.sum(dim=(1, 2)) | |
new_mask = np.zeros( | |
(upsample_size * t_height, upsample_size * t_width), dtype=np.int32 | |
) | |
for i in range(mask.shape[0]): | |
if mask_sizes[i] < size_filter_pixels: | |
new_mask += mask[i].cpu().numpy() > 0 | |
mask = (new_mask > 0).astype(np.uint8) | |
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) | |
mask = cv2.morphologyEx( | |
mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8) | |
) | |
else: | |
mask = np.zeros((t_height, t_width), dtype=np.uint8) | |
# Resize mask if necessary (happens when upsample size is > 1) | |
mask_height, mask_width = mask.shape | |
if mask_height != t_height or mask_width != t_width: | |
mask = cv2.resize( | |
mask.astype(np.uint8), | |
(t_width, t_height), | |
interpolation=cv2.INTER_NEAREST, | |
) | |
return mask | |
def main(args): | |
if not args.skip_valid_checks: | |
assert os.path.exist(args.input_fn) | |
assert args.input_fn.endswith(".tif") | |
assert args.output_fn.endswith(".tif") | |
if os.path.exists(args.output_fn) and not args.overwrite: | |
raise FileExistsError(f"Output file {args.output_fn} already exists") | |
elif os.path.exists(args.output_fn) and args.overwrite: | |
print("WARNING: Will overwrite existing output file") | |
# Load SAM | |
print("Loading SAM model...") | |
device = f"cuda:{args.device}" | |
sam = SamGeo(model="FastSAM-x.pt") | |
# Stop YOLO8 logging | |
LOGGER = logging.getLogger("ultralytics") | |
LOGGER.setLevel(logging.CRITICAL) | |
# Run SAM | |
patch_size = args.patch_size | |
padding = args.padding | |
stride = patch_size - 2 * padding | |
with rasterio.open(args.input_fn) as f: | |
height, width = f.shape | |
profile = f.profile | |
if f.count > 3: | |
print("Input imagery has more than 3 bands, will use only the first 3") | |
elif f.count < 3: | |
raise ValueError("Input imagery must have at least 3 bands") | |
outputs = np.zeros((height, width), dtype=np.uint8) | |
ys = list(range(0, height - stride, stride)) + [height - patch_size] | |
xs = list(range(0, width - stride, stride)) + [width - patch_size] | |
print("Running inference...") | |
for i, (y, x) in enumerate(itertools.product(ys, xs)): | |
window = rasterio.windows.Window(x, y, patch_size, patch_size) | |
with rasterio.open(args.input_fn) as f: | |
img = f.read(window=window).transpose(1, 2, 0)[:, :, :3].copy() | |
# TODO: It makes sense that some blurring here might help, but IDK | |
mask = run_sam_on_img( | |
img, | |
sam, | |
upsample_size=args.upsample_size, | |
size_filter_pixels=args.size_filter_pixels, | |
imgsz=int(patch_size * args.upsample_size), | |
iou=args.iou, | |
conf=args.confidence, | |
max_det=args.max_detections, | |
device=device, | |
) | |
if x == 0 or y == 0: | |
outputs[y : y + patch_size, x : x + patch_size] = mask | |
else: | |
outputs[ | |
y + padding : y + patch_size - padding, | |
x + padding : x + patch_size - padding, | |
] = mask[padding:-padding, padding:-padding] | |
profile["count"] = 1 | |
profile["nodata"] = 0 | |
with rasterio.open(args.output_fn, "w", **profile) as f: | |
f.write(outputs, 1) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
add_parser_args(parser) | |
args = parser.parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment