Skip to content

Instantly share code, notes, and snippets.

@darrenwiens
Created April 6, 2023 20:20
Show Gist options
  • Save darrenwiens/9c3eb33165f5c598476b435feb1d2be5 to your computer and use it in GitHub Desktop.
Save darrenwiens/9c3eb33165f5c598476b435feb1d2be5 to your computer and use it in GitHub Desktop.
Segment anything API, segments static Mapbox Satellite Image
from fastapi import FastAPI
from pydantic import BaseModel
import torch
import torchvision
import numpy as np
from PIL import Image
import random
import requests
import base64
import io
import os
from segment_anything import sam_model_registry, SamPredictor
from fastapi.middleware.cors import CORSMiddleware
MAPBOX_TOKEN = os.getenv("MAPBOX_TOKEN") # get your Mapbox token from an environment variable
app = FastAPI()
origins = [
"http://localhost:8001", # allow your origin to make requests
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
sam_checkpoint = "sam_vit_h_4b8939.pth" # local file, download from: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
# device = "cuda" # use it if you've got it
model_type = "default"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
# sam.to(device=device)
predictor = SamPredictor(sam)
def get_mapbox_img(
map_center_lat, map_center_lon, map_width_px, map_height_px, map_zoom_level
):
style = "mapbox/satellite-v9"
url = f"https://api.mapbox.com/styles/v1/{style}/static/{map_center_lon},{map_center_lat}@2x,{map_zoom_level},0/{map_width_px}x{map_height_px}?access_token={MAPBOX_TOKEN}"
im = Image.open(requests.get(url, stream=True).raw)
return im
class SegmentRequest(BaseModel):
map_width_px: int
map_height_px: int
map_zoom_level: float
map_center_lat: float
map_center_lon: float
click_x_px: int
click_y_px: int
@app.post("/segment/")
async def create_item(segment_req: SegmentRequest):
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())
image = get_mapbox_img(
segment_req.map_center_lat,
segment_req.map_center_lon,
segment_req.map_width_px,
segment_req.map_height_px,
segment_req.map_zoom_level,
)
image = np.array(image)
predictor.set_image(image)
input_point = np.array([[int(segment_req.click_x_px), int(segment_req.click_y_px)]])
input_label = np.array([1])
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
mask = masks[np.argmax(scores)]
bool_img = Image.fromarray(mask).convert("L")
rgb_img = bool_img.convert("RGB")
r, g, b = rgb_img.split()
r = r.point(lambda i: i * random.random())
g = g.point(lambda i: i * random.random())
b = b.point(lambda i: i * random.random())
rgba_img = Image.merge("RGBA", (r, g, b, bool_img))
buffered = io.BytesIO()
rgba_img.save(buffered, format="PNG")
encoded = base64.b64encode(buffered.getvalue()).decode("utf-8")
return f"data:image/png;base64,{encoded}" # return a data uri
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment