Created
April 6, 2023 20:20
-
-
Save darrenwiens/9c3eb33165f5c598476b435feb1d2be5 to your computer and use it in GitHub Desktop.
Segment anything API, segments static Mapbox Satellite Image
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi import FastAPI | |
from pydantic import BaseModel | |
import torch | |
import torchvision | |
import numpy as np | |
from PIL import Image | |
import random | |
import requests | |
import base64 | |
import io | |
import os | |
from segment_anything import sam_model_registry, SamPredictor | |
from fastapi.middleware.cors import CORSMiddleware | |
MAPBOX_TOKEN = os.getenv("MAPBOX_TOKEN") # get your Mapbox token from an environment variable | |
app = FastAPI() | |
origins = [ | |
"http://localhost:8001", # allow your origin to make requests | |
] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
sam_checkpoint = "sam_vit_h_4b8939.pth" # local file, download from: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth | |
# device = "cuda" # use it if you've got it | |
model_type = "default" | |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) | |
# sam.to(device=device) | |
predictor = SamPredictor(sam) | |
def get_mapbox_img( | |
map_center_lat, map_center_lon, map_width_px, map_height_px, map_zoom_level | |
): | |
style = "mapbox/satellite-v9" | |
url = f"https://api.mapbox.com/styles/v1/{style}/static/{map_center_lon},{map_center_lat}@2x,{map_zoom_level},0/{map_width_px}x{map_height_px}?access_token={MAPBOX_TOKEN}" | |
im = Image.open(requests.get(url, stream=True).raw) | |
return im | |
class SegmentRequest(BaseModel): | |
map_width_px: int | |
map_height_px: int | |
map_zoom_level: float | |
map_center_lat: float | |
map_center_lon: float | |
click_x_px: int | |
click_y_px: int | |
@app.post("/segment/") | |
async def create_item(segment_req: SegmentRequest): | |
print("PyTorch version:", torch.__version__) | |
print("Torchvision version:", torchvision.__version__) | |
print("CUDA is available:", torch.cuda.is_available()) | |
image = get_mapbox_img( | |
segment_req.map_center_lat, | |
segment_req.map_center_lon, | |
segment_req.map_width_px, | |
segment_req.map_height_px, | |
segment_req.map_zoom_level, | |
) | |
image = np.array(image) | |
predictor.set_image(image) | |
input_point = np.array([[int(segment_req.click_x_px), int(segment_req.click_y_px)]]) | |
input_label = np.array([1]) | |
masks, scores, logits = predictor.predict( | |
point_coords=input_point, | |
point_labels=input_label, | |
multimask_output=True, | |
) | |
mask = masks[np.argmax(scores)] | |
bool_img = Image.fromarray(mask).convert("L") | |
rgb_img = bool_img.convert("RGB") | |
r, g, b = rgb_img.split() | |
r = r.point(lambda i: i * random.random()) | |
g = g.point(lambda i: i * random.random()) | |
b = b.point(lambda i: i * random.random()) | |
rgba_img = Image.merge("RGBA", (r, g, b, bool_img)) | |
buffered = io.BytesIO() | |
rgba_img.save(buffered, format="PNG") | |
encoded = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
return f"data:image/png;base64,{encoded}" # return a data uri |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment