Last active
August 24, 2020 16:20
-
-
Save vincentsarago/dd3b3f91a5d1562048b95f024c8de6d4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Common dependency.""" | |
import re | |
from dataclasses import dataclass, field | |
from enum import Enum | |
from typing import Dict, Optional, Tuple, Type, Union | |
from urllib.parse import urlparse | |
import morecantile | |
import numpy | |
from rasterio.enums import Resampling | |
from rio_tiler.colormap import cmap | |
from rio_tiler.io import BaseReader | |
from titiler import settings | |
from titiler.custom import cmap as custom_colormap | |
from titiler.custom import tms as custom_tms | |
from titiler.utils import get_hash | |
from fastapi import Query | |
from starlette.requests import Request | |
################################################################################ | |
# CMAP AND TMS Customization | |
morecantile.tms.register(custom_tms.EPSG3413) | |
# REGISTER CUSTOM TMS | |
# | |
# e.g morecantile.tms.register(custom_tms.my_custom_tms) | |
cmap.register("above", custom_colormap.above_cmap) | |
# REGISTER CUSTOM COLORMAP HERE | |
# | |
# e.g cmap.register("customRed", custom_colormap.custom_red) | |
################################################################################ | |
# DO NOT UPDATE | |
# Create ENUMS with all CMAP and TMS for documentation and validation. | |
ColorMapNames = Enum("ColorMapNames", [(a, a) for a in sorted(cmap.list())]) # type: ignore | |
TileMatrixSetNames = Enum("TileMatrixSetNames", [(a, a) for a in sorted(morecantile.tms.list())]) # type: ignore | |
MosaicTileMatrixSetNames = Enum("MosaicTileMatrixSetNames", [("WebMercatorQuad", "WebMercatorQuad")]) # type: ignore | |
ResamplingNames = Enum("ResamplingNames", [(r.name, r.name) for r in Resampling]) # type: ignore | |
async def request_hash(request: Request) -> str: | |
"""Create SHA224 id from reuqest.""" | |
return get_hash(**dict(request.query_params), **request.path_params) | |
def TMSParams( | |
TileMatrixSetId: TileMatrixSetNames = Query( | |
TileMatrixSetNames.WebMercatorQuad, # type: ignore | |
description="TileMatrixSet Name (default: 'WebMercatorQuad')", | |
) | |
) -> morecantile.TileMatrixSet: | |
"""TileMatrixSet Dependency.""" | |
return morecantile.tms.get(TileMatrixSetId.name) | |
def MosaicTMSParams( | |
TileMatrixSetId: MosaicTileMatrixSetNames = Query( | |
MosaicTileMatrixSetNames.WebMercatorQuad, # type: ignore | |
description="TileMatrixSet Name (default: 'WebMercatorQuad')", | |
) | |
) -> morecantile.TileMatrixSet: | |
"""TileMatrixSet Dependency.""" | |
return morecantile.tms.get(TileMatrixSetId.name) | |
@dataclass | |
class DefaultDependency: | |
"""Dependency Base Class""" | |
kwargs: dict = field(init=False, default_factory=dict) | |
@dataclass | |
class PathParams(DefaultDependency): | |
"""Create dataset path from args""" | |
url: str = Query(..., description="Dataset URL") | |
reader: Optional[Type[BaseReader]] = field(init=False, default=None) # Placeholder | |
def __post_init__(self,): | |
"""Define dataset URL.""" | |
parsed = urlparse(self.url) | |
# Placeholder in case we want to allow landsat+mosaicid | |
# by default we store the mosaicjson as a GZ compressed json (.json.gz) file | |
if parsed.scheme.split("+")[-1] == "mosaicid": | |
self.url = f"{settings.DEFAULT_MOSAIC_BACKEND}{settings.DEFAULT_MOSAIC_HOST}/{parsed.netloc}.json.gz" | |
# if parsed.scheme.startswith("landsat+"): | |
# self.reader = L8Reader | |
# self.url = self.url.replace("landsat+", "") | |
# elif parsed.scheme.startswith("sentinel2cog+"): | |
# self.reader = Sentinel2COG | |
# self.url = self.url.replace("sentinel2cog+", "") | |
# .... | |
@dataclass | |
class AssetsParams(DefaultDependency): | |
"""Create dataset path from args""" | |
assets: Optional[str] = Query( | |
None, | |
title="Asset indexes", | |
description="comma (',') delimited asset names (might not be an available options of some readers)", | |
) | |
kwargs: dict = field(init=False, default_factory=dict) | |
def __post_init__(self): | |
"""Post Init.""" | |
if self.assets is not None: | |
self.kwargs["assets"] = self.assets.split(",") | |
@dataclass | |
class CommonParams(DefaultDependency): | |
"""Common Reader params.""" | |
bidx: Optional[str] = Query( | |
None, title="Band indexes", description="comma (',') delimited band indexes", | |
) | |
nodata: Optional[Union[str, int, float]] = Query( | |
None, title="Nodata value", description="Overwrite internal Nodata value" | |
) | |
resampling_method: ResamplingNames = Query( | |
ResamplingNames.nearest, description="Resampling method." # type: ignore | |
) | |
def __post_init__(self): | |
"""Post Init.""" | |
self.indexes = ( | |
tuple(int(s) for s in re.findall(r"\d+", self.bidx)) if self.bidx else None | |
) | |
if self.nodata is not None: | |
self.nodata = numpy.nan if self.nodata == "nan" else float(self.nodata) | |
@dataclass | |
class MetadataParams(CommonParams): | |
"""Common Metadada parameters.""" | |
pmin: float = Query(2.0, description="Minimum percentile") | |
pmax: float = Query(98.0, description="Maximum percentile") | |
max_size: int = Query(1024, description="Maximum image size to read onto.") | |
histogram_bins: Optional[int] = Query(None, description="Histogram bins.") | |
histogram_range: Optional[str] = Query( | |
None, description="comma (',') delimited Min,Max histogram bounds" | |
) | |
bounds: Optional[str] = Query( | |
None, | |
descriptions="comma (',') delimited Bounding box coordinates from which to calculate image statistics.", | |
) | |
hist_options: dict = field(init=False, default_factory=dict) | |
def __post_init__(self): | |
"""Post Init.""" | |
super().__post_init__() | |
if self.histogram_bins: | |
self.hist_options.update(dict(bins=self.histogram_bins)) | |
if self.histogram_range: | |
self.hist_options.update( | |
dict(range=list(map(float, self.histogram_range.split(",")))) | |
) | |
if self.bounds: | |
self.bounds = tuple(map(float, self.bounds.split(","))) | |
@dataclass | |
class PointParams(DefaultDependency): | |
"""Point Parameters.""" | |
bidx: Optional[str] = Query( | |
None, title="Band indexes", description="comma (',') delimited band indexes", | |
) | |
nodata: Optional[Union[str, int, float]] = Query( | |
None, title="Nodata value", description="Overwrite internal Nodata value" | |
) | |
expression: Optional[str] = Query( | |
None, | |
title="Band Math expression", | |
description="rio-tiler's band math expression (e.g B1/B2)", | |
) | |
def __post_init__(self): | |
"""Post Init.""" | |
self.indexes = ( | |
tuple(int(s) for s in re.findall(r"\d+", self.bidx)) if self.bidx else None | |
) | |
if self.nodata is not None: | |
self.nodata = numpy.nan if self.nodata == "nan" else float(self.nodata) | |
@dataclass | |
class TileParams(CommonParams): | |
"""Common Tile parameters.""" | |
expression: Optional[str] = Query( | |
None, | |
title="Band Math expression", | |
description="rio-tiler's band math expression (e.g B1/B2)", | |
) | |
rescale: Optional[str] = Query( | |
None, | |
title="Min/Max data Rescaling", | |
description="comma (',') delimited Min,Max bounds", | |
) | |
color_formula: Optional[str] = Query( | |
None, | |
title="Color Formula", | |
description="rio-color formula (info: https://github.com/mapbox/rio-color)", | |
) | |
color_map: Optional[ColorMapNames] = Query( | |
None, description="rio-tiler's colormap name" | |
) | |
resampling_method: ResamplingNames = Query( | |
ResamplingNames.nearest, description="Resampling method." # type: ignore | |
) | |
colormap: Optional[Dict[int, Tuple[int, int, int, int]]] = field(init=False) | |
def __post_init__(self): | |
"""Post Init.""" | |
super().__post_init__() | |
self.colormap = cmap.get(self.color_map.value) if self.color_map else None | |
@dataclass | |
class ImageParams(TileParams): | |
"""Common Image parameters.""" | |
max_size: Optional[int] = Query( | |
1024, description="Maximum image size to read onto." | |
) | |
height: Optional[int] = Query(None, description="Force output image height.") | |
width: Optional[int] = Query(None, description="Force output image width.") | |
def __post_init__(self): | |
"""Post Init.""" | |
super().__post_init__() | |
if self.width and self.height: | |
self.max_size = None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""API.""" | |
import abc | |
from dataclasses import dataclass, field | |
import os | |
from urllib.parse import urlencode | |
from typing import Callable, Dict, Type, Optional, Union, List | |
import pkg_resources | |
from rasterio.transform import from_bounds | |
from fastapi import APIRouter, Depends, Path, Query | |
from starlette.requests import Request | |
from starlette.responses import Response | |
from starlette.templating import Jinja2Templates | |
from rio_tiler.io import BaseReader | |
from rio_tiler_crs import COGReader | |
from .. import utils | |
from ..db.memcache import CacheLayer | |
from ..models.cog import cogBounds, cogInfo, cogMetadata | |
from ..models.mapbox import TileJSON | |
from ..ressources.common import img_endpoint_params | |
from ..ressources.enums import ImageMimeTypes, ImageType, MimeTypes | |
from ..ressources.responses import XMLResponse | |
from ..dependencies import ( | |
PathParams, | |
AssetsParams, | |
DefaultDependency, | |
MetadataParams, | |
TileParams, | |
ImageParams, | |
PointParams, | |
TMSParams, | |
request_hash, | |
) | |
template_dir = pkg_resources.resource_filename("titiler", "templates") | |
templates = Jinja2Templates(directory=template_dir) | |
# ref: https://github.com/python/mypy/issues/5374 | |
@dataclass # type: ignore | |
class BaseFactory(metaclass=abc.ABCMeta): | |
"""Tiler Factory.""" | |
reader: Type[BaseReader] = COGReader | |
reader_options: Dict = field(default_factory=dict) | |
# FastAPI router | |
router: APIRouter = field(default_factory=APIRouter) | |
# Endpoint Dependencies | |
tms_dependency: Callable = field(default=TMSParams) | |
path_dependency: Type[PathParams] = field(default=PathParams) | |
tiles_dependency: Type[TileParams] = field(default=TileParams) | |
point_dependency: Type[PointParams] = field(default=PointParams) | |
# Add `assets` options in endpoint | |
add_asset_deps: bool = False | |
# Router Prefix is needed to find the path for /tile if the TilerFactory.router is mounted | |
# with other router (multiple `.../tile` routes). | |
# e.g if you mount the route with `/cog` prefix, set router_prefix to cog and | |
# each routes will be prefixed with `cog_`, which will let starlette retrieve routes url (Reverse URL lookups) | |
router_prefix: str = "" | |
def __post_init__(self): | |
self.options = AssetsParams if self.add_asset_deps else DefaultDependency | |
if self.router_prefix: | |
self.router_prefix = f"{self.router_prefix}_" | |
self.register_routes() | |
@abc.abstractmethod | |
def register_routes(self): | |
... | |
@dataclass | |
class TilerFactory(BaseFactory): | |
"""Tiler Factory.""" | |
# Endpoint Dependencies | |
metadata_dependency: Type[MetadataParams] = MetadataParams | |
img_dependency: Type[ImageParams] = ImageParams | |
# Add/Remove some endpoints | |
add_preview: bool = True | |
add_part: bool = True | |
def register_routes(self): | |
""" | |
This Method register routes to the router. | |
Because we wrap the endpoints in a class we cannot define the routes as | |
methods (because of the self argument). The HACK is to define routes inside | |
the class method and register them after the class initialisation. | |
""" | |
# Default Routes | |
# (/bounds, /info, /metadata, /tile, /tilejson.json, /WMTSCapabilities.xml and /point) | |
self._bounds() | |
self._info_with_assets() if self.add_asset_deps else self._info() | |
self._metadata() | |
self._tile() | |
self._point() | |
if self.add_preview: | |
self._preview() | |
if self.add_part: | |
self._part() | |
############################################################################ | |
# /bounds | |
############################################################################ | |
def _bounds(self): | |
"""Register /bounds endpoint to router.""" | |
@self.router.get( | |
"/bounds", | |
response_model=cogBounds, | |
responses={200: {"description": "Return dataset's bounds."}}, | |
name=f"{self.router_prefix}bounds", | |
) | |
def bounds(src_path=Depends(self.path_dependency)): | |
"""Return the bounds of the COG.""" | |
reader = src_path.reader or self.reader | |
with reader(src_path.url, **self.reader_options) as src_dst: | |
return {"bounds": src_dst.bounds} | |
############################################################################ | |
# /info - with assets | |
############################################################################ | |
def _info_with_assets(self): | |
"""Register /info endpoint to router.""" | |
@self.router.get( | |
"/info", | |
response_model=Union[List[str], Dict[str, cogInfo]], | |
response_model_exclude={"minzoom", "maxzoom", "center"}, | |
response_model_exclude_none=True, | |
responses={200: {"description": "Return dataset's basic info."}}, | |
name=f"{self.router_prefix}info", | |
) | |
def info( | |
src_path=Depends(self.path_dependency), | |
options: AssetsParams = Depends() | |
): | |
"""Return basic info.""" | |
reader = src_path.reader or self.reader | |
with reader(src_path.url, **self.reader_options) as src_dst: | |
if not options.kwargs.get("assets"): | |
return src_dst.assets | |
info = src_dst.info(**options.kwargs) | |
return info | |
############################################################################ | |
# /info - without assets | |
############################################################################ | |
def _info(self): | |
"""Register /info endpoint to router.""" | |
@self.router.get( | |
"/info", | |
response_model=Union[List[str], Dict[str, cogInfo], cogInfo], | |
response_model_exclude={"minzoom", "maxzoom", "center"}, | |
response_model_exclude_none=True, | |
responses={200: {"description": "Return dataset's basic info."}}, | |
name=f"{self.router_prefix}info", | |
) | |
def info(src_path=Depends(self.path_dependency)): | |
"""Return basic info.""" | |
reader = src_path.reader or self.reader | |
with reader(src_path.url, **self.reader_options) as src_dst: | |
info = src_dst.info() | |
return info | |
############################################################################ | |
# /metadata | |
############################################################################ | |
def _metadata(self): | |
"""Register /metadata endpoint to router.""" | |
@self.router.get( | |
"/metadata", | |
response_model=Union[cogMetadata, Dict[str, cogMetadata]], | |
response_model_exclude={"minzoom", "maxzoom", "center"}, | |
response_model_exclude_none=True, | |
responses={200: {"description": "Return dataset's metadata."}}, | |
name=f"{self.router_prefix}metadata", | |
) | |
def metadata( | |
src_path=Depends(self.path_dependency), | |
params=Depends(self.metadata_dependency), | |
options=Depends(self.options), | |
): | |
"""Return metadata.""" | |
reader = src_path.reader or self.reader | |
with reader(src_path.url, **self.reader_options) as src_dst: | |
info = src_dst.metadata( | |
params.pmin, | |
params.pmax, | |
nodata=params.nodata, | |
indexes=params.indexes, | |
max_size=params.max_size, | |
hist_options=params.hist_options, | |
bounds=params.bounds, | |
resampling_method=params.resampling_method.name, | |
**options.kwargs, | |
) | |
return info | |
############################################################################ | |
# /tiles | |
############################################################################ | |
def _tile(self): | |
tile_endpoint_params = img_endpoint_params.copy() | |
tile_endpoint_params["name"] = f"{self.router_prefix}tile" | |
@self.router.get(r"/tiles/{z}/{x}/{y}", **tile_endpoint_params) | |
@self.router.get(r"/tiles/{z}/{x}/{y}.{format}", **tile_endpoint_params) | |
@self.router.get(r"/tiles/{z}/{x}/{y}@{scale}x", **tile_endpoint_params) | |
@self.router.get(r"/tiles/{z}/{x}/{y}@{scale}x.{format}", **tile_endpoint_params) | |
@self.router.get(r"/tiles/{TileMatrixSetId}/{z}/{x}/{y}", **tile_endpoint_params) | |
@self.router.get( | |
r"/tiles/{TileMatrixSetId}/{z}/{x}/{y}.{format}", **tile_endpoint_params | |
) | |
@self.router.get( | |
r"/tiles/{TileMatrixSetId}/{z}/{x}/{y}@{scale}x", **tile_endpoint_params | |
) | |
@self.router.get( | |
r"/tiles/{TileMatrixSetId}/{z}/{x}/{y}@{scale}x.{format}", | |
**tile_endpoint_params, | |
) | |
def tile( | |
z: int = Path(..., ge=0, le=30, description="Mercator tiles's zoom level"), | |
x: int = Path(..., description="Mercator tiles's column"), | |
y: int = Path(..., description="Mercator tiles's row"), | |
tms=Depends(self.tms_dependency), | |
scale: int = Query( | |
1, gt=0, lt=4, description="Tile size scale. 1=256x256, 2=512x512..." | |
), | |
format: ImageType = Query(None, description="Output image type. Default is auto."), | |
src_path=Depends(self.path_dependency), | |
params=Depends(self.tiles_dependency), | |
options=Depends(self.options), | |
cache_client: CacheLayer = Depends(utils.get_cache), | |
request_id: str = Depends(request_hash), | |
): | |
"""Create map tile from a dataset.""" | |
timings = [] | |
headers: Dict[str, str] = {} | |
tilesize = scale * 256 | |
content = None | |
if cache_client: | |
try: | |
content, ext = cache_client.get_image_from_cache(request_id) | |
format = ImageType[ext] | |
headers["X-Cache"] = "HIT" | |
except Exception: | |
content = None | |
if not content: | |
with utils.Timer() as t: | |
reader = src_path.reader or self.reader | |
with reader(src_path.url, tms=tms, **self.reader_options) as src_dst: | |
tile, mask = src_dst.tile( | |
x, | |
y, | |
z, | |
tilesize=tilesize, | |
indexes=params.indexes, | |
expression=params.expression, | |
nodata=params.nodata, | |
resampling_method=params.resampling_method.name, | |
**options.kwargs, | |
) | |
colormap = ( | |
params.colormap or getattr(src_dst, "colormap", None) | |
) | |
timings.append(("Read", t.elapsed)) | |
if not format: | |
format = ImageType.jpg if mask.all() else ImageType.png | |
with utils.Timer() as t: | |
tile = utils.postprocess( | |
tile, | |
mask, | |
rescale=params.rescale, | |
color_formula=params.color_formula, | |
) | |
timings.append(("Post-process", t.elapsed)) | |
bounds = tms.xy_bounds(x, y, z) | |
dst_transform = from_bounds(*bounds, tilesize, tilesize) | |
with utils.Timer() as t: | |
content = utils.reformat( | |
tile, | |
mask, | |
format, | |
colormap=colormap, | |
transform=dst_transform, | |
crs=tms.crs, | |
) | |
timings.append(("Format", t.elapsed)) | |
if cache_client and content: | |
cache_client.set_image_cache(request_id, (content, format.value)) | |
if timings: | |
headers["X-Server-Timings"] = "; ".join( | |
["{} - {:0.2f}".format(name, time * 1000) for (name, time) in timings] | |
) | |
return Response( | |
content, media_type=ImageMimeTypes[format.value].value, headers=headers, | |
) | |
@self.router.get( | |
"/tilejson.json", | |
response_model=TileJSON, | |
responses={200: {"description": "Return a tilejson"}}, | |
response_model_exclude_none=True, | |
name=f"{self.router_prefix}tilejson", | |
) | |
@self.router.get( | |
"/{TileMatrixSetId}/tilejson.json", | |
response_model=TileJSON, | |
responses={200: {"description": "Return a tilejson"}}, | |
response_model_exclude_none=True, | |
name=f"{self.router_prefix}tilejson", | |
) | |
def tilejson( | |
request: Request, | |
tms=Depends(self.tms_dependency), | |
src_path=Depends(self.path_dependency), | |
tile_format: Optional[ImageType] = Query( | |
None, description="Output image type. Default is auto." | |
), | |
tile_scale: int = Query( | |
1, gt=0, lt=4, description="Tile size scale. 1=256x256, 2=512x512..." | |
), | |
minzoom: Optional[int] = Query(None, description="Overwrite default minzoom."), | |
maxzoom: Optional[int] = Query(None, description="Overwrite default maxzoom."), | |
): | |
"""Return TileJSON document for a dataset.""" | |
kwargs = { | |
"z": "{z}", | |
"x": "{x}", | |
"y": "{y}", | |
"scale": tile_scale, | |
"TileMatrixSetId": tms.identifier, | |
} | |
if tile_format: | |
kwargs["format"] = tile_format.value | |
q = dict(request.query_params) | |
q.pop("TileMatrixSetId", None) | |
q.pop("tile_format", None) | |
q.pop("tile_scale", None) | |
q.pop("minzoom", None) | |
q.pop("maxzoom", None) | |
qs = urlencode(list(q.items())) | |
tiles_url = request.url_for(f"{self.router_prefix}tile", **kwargs) | |
tiles_url += f"?{qs}" | |
reader = src_path.reader or self.reader | |
with reader(src_path.url, tms=tms, **self.reader_options) as src_dst: | |
center = list(src_dst.center) | |
if minzoom: | |
center[-1] = minzoom | |
tjson = { | |
"bounds": src_dst.bounds, | |
"center": tuple(center), | |
"minzoom": minzoom if minzoom is not None else src_dst.minzoom, | |
"maxzoom": maxzoom if maxzoom is not None else src_dst.maxzoom, | |
"name": os.path.basename(src_path.url), | |
"tiles": [tiles_url], | |
} | |
return tjson | |
@self.router.get( | |
"/WMTSCapabilities.xml", | |
response_class=XMLResponse, | |
name=f"{self.router_prefix}wmts", | |
tags=["OGC"], | |
) | |
@self.router.get( | |
"/{TileMatrixSetId}/WMTSCapabilities.xml", | |
response_class=XMLResponse, | |
name=f"{self.router_prefix}wmts", | |
tags=["OGC"], | |
) | |
def wmts( | |
request: Request, | |
tms=Depends(self.tms_dependency), | |
src_path=Depends(self.path_dependency), | |
tile_format: ImageType = Query( | |
ImageType.png, description="Output image type. Default is png." | |
), | |
tile_scale: int = Query( | |
1, gt=0, lt=4, description="Tile size scale. 1=256x256, 2=512x512..." | |
), | |
minzoom: Optional[int] = Query(None, description="Overwrite default minzoom."), | |
maxzoom: Optional[int] = Query(None, description="Overwrite default maxzoom."), | |
): | |
"""OGC WMTS endpoint.""" | |
kwargs = { | |
"z": "{TileMatrix}", | |
"x": "{TileCol}", | |
"y": "{TileRow}", | |
"scale": tile_scale, | |
"format": tile_format.value, | |
"TileMatrixSetId": tms.identifier, | |
} | |
tiles_endpoint = request.url_for(f"{self.router_prefix}tile", **kwargs) | |
q = dict(request.query_params) | |
q.pop("TileMatrixSetId", None) | |
q.pop("tile_format", None) | |
q.pop("tile_scale", None) | |
q.pop("minzoom", None) | |
q.pop("maxzoom", None) | |
q.pop("SERVICE", None) | |
q.pop("REQUEST", None) | |
qs = urlencode(list(q.items())) | |
tiles_endpoint += f"?{qs}" | |
reader = src_path.reader or self.reader | |
with reader(src_path.url, tms=tms, **self.reader_options) as src_dst: | |
bounds = src_dst.bounds | |
minzoom = minzoom if minzoom is not None else src_dst.minzoom | |
maxzoom = maxzoom if maxzoom is not None else src_dst.maxzoom | |
media_type = ImageMimeTypes[tile_format.value].value | |
tileMatrix = [] | |
for zoom in range(minzoom, maxzoom + 1): | |
matrix = tms.matrix(zoom) | |
tm = f""" | |
<TileMatrix> | |
<ows:Identifier>{matrix.identifier}</ows:Identifier> | |
<ScaleDenominator>{matrix.scaleDenominator}</ScaleDenominator> | |
<TopLeftCorner>{matrix.topLeftCorner[0]} {matrix.topLeftCorner[1]}</TopLeftCorner> | |
<TileWidth>{matrix.tileWidth}</TileWidth> | |
<TileHeight>{matrix.tileHeight}</TileHeight> | |
<MatrixWidth>{matrix.matrixWidth}</MatrixWidth> | |
<MatrixHeight>{matrix.matrixHeight}</MatrixHeight> | |
</TileMatrix>""" | |
tileMatrix.append(tm) | |
return templates.TemplateResponse( | |
"wmts.xml", | |
{ | |
"request": request, | |
"tiles_endpoint": tiles_endpoint, | |
"bounds": bounds, | |
"tileMatrix": tileMatrix, | |
"tms": tms, | |
"title": "Cloud Optimized GeoTIFF", | |
"layer_name": "cogeo", | |
"media_type": media_type, | |
}, | |
media_type=MimeTypes.xml.value, | |
) | |
############################################################################ | |
# /point | |
############################################################################ | |
def _point(self): | |
@self.router.get( | |
r"/point/{lon},{lat}", | |
responses={200: {"description": "Return a value for a point"}}, | |
name=f"{self.router_prefix}point", | |
) | |
def point( | |
lon: float = Path(..., description="Longitude"), | |
lat: float = Path(..., description="Latitude"), | |
src_path=Depends(self.path_dependency), | |
params=Depends(self.point_dependency), | |
options=Depends(self.options), | |
): | |
"""Get Point value for a dataset.""" | |
timings = [] | |
headers: Dict[str, str] = {} | |
with utils.Timer() as t: | |
reader = src_path.reader or self.reader | |
with reader(src_path.url, **self.reader_options) as src_dst: | |
values = src_dst.point( | |
lon, | |
lat, | |
indexes=params.indexes, | |
expression=params.expression, | |
nodata=params.nodata, | |
**options.kwargs, | |
) | |
timings.append(("Read", t.elapsed)) | |
if timings: | |
headers["X-Server-Timings"] = "; ".join( | |
["{} - {:0.2f}".format(name, time * 1000) for (name, time) in timings] | |
) | |
return {"coordinates": [lon, lat], "values": values} | |
############################################################################ | |
# /preview (Optional) | |
############################################################################ | |
def _preview(self): | |
prev_endpoint_params = img_endpoint_params.copy() | |
prev_endpoint_params["name"] = f"{self.router_prefix}preview" | |
@self.router.get(r"/preview", **prev_endpoint_params) | |
@self.router.get(r"/preview.{format}", **prev_endpoint_params) | |
def preview( | |
format: ImageType = Query( | |
None, description="Output image type. Default is auto." | |
), | |
src_path=Depends(self.path_dependency), | |
params=Depends(self.img_dependency), | |
options=Depends(self.options), | |
): | |
"""Create preview of a dataset.""" | |
timings = [] | |
headers: Dict[str, str] = {} | |
with utils.Timer() as t: | |
reader = src_path.reader or self.reader | |
with reader(src_path.url, **self.reader_options) as src_dst: | |
data, mask = src_dst.preview( | |
height=params.height, | |
width=params.width, | |
max_size=params.max_size, | |
indexes=params.indexes, | |
expression=params.expression, | |
nodata=params.nodata, | |
resampling_method=params.resampling_method.name, | |
**options.kwargs, | |
) | |
colormap = ( | |
params.colormap or getattr(src_dst, "colormap", None) | |
) | |
timings.append(("Read", t.elapsed)) | |
if not format: | |
format = ImageType.jpg if mask.all() else ImageType.png | |
with utils.Timer() as t: | |
data = utils.postprocess( | |
data, | |
mask, | |
rescale=params.rescale, | |
color_formula=params.color_formula, | |
) | |
timings.append(("Post-process", t.elapsed)) | |
with utils.Timer() as t: | |
content = utils.reformat(data, mask, format, colormap=colormap) | |
timings.append(("Format", t.elapsed)) | |
if timings: | |
headers["X-Server-Timings"] = "; ".join( | |
["{} - {:0.2f}".format(name, time * 1000) for (name, time) in timings] | |
) | |
return Response( | |
content, media_type=ImageMimeTypes[format.value].value, headers=headers, | |
) | |
############################################################################ | |
# /crop (Optional) | |
############################################################################ | |
def _part(self): | |
part_endpoint_params = img_endpoint_params.copy() | |
part_endpoint_params["name"] = f"{self.router_prefix}part" | |
# @router.get(r"/crop/{minx},{miny},{maxx},{maxy}", **part_endpoint_params) | |
@self.router.get( | |
r"/crop/{minx},{miny},{maxx},{maxy}.{format}", | |
**part_endpoint_params, | |
) | |
def part( | |
minx: float = Path(..., description="Bounding box min X"), | |
miny: float = Path(..., description="Bounding box min Y"), | |
maxx: float = Path(..., description="Bounding box max X"), | |
maxy: float = Path(..., description="Bounding box max Y"), | |
format: ImageType = Query(None, description="Output image type."), | |
src_path=Depends(self.path_dependency), | |
params=Depends(self.img_dependency), | |
options=Depends(self.options), | |
): | |
"""Create image from part of a dataset.""" | |
timings = [] | |
headers: Dict[str, str] = {} | |
with utils.Timer() as t: | |
reader = src_path.reader or self.reader | |
with reader(src_path.url, **self.reader_options) as src_dst: | |
data, mask = src_dst.part( | |
[minx, miny, maxx, maxy], | |
height=params.height, | |
width=params.width, | |
max_size=params.max_size, | |
indexes=params.indexes, | |
expression=params.expression, | |
nodata=params.nodata, | |
resampling_method=params.resampling_method.name, | |
**options.kwargs, | |
) | |
colormap = ( | |
params.colormap or getattr(src_dst, "colormap", None) | |
) | |
timings.append(("Read", t.elapsed)) | |
if not format: | |
format = ImageType.jpg if mask.all() else ImageType.png | |
with utils.Timer() as t: | |
data = utils.postprocess( | |
data, | |
mask, | |
rescale=params.rescale, | |
color_formula=params.color_formula, | |
) | |
timings.append(("Post-process", t.elapsed)) | |
with utils.Timer() as t: | |
content = utils.reformat(data, mask, format, colormap=colormap) | |
timings.append(("Format", t.elapsed)) | |
if timings: | |
headers["X-Server-Timings"] = "; ".join( | |
["{} - {:0.2f}".format(name, time * 1000) for (name, time) in timings] | |
) | |
return Response( | |
content, | |
media_type=ImageMimeTypes[format.value].value, | |
headers=headers, | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""API.""" | |
from dataclasses import dataclass, field | |
import os | |
from urllib.parse import urlencode | |
from typing import Callable, Dict, Optional | |
import pkg_resources | |
from rasterio.transform import from_bounds | |
from cogeo_mosaic.mosaic import MosaicJSON | |
from cogeo_mosaic.utils import get_footprints | |
from cogeo_mosaic.backends import BaseBackend, MosaicBackend | |
from rio_tiler.io import BaseReader | |
from rio_tiler_crs import COGReader | |
from rio_tiler.constants import MAX_THREADS | |
from fastapi import Depends, Path, Query | |
from starlette.requests import Request | |
from starlette.responses import Response | |
from starlette.templating import Jinja2Templates | |
from .. import utils | |
from ..db.memcache import CacheLayer | |
from ..models.cog import cogBounds | |
from ..models.mapbox import TileJSON | |
from ..models.mosaic import CreateMosaicJSON, UpdateMosaicJSON, mosaicInfo | |
from ..ressources.common import img_endpoint_params | |
from ..ressources.enums import ImageMimeTypes, ImageType, MimeTypes, PixelSelectionMethod | |
from ..ressources.responses import XMLResponse | |
from ..errors import BadRequestError, TileNotFoundError | |
from ..dependencies import MosaicTMSParams, request_hash | |
from .factory import BaseFactory | |
template_dir = pkg_resources.resource_filename("titiler", "templates") | |
templates = Jinja2Templates(directory=template_dir) | |
@dataclass | |
class MosaicTilerFactory(BaseFactory): | |
reader: BaseBackend = field(default=MosaicBackend) | |
dataset_reader: BaseReader = field(default=COGReader) | |
tms_dependency: Callable = field(default=MosaicTMSParams) | |
add_asset_deps: bool = True # We add if by default | |
def __post_init__(self): | |
super().__post_init__() | |
def register_routes(self): | |
""" | |
This Method register routes to the router. | |
Because we wrap the endpoints in a class we cannot define the routes as | |
methods (because of the self argument). The HACK is to define routes inside | |
the class method and register them after the class initialisation. | |
""" | |
self._read() | |
self._create() | |
self._update() | |
self._bounds() | |
self._info() | |
self._tile() | |
self._point() | |
############################################################################ | |
# /read | |
############################################################################ | |
def _read(self): | |
"""Add / - GET (Read) route.""" | |
@self.router.get( | |
"", | |
response_model=MosaicJSON, | |
response_model_exclude_none=True, | |
responses={200: {"description": "Return MosaicJSON definition"}}, | |
name=f"{self.router_prefix}read", | |
) | |
def read(src_path=Depends(self.path_dependency),): | |
"""Read a MosaicJSON""" | |
with self.reader(src_path.url) as mosaic: | |
return mosaic.mosaic_def | |
############################################################################ | |
# /create | |
############################################################################ | |
def _create(self): | |
"""Add / - POST (create) route.""" | |
@self.router.post( | |
"", | |
response_model=MosaicJSON, | |
response_model_exclude_none=True, | |
name=f"{self.router_prefix}create", | |
) | |
def create(body: CreateMosaicJSON): | |
"""Create a MosaicJSON""" | |
mosaic = MosaicJSON.from_urls( | |
body.files, | |
minzoom=body.minzoom, | |
maxzoom=body.maxzoom, | |
max_threads=body.max_threads, | |
) | |
src_path = self.path_dependency(body.url) | |
reader = src_path.reader or self.dataset_reader | |
with self.reader(src_path.url, mosaic_def=mosaic, reader=reader) as mosaic: | |
try: | |
mosaic.write() | |
except NotImplementedError: | |
raise BadRequestError( | |
f"{mosaic.__class__.__name__} does not support write operations" | |
) | |
return mosaic.mosaic_def | |
############################################################################ | |
# /update | |
############################################################################ | |
def _update(self): | |
"""Add / - PUT (update) route.""" | |
@self.router.put( | |
"", | |
response_model=MosaicJSON, | |
response_model_exclude_none=True, | |
name=f"{self.router_prefix}update", | |
) | |
def update_mosaicjson(body: UpdateMosaicJSON): | |
"""Update an existing MosaicJSON""" | |
src_path = self.path_dependency(body.url) | |
reader = src_path.reader or self.dataset_reader | |
with self.reader(src_path.url, reader=reader) as mosaic: | |
features = get_footprints(body.files, max_threads=body.max_threads) | |
try: | |
mosaic.update(features, add_first=body.add_first, quiet=True) | |
except NotImplementedError: | |
raise BadRequestError( | |
f"{mosaic.__class__.__name__} does not support update operations" | |
) | |
return mosaic.mosaic_def | |
############################################################################ | |
# /bounds | |
############################################################################ | |
def _bounds(self): | |
"""Register /bounds endpoint to router.""" | |
@self.router.get( | |
"/bounds", | |
response_model=cogBounds, | |
responses={200: {"description": "Return the bounds of the MosaicJSON"}}, | |
name=f"{self.router_prefix}bounds", | |
) | |
def bounds(src_path=Depends(self.path_dependency)): | |
"""Return the bounds of the COG.""" | |
with self.reader(src_path.url) as src_dst: | |
return {"bounds": src_dst.bounds} | |
############################################################################ | |
# /info | |
############################################################################ | |
def _info(self): | |
"""Register /info endpoint to router.""" | |
@self.router.get( | |
"/info", | |
response_model=mosaicInfo, | |
responses={200: {"description": "Return info about the MosaicJSON"}}, | |
name=f"{self.router_prefix}info", | |
) | |
def info(src_path=Depends(self.path_dependency)): | |
"""Return basic info.""" | |
with self.reader(src_path.url) as src_dst: | |
info = src_dst.info() | |
return info | |
############################################################################ | |
# /tiles | |
############################################################################ | |
def _tile(self): | |
tile_endpoint_params = img_endpoint_params.copy() | |
tile_endpoint_params["name"] = f"{self.router_prefix}tile" | |
@self.router.get(r"/tiles/{z}/{x}/{y}", **tile_endpoint_params) | |
@self.router.get(r"/tiles/{z}/{x}/{y}.{format}", **tile_endpoint_params) | |
@self.router.get(r"/tiles/{z}/{x}/{y}@{scale}x", **tile_endpoint_params) | |
@self.router.get(r"/tiles/{z}/{x}/{y}@{scale}x.{format}", **tile_endpoint_params) | |
@self.router.get(r"/tiles/{TileMatrixSetId}/{z}/{x}/{y}", **tile_endpoint_params) | |
@self.router.get( | |
r"/tiles/{TileMatrixSetId}/{z}/{x}/{y}.{format}", **tile_endpoint_params | |
) | |
@self.router.get( | |
r"/tiles/{TileMatrixSetId}/{z}/{x}/{y}@{scale}x", **tile_endpoint_params | |
) | |
@self.router.get( | |
r"/tiles/{TileMatrixSetId}/{z}/{x}/{y}@{scale}x.{format}", | |
**tile_endpoint_params, | |
) | |
def tile( | |
z: int = Path(..., ge=0, le=30, description="Mercator tiles's zoom level"), | |
x: int = Path(..., description="Mercator tiles's column"), | |
y: int = Path(..., description="Mercator tiles's row"), | |
tms=Depends(self.tms_dependency), | |
scale: int = Query( | |
1, gt=0, lt=4, description="Tile size scale. 1=256x256, 2=512x512..." | |
), | |
format: ImageType = Query(None, description="Output image type. Default is auto."), | |
src_path=Depends(self.path_dependency), | |
params=Depends(self.tiles_dependency), | |
options=Depends(self.options), | |
pixel_selection: PixelSelectionMethod = Query( | |
PixelSelectionMethod.first, description="Pixel selection method." | |
), | |
cache_client: CacheLayer = Depends(utils.get_cache), | |
request_id: str = Depends(request_hash), | |
): | |
"""Create map tile from a COG.""" | |
timings = [] | |
headers: Dict[str, str] = {} | |
tilesize = scale * 256 | |
content = None | |
if cache_client: | |
try: | |
content, ext = cache_client.get_image_from_cache(request_id) | |
format = ImageType[ext] | |
headers["X-Cache"] = "HIT" | |
except Exception: | |
content = None | |
if not content: | |
with utils.Timer() as t: | |
reader = src_path.reader or self.dataset_reader | |
reader_options = {**self.reader_options, "tms": tms} | |
threads = int(os.getenv("MOSAIC_CONCURRENCY", MAX_THREADS)) | |
with self.reader( | |
src_path.url, reader=reader, reader_options=reader_options | |
) as src_dst: | |
(data, mask), assets_used = src_dst.tile( | |
x, | |
y, | |
z, | |
pixel_selection=pixel_selection.method(), | |
threads=threads, | |
tilesize=tilesize, | |
indexes=params.indexes, | |
expression=params.expression, | |
nodata=params.nodata, | |
resampling_method=params.resampling_method.name, | |
**options.kwargs, | |
) | |
timings.append(("Read-tile", t.elapsed)) | |
if data is None: | |
raise TileNotFoundError(f"Tile {z}/{x}/{y} was not found") | |
if not format: | |
format = ImageType.jpg if mask.all() else ImageType.png | |
with utils.Timer() as t: | |
data = utils.postprocess( | |
data, | |
mask, | |
rescale=params.rescale, | |
color_formula=params.color_formula, | |
) | |
timings.append(("Post-process", t.elapsed)) | |
bounds = tms.xy_bounds(x, y, z) | |
dst_transform = from_bounds(*bounds, tilesize, tilesize) | |
with utils.Timer() as t: | |
content = utils.reformat( | |
tile, | |
mask, | |
format, | |
colormap=params.colormap, | |
transform=dst_transform, | |
crs=tms.crs, | |
) | |
timings.append(("Format", t.elapsed)) | |
if cache_client and content: | |
cache_client.set_image_cache(request_id, (content, format.value)) | |
if timings: | |
headers["X-Server-Timings"] = "; ".join( | |
["{} - {:0.2f}".format(name, time * 1000) for (name, time) in timings] | |
) | |
if assets_used: | |
headers["X-Assets"] = ",".join(assets_used) | |
return Response( | |
content, media_type=ImageMimeTypes[format.value].value, headers=headers, | |
) | |
@self.router.get( | |
"/tilejson.json", | |
response_model=TileJSON, | |
responses={200: {"description": "Return a tilejson"}}, | |
response_model_exclude_none=True, | |
name=f"{self.router_prefix}tilejson", | |
) | |
@self.router.get( | |
"/{TileMatrixSetId}/tilejson.json", | |
response_model=TileJSON, | |
responses={200: {"description": "Return a tilejson"}}, | |
response_model_exclude_none=True, | |
name=f"{self.router_prefix}tilejson", | |
) | |
def tilejson( | |
request: Request, | |
tms=Depends(self.tms_dependency), | |
src_path=Depends(self.path_dependency), | |
tile_format: Optional[ImageType] = Query( | |
None, description="Output image type. Default is auto." | |
), | |
tile_scale: int = Query( | |
1, gt=0, lt=4, description="Tile size scale. 1=256x256, 2=512x512..." | |
), | |
minzoom: Optional[int] = Query(None, description="Overwrite default minzoom."), | |
maxzoom: Optional[int] = Query(None, description="Overwrite default maxzoom."), | |
): | |
"""Return TileJSON document for a COG.""" | |
kwargs = { | |
"z": "{z}", | |
"x": "{x}", | |
"y": "{y}", | |
"scale": tile_scale, | |
"TileMatrixSetId": tms.identifier, | |
} | |
if tile_format: | |
kwargs["format"] = tile_format.value | |
q = dict(request.query_params) | |
q.pop("TileMatrixSetId", None) | |
q.pop("tile_format", None) | |
q.pop("tile_scale", None) | |
q.pop("minzoom", None) | |
q.pop("maxzoom", None) | |
qs = urlencode(list(q.items())) | |
tiles_url = request.url_for(f"{self.router_prefix}tile", **kwargs) | |
tiles_url += f"?{qs}" | |
with self.reader(src_path.url,) as src_dst: | |
center = list(src_dst.center) | |
if minzoom: | |
center[-1] = minzoom | |
tjson = { | |
"bounds": src_dst.bounds, | |
"center": tuple(center), | |
"minzoom": minzoom if minzoom is not None else src_dst.minzoom, | |
"maxzoom": maxzoom if maxzoom is not None else src_dst.maxzoom, | |
"name": os.path.basename(src_path.url), | |
"tiles": [tiles_url], | |
} | |
return tjson | |
@self.router.get( | |
"/WMTSCapabilities.xml", | |
response_class=XMLResponse, | |
name=f"{self.router_prefix}wmts", | |
tags=["OGC"], | |
) | |
@self.router.get( | |
"/{TileMatrixSetId}/WMTSCapabilities.xml", | |
response_class=XMLResponse, | |
name=f"{self.router_prefix}wmts", | |
tags=["OGC"], | |
) | |
def wmts( | |
request: Request, | |
tms=Depends(self.tms_dependency), | |
src_path=Depends(self.path_dependency), | |
tile_format: ImageType = Query( | |
ImageType.png, description="Output image type. Default is png." | |
), | |
tile_scale: int = Query( | |
1, gt=0, lt=4, description="Tile size scale. 1=256x256, 2=512x512..." | |
), | |
minzoom: Optional[int] = Query(None, description="Overwrite default minzoom."), | |
maxzoom: Optional[int] = Query(None, description="Overwrite default maxzoom."), | |
): | |
"""OGC WMTS endpoint.""" | |
kwargs = { | |
"z": "{TileMatrix}", | |
"x": "{TileCol}", | |
"y": "{TileRow}", | |
"scale": tile_scale, | |
"format": tile_format.value, | |
"TileMatrixSetId": tms.identifier, | |
} | |
tiles_endpoint = request.url_for(f"{self.router_prefix}tile", **kwargs) | |
q = dict(request.query_params) | |
q.pop("TileMatrixSetId", None) | |
q.pop("tile_format", None) | |
q.pop("tile_scale", None) | |
q.pop("minzoom", None) | |
q.pop("maxzoom", None) | |
q.pop("SERVICE", None) | |
q.pop("REQUEST", None) | |
qs = urlencode(list(q.items())) | |
tiles_endpoint += f"?{qs}" | |
with self.reader(src_path.url) as src_dst: | |
bounds = src_dst.bounds | |
minzoom = minzoom if minzoom is not None else src_dst.minzoom | |
maxzoom = maxzoom if maxzoom is not None else src_dst.maxzoom | |
media_type = ImageMimeTypes[tile_format.value].value | |
tileMatrix = [] | |
for zoom in range(minzoom, maxzoom + 1): | |
matrix = tms.matrix(zoom) | |
tm = f""" | |
<TileMatrix> | |
<ows:Identifier>{matrix.identifier}</ows:Identifier> | |
<ScaleDenominator>{matrix.scaleDenominator}</ScaleDenominator> | |
<TopLeftCorner>{matrix.topLeftCorner[0]} {matrix.topLeftCorner[1]}</TopLeftCorner> | |
<TileWidth>{matrix.tileWidth}</TileWidth> | |
<TileHeight>{matrix.tileHeight}</TileHeight> | |
<MatrixWidth>{matrix.matrixWidth}</MatrixWidth> | |
<MatrixHeight>{matrix.matrixHeight}</MatrixHeight> | |
</TileMatrix>""" | |
tileMatrix.append(tm) | |
return templates.TemplateResponse( | |
"wmts.xml", | |
{ | |
"request": request, | |
"tiles_endpoint": tiles_endpoint, | |
"bounds": bounds, | |
"tileMatrix": tileMatrix, | |
"tms": tms, | |
"title": "Cloud Optimized GeoTIFF", | |
"layer_name": "cogeo", | |
"media_type": media_type, | |
}, | |
media_type=MimeTypes.xml.value, | |
) | |
############################################################################ | |
# /point (Optional) | |
############################################################################ | |
def _point(self): | |
@self.router.get( | |
r"/point/{lon},{lat}", | |
responses={200: {"description": "Return a value for a point"}}, | |
name=f"{self.router_prefix}point", | |
) | |
def point( | |
lon: float = Path(..., description="Longitude"), | |
lat: float = Path(..., description="Latitude"), | |
src_path=Depends(self.path_dependency), | |
params=Depends(self.point_dependency), | |
options=Depends(self.options), | |
): | |
"""Get Point value for a Mosaic.""" | |
timings = [] | |
headers: Dict[str, str] = {} | |
threads = int(os.getenv("MOSAIC_CONCURRENCY", MAX_THREADS)) | |
with utils.Timer() as t: | |
reader = src_path.reader or self.dataset_reader | |
with self.reader( | |
src_path.url, | |
reader=reader, | |
reader_options=self.reader_options, | |
) as src_dst: | |
values = src_dst.point( | |
lon, | |
lat, | |
threads=threads, | |
indexes=params.indexes, | |
expression=params.expression, | |
nodata=params.nodata, | |
**options.kwargs, | |
) | |
timings.append(("Read", t.elapsed)) | |
if timings: | |
headers["X-Server-Timings"] = "; ".join( | |
["{} - {:0.2f}".format(name, time * 1000) for (name, time) in timings] | |
) | |
return {"coordinates": [lon, lat], "values": values} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment