Skip to content

Instantly share code, notes, and snippets.

@vincentsarago
Last active August 24, 2020 16:20
Show Gist options
  • Save vincentsarago/dd3b3f91a5d1562048b95f024c8de6d4 to your computer and use it in GitHub Desktop.
Save vincentsarago/dd3b3f91a5d1562048b95f024c8de6d4 to your computer and use it in GitHub Desktop.
"""Common dependency."""
import re
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Optional, Tuple, Type, Union
from urllib.parse import urlparse
import morecantile
import numpy
from rasterio.enums import Resampling
from rio_tiler.colormap import cmap
from rio_tiler.io import BaseReader
from titiler import settings
from titiler.custom import cmap as custom_colormap
from titiler.custom import tms as custom_tms
from titiler.utils import get_hash
from fastapi import Query
from starlette.requests import Request
################################################################################
# CMAP AND TMS Customization
morecantile.tms.register(custom_tms.EPSG3413)
# REGISTER CUSTOM TMS
#
# e.g morecantile.tms.register(custom_tms.my_custom_tms)
cmap.register("above", custom_colormap.above_cmap)
# REGISTER CUSTOM COLORMAP HERE
#
# e.g cmap.register("customRed", custom_colormap.custom_red)
################################################################################
# DO NOT UPDATE
# Create ENUMS with all CMAP and TMS for documentation and validation.
ColorMapNames = Enum("ColorMapNames", [(a, a) for a in sorted(cmap.list())]) # type: ignore
TileMatrixSetNames = Enum("TileMatrixSetNames", [(a, a) for a in sorted(morecantile.tms.list())]) # type: ignore
MosaicTileMatrixSetNames = Enum("MosaicTileMatrixSetNames", [("WebMercatorQuad", "WebMercatorQuad")]) # type: ignore
ResamplingNames = Enum("ResamplingNames", [(r.name, r.name) for r in Resampling]) # type: ignore
async def request_hash(request: Request) -> str:
"""Create SHA224 id from reuqest."""
return get_hash(**dict(request.query_params), **request.path_params)
def TMSParams(
TileMatrixSetId: TileMatrixSetNames = Query(
TileMatrixSetNames.WebMercatorQuad, # type: ignore
description="TileMatrixSet Name (default: 'WebMercatorQuad')",
)
) -> morecantile.TileMatrixSet:
"""TileMatrixSet Dependency."""
return morecantile.tms.get(TileMatrixSetId.name)
def MosaicTMSParams(
TileMatrixSetId: MosaicTileMatrixSetNames = Query(
MosaicTileMatrixSetNames.WebMercatorQuad, # type: ignore
description="TileMatrixSet Name (default: 'WebMercatorQuad')",
)
) -> morecantile.TileMatrixSet:
"""TileMatrixSet Dependency."""
return morecantile.tms.get(TileMatrixSetId.name)
@dataclass
class DefaultDependency:
"""Dependency Base Class"""
kwargs: dict = field(init=False, default_factory=dict)
@dataclass
class PathParams(DefaultDependency):
"""Create dataset path from args"""
url: str = Query(..., description="Dataset URL")
reader: Optional[Type[BaseReader]] = field(init=False, default=None) # Placeholder
def __post_init__(self,):
"""Define dataset URL."""
parsed = urlparse(self.url)
# Placeholder in case we want to allow landsat+mosaicid
# by default we store the mosaicjson as a GZ compressed json (.json.gz) file
if parsed.scheme.split("+")[-1] == "mosaicid":
self.url = f"{settings.DEFAULT_MOSAIC_BACKEND}{settings.DEFAULT_MOSAIC_HOST}/{parsed.netloc}.json.gz"
# if parsed.scheme.startswith("landsat+"):
# self.reader = L8Reader
# self.url = self.url.replace("landsat+", "")
# elif parsed.scheme.startswith("sentinel2cog+"):
# self.reader = Sentinel2COG
# self.url = self.url.replace("sentinel2cog+", "")
# ....
@dataclass
class AssetsParams(DefaultDependency):
"""Create dataset path from args"""
assets: Optional[str] = Query(
None,
title="Asset indexes",
description="comma (',') delimited asset names (might not be an available options of some readers)",
)
kwargs: dict = field(init=False, default_factory=dict)
def __post_init__(self):
"""Post Init."""
if self.assets is not None:
self.kwargs["assets"] = self.assets.split(",")
@dataclass
class CommonParams(DefaultDependency):
"""Common Reader params."""
bidx: Optional[str] = Query(
None, title="Band indexes", description="comma (',') delimited band indexes",
)
nodata: Optional[Union[str, int, float]] = Query(
None, title="Nodata value", description="Overwrite internal Nodata value"
)
resampling_method: ResamplingNames = Query(
ResamplingNames.nearest, description="Resampling method." # type: ignore
)
def __post_init__(self):
"""Post Init."""
self.indexes = (
tuple(int(s) for s in re.findall(r"\d+", self.bidx)) if self.bidx else None
)
if self.nodata is not None:
self.nodata = numpy.nan if self.nodata == "nan" else float(self.nodata)
@dataclass
class MetadataParams(CommonParams):
"""Common Metadada parameters."""
pmin: float = Query(2.0, description="Minimum percentile")
pmax: float = Query(98.0, description="Maximum percentile")
max_size: int = Query(1024, description="Maximum image size to read onto.")
histogram_bins: Optional[int] = Query(None, description="Histogram bins.")
histogram_range: Optional[str] = Query(
None, description="comma (',') delimited Min,Max histogram bounds"
)
bounds: Optional[str] = Query(
None,
descriptions="comma (',') delimited Bounding box coordinates from which to calculate image statistics.",
)
hist_options: dict = field(init=False, default_factory=dict)
def __post_init__(self):
"""Post Init."""
super().__post_init__()
if self.histogram_bins:
self.hist_options.update(dict(bins=self.histogram_bins))
if self.histogram_range:
self.hist_options.update(
dict(range=list(map(float, self.histogram_range.split(","))))
)
if self.bounds:
self.bounds = tuple(map(float, self.bounds.split(",")))
@dataclass
class PointParams(DefaultDependency):
"""Point Parameters."""
bidx: Optional[str] = Query(
None, title="Band indexes", description="comma (',') delimited band indexes",
)
nodata: Optional[Union[str, int, float]] = Query(
None, title="Nodata value", description="Overwrite internal Nodata value"
)
expression: Optional[str] = Query(
None,
title="Band Math expression",
description="rio-tiler's band math expression (e.g B1/B2)",
)
def __post_init__(self):
"""Post Init."""
self.indexes = (
tuple(int(s) for s in re.findall(r"\d+", self.bidx)) if self.bidx else None
)
if self.nodata is not None:
self.nodata = numpy.nan if self.nodata == "nan" else float(self.nodata)
@dataclass
class TileParams(CommonParams):
"""Common Tile parameters."""
expression: Optional[str] = Query(
None,
title="Band Math expression",
description="rio-tiler's band math expression (e.g B1/B2)",
)
rescale: Optional[str] = Query(
None,
title="Min/Max data Rescaling",
description="comma (',') delimited Min,Max bounds",
)
color_formula: Optional[str] = Query(
None,
title="Color Formula",
description="rio-color formula (info: https://github.com/mapbox/rio-color)",
)
color_map: Optional[ColorMapNames] = Query(
None, description="rio-tiler's colormap name"
)
resampling_method: ResamplingNames = Query(
ResamplingNames.nearest, description="Resampling method." # type: ignore
)
colormap: Optional[Dict[int, Tuple[int, int, int, int]]] = field(init=False)
def __post_init__(self):
"""Post Init."""
super().__post_init__()
self.colormap = cmap.get(self.color_map.value) if self.color_map else None
@dataclass
class ImageParams(TileParams):
"""Common Image parameters."""
max_size: Optional[int] = Query(
1024, description="Maximum image size to read onto."
)
height: Optional[int] = Query(None, description="Force output image height.")
width: Optional[int] = Query(None, description="Force output image width.")
def __post_init__(self):
"""Post Init."""
super().__post_init__()
if self.width and self.height:
self.max_size = None
"""API."""
import abc
from dataclasses import dataclass, field
import os
from urllib.parse import urlencode
from typing import Callable, Dict, Type, Optional, Union, List
import pkg_resources
from rasterio.transform import from_bounds
from fastapi import APIRouter, Depends, Path, Query
from starlette.requests import Request
from starlette.responses import Response
from starlette.templating import Jinja2Templates
from rio_tiler.io import BaseReader
from rio_tiler_crs import COGReader
from .. import utils
from ..db.memcache import CacheLayer
from ..models.cog import cogBounds, cogInfo, cogMetadata
from ..models.mapbox import TileJSON
from ..ressources.common import img_endpoint_params
from ..ressources.enums import ImageMimeTypes, ImageType, MimeTypes
from ..ressources.responses import XMLResponse
from ..dependencies import (
PathParams,
AssetsParams,
DefaultDependency,
MetadataParams,
TileParams,
ImageParams,
PointParams,
TMSParams,
request_hash,
)
template_dir = pkg_resources.resource_filename("titiler", "templates")
templates = Jinja2Templates(directory=template_dir)
# ref: https://github.com/python/mypy/issues/5374
@dataclass # type: ignore
class BaseFactory(metaclass=abc.ABCMeta):
"""Tiler Factory."""
reader: Type[BaseReader] = COGReader
reader_options: Dict = field(default_factory=dict)
# FastAPI router
router: APIRouter = field(default_factory=APIRouter)
# Endpoint Dependencies
tms_dependency: Callable = field(default=TMSParams)
path_dependency: Type[PathParams] = field(default=PathParams)
tiles_dependency: Type[TileParams] = field(default=TileParams)
point_dependency: Type[PointParams] = field(default=PointParams)
# Add `assets` options in endpoint
add_asset_deps: bool = False
# Router Prefix is needed to find the path for /tile if the TilerFactory.router is mounted
# with other router (multiple `.../tile` routes).
# e.g if you mount the route with `/cog` prefix, set router_prefix to cog and
# each routes will be prefixed with `cog_`, which will let starlette retrieve routes url (Reverse URL lookups)
router_prefix: str = ""
def __post_init__(self):
self.options = AssetsParams if self.add_asset_deps else DefaultDependency
if self.router_prefix:
self.router_prefix = f"{self.router_prefix}_"
self.register_routes()
@abc.abstractmethod
def register_routes(self):
...
@dataclass
class TilerFactory(BaseFactory):
"""Tiler Factory."""
# Endpoint Dependencies
metadata_dependency: Type[MetadataParams] = MetadataParams
img_dependency: Type[ImageParams] = ImageParams
# Add/Remove some endpoints
add_preview: bool = True
add_part: bool = True
def register_routes(self):
"""
This Method register routes to the router.
Because we wrap the endpoints in a class we cannot define the routes as
methods (because of the self argument). The HACK is to define routes inside
the class method and register them after the class initialisation.
"""
# Default Routes
# (/bounds, /info, /metadata, /tile, /tilejson.json, /WMTSCapabilities.xml and /point)
self._bounds()
self._info_with_assets() if self.add_asset_deps else self._info()
self._metadata()
self._tile()
self._point()
if self.add_preview:
self._preview()
if self.add_part:
self._part()
############################################################################
# /bounds
############################################################################
def _bounds(self):
"""Register /bounds endpoint to router."""
@self.router.get(
"/bounds",
response_model=cogBounds,
responses={200: {"description": "Return dataset's bounds."}},
name=f"{self.router_prefix}bounds",
)
def bounds(src_path=Depends(self.path_dependency)):
"""Return the bounds of the COG."""
reader = src_path.reader or self.reader
with reader(src_path.url, **self.reader_options) as src_dst:
return {"bounds": src_dst.bounds}
############################################################################
# /info - with assets
############################################################################
def _info_with_assets(self):
"""Register /info endpoint to router."""
@self.router.get(
"/info",
response_model=Union[List[str], Dict[str, cogInfo]],
response_model_exclude={"minzoom", "maxzoom", "center"},
response_model_exclude_none=True,
responses={200: {"description": "Return dataset's basic info."}},
name=f"{self.router_prefix}info",
)
def info(
src_path=Depends(self.path_dependency),
options: AssetsParams = Depends()
):
"""Return basic info."""
reader = src_path.reader or self.reader
with reader(src_path.url, **self.reader_options) as src_dst:
if not options.kwargs.get("assets"):
return src_dst.assets
info = src_dst.info(**options.kwargs)
return info
############################################################################
# /info - without assets
############################################################################
def _info(self):
"""Register /info endpoint to router."""
@self.router.get(
"/info",
response_model=Union[List[str], Dict[str, cogInfo], cogInfo],
response_model_exclude={"minzoom", "maxzoom", "center"},
response_model_exclude_none=True,
responses={200: {"description": "Return dataset's basic info."}},
name=f"{self.router_prefix}info",
)
def info(src_path=Depends(self.path_dependency)):
"""Return basic info."""
reader = src_path.reader or self.reader
with reader(src_path.url, **self.reader_options) as src_dst:
info = src_dst.info()
return info
############################################################################
# /metadata
############################################################################
def _metadata(self):
"""Register /metadata endpoint to router."""
@self.router.get(
"/metadata",
response_model=Union[cogMetadata, Dict[str, cogMetadata]],
response_model_exclude={"minzoom", "maxzoom", "center"},
response_model_exclude_none=True,
responses={200: {"description": "Return dataset's metadata."}},
name=f"{self.router_prefix}metadata",
)
def metadata(
src_path=Depends(self.path_dependency),
params=Depends(self.metadata_dependency),
options=Depends(self.options),
):
"""Return metadata."""
reader = src_path.reader or self.reader
with reader(src_path.url, **self.reader_options) as src_dst:
info = src_dst.metadata(
params.pmin,
params.pmax,
nodata=params.nodata,
indexes=params.indexes,
max_size=params.max_size,
hist_options=params.hist_options,
bounds=params.bounds,
resampling_method=params.resampling_method.name,
**options.kwargs,
)
return info
############################################################################
# /tiles
############################################################################
def _tile(self):
tile_endpoint_params = img_endpoint_params.copy()
tile_endpoint_params["name"] = f"{self.router_prefix}tile"
@self.router.get(r"/tiles/{z}/{x}/{y}", **tile_endpoint_params)
@self.router.get(r"/tiles/{z}/{x}/{y}.{format}", **tile_endpoint_params)
@self.router.get(r"/tiles/{z}/{x}/{y}@{scale}x", **tile_endpoint_params)
@self.router.get(r"/tiles/{z}/{x}/{y}@{scale}x.{format}", **tile_endpoint_params)
@self.router.get(r"/tiles/{TileMatrixSetId}/{z}/{x}/{y}", **tile_endpoint_params)
@self.router.get(
r"/tiles/{TileMatrixSetId}/{z}/{x}/{y}.{format}", **tile_endpoint_params
)
@self.router.get(
r"/tiles/{TileMatrixSetId}/{z}/{x}/{y}@{scale}x", **tile_endpoint_params
)
@self.router.get(
r"/tiles/{TileMatrixSetId}/{z}/{x}/{y}@{scale}x.{format}",
**tile_endpoint_params,
)
def tile(
z: int = Path(..., ge=0, le=30, description="Mercator tiles's zoom level"),
x: int = Path(..., description="Mercator tiles's column"),
y: int = Path(..., description="Mercator tiles's row"),
tms=Depends(self.tms_dependency),
scale: int = Query(
1, gt=0, lt=4, description="Tile size scale. 1=256x256, 2=512x512..."
),
format: ImageType = Query(None, description="Output image type. Default is auto."),
src_path=Depends(self.path_dependency),
params=Depends(self.tiles_dependency),
options=Depends(self.options),
cache_client: CacheLayer = Depends(utils.get_cache),
request_id: str = Depends(request_hash),
):
"""Create map tile from a dataset."""
timings = []
headers: Dict[str, str] = {}
tilesize = scale * 256
content = None
if cache_client:
try:
content, ext = cache_client.get_image_from_cache(request_id)
format = ImageType[ext]
headers["X-Cache"] = "HIT"
except Exception:
content = None
if not content:
with utils.Timer() as t:
reader = src_path.reader or self.reader
with reader(src_path.url, tms=tms, **self.reader_options) as src_dst:
tile, mask = src_dst.tile(
x,
y,
z,
tilesize=tilesize,
indexes=params.indexes,
expression=params.expression,
nodata=params.nodata,
resampling_method=params.resampling_method.name,
**options.kwargs,
)
colormap = (
params.colormap or getattr(src_dst, "colormap", None)
)
timings.append(("Read", t.elapsed))
if not format:
format = ImageType.jpg if mask.all() else ImageType.png
with utils.Timer() as t:
tile = utils.postprocess(
tile,
mask,
rescale=params.rescale,
color_formula=params.color_formula,
)
timings.append(("Post-process", t.elapsed))
bounds = tms.xy_bounds(x, y, z)
dst_transform = from_bounds(*bounds, tilesize, tilesize)
with utils.Timer() as t:
content = utils.reformat(
tile,
mask,
format,
colormap=colormap,
transform=dst_transform,
crs=tms.crs,
)
timings.append(("Format", t.elapsed))
if cache_client and content:
cache_client.set_image_cache(request_id, (content, format.value))
if timings:
headers["X-Server-Timings"] = "; ".join(
["{} - {:0.2f}".format(name, time * 1000) for (name, time) in timings]
)
return Response(
content, media_type=ImageMimeTypes[format.value].value, headers=headers,
)
@self.router.get(
"/tilejson.json",
response_model=TileJSON,
responses={200: {"description": "Return a tilejson"}},
response_model_exclude_none=True,
name=f"{self.router_prefix}tilejson",
)
@self.router.get(
"/{TileMatrixSetId}/tilejson.json",
response_model=TileJSON,
responses={200: {"description": "Return a tilejson"}},
response_model_exclude_none=True,
name=f"{self.router_prefix}tilejson",
)
def tilejson(
request: Request,
tms=Depends(self.tms_dependency),
src_path=Depends(self.path_dependency),
tile_format: Optional[ImageType] = Query(
None, description="Output image type. Default is auto."
),
tile_scale: int = Query(
1, gt=0, lt=4, description="Tile size scale. 1=256x256, 2=512x512..."
),
minzoom: Optional[int] = Query(None, description="Overwrite default minzoom."),
maxzoom: Optional[int] = Query(None, description="Overwrite default maxzoom."),
):
"""Return TileJSON document for a dataset."""
kwargs = {
"z": "{z}",
"x": "{x}",
"y": "{y}",
"scale": tile_scale,
"TileMatrixSetId": tms.identifier,
}
if tile_format:
kwargs["format"] = tile_format.value
q = dict(request.query_params)
q.pop("TileMatrixSetId", None)
q.pop("tile_format", None)
q.pop("tile_scale", None)
q.pop("minzoom", None)
q.pop("maxzoom", None)
qs = urlencode(list(q.items()))
tiles_url = request.url_for(f"{self.router_prefix}tile", **kwargs)
tiles_url += f"?{qs}"
reader = src_path.reader or self.reader
with reader(src_path.url, tms=tms, **self.reader_options) as src_dst:
center = list(src_dst.center)
if minzoom:
center[-1] = minzoom
tjson = {
"bounds": src_dst.bounds,
"center": tuple(center),
"minzoom": minzoom if minzoom is not None else src_dst.minzoom,
"maxzoom": maxzoom if maxzoom is not None else src_dst.maxzoom,
"name": os.path.basename(src_path.url),
"tiles": [tiles_url],
}
return tjson
@self.router.get(
"/WMTSCapabilities.xml",
response_class=XMLResponse,
name=f"{self.router_prefix}wmts",
tags=["OGC"],
)
@self.router.get(
"/{TileMatrixSetId}/WMTSCapabilities.xml",
response_class=XMLResponse,
name=f"{self.router_prefix}wmts",
tags=["OGC"],
)
def wmts(
request: Request,
tms=Depends(self.tms_dependency),
src_path=Depends(self.path_dependency),
tile_format: ImageType = Query(
ImageType.png, description="Output image type. Default is png."
),
tile_scale: int = Query(
1, gt=0, lt=4, description="Tile size scale. 1=256x256, 2=512x512..."
),
minzoom: Optional[int] = Query(None, description="Overwrite default minzoom."),
maxzoom: Optional[int] = Query(None, description="Overwrite default maxzoom."),
):
"""OGC WMTS endpoint."""
kwargs = {
"z": "{TileMatrix}",
"x": "{TileCol}",
"y": "{TileRow}",
"scale": tile_scale,
"format": tile_format.value,
"TileMatrixSetId": tms.identifier,
}
tiles_endpoint = request.url_for(f"{self.router_prefix}tile", **kwargs)
q = dict(request.query_params)
q.pop("TileMatrixSetId", None)
q.pop("tile_format", None)
q.pop("tile_scale", None)
q.pop("minzoom", None)
q.pop("maxzoom", None)
q.pop("SERVICE", None)
q.pop("REQUEST", None)
qs = urlencode(list(q.items()))
tiles_endpoint += f"?{qs}"
reader = src_path.reader or self.reader
with reader(src_path.url, tms=tms, **self.reader_options) as src_dst:
bounds = src_dst.bounds
minzoom = minzoom if minzoom is not None else src_dst.minzoom
maxzoom = maxzoom if maxzoom is not None else src_dst.maxzoom
media_type = ImageMimeTypes[tile_format.value].value
tileMatrix = []
for zoom in range(minzoom, maxzoom + 1):
matrix = tms.matrix(zoom)
tm = f"""
<TileMatrix>
<ows:Identifier>{matrix.identifier}</ows:Identifier>
<ScaleDenominator>{matrix.scaleDenominator}</ScaleDenominator>
<TopLeftCorner>{matrix.topLeftCorner[0]} {matrix.topLeftCorner[1]}</TopLeftCorner>
<TileWidth>{matrix.tileWidth}</TileWidth>
<TileHeight>{matrix.tileHeight}</TileHeight>
<MatrixWidth>{matrix.matrixWidth}</MatrixWidth>
<MatrixHeight>{matrix.matrixHeight}</MatrixHeight>
</TileMatrix>"""
tileMatrix.append(tm)
return templates.TemplateResponse(
"wmts.xml",
{
"request": request,
"tiles_endpoint": tiles_endpoint,
"bounds": bounds,
"tileMatrix": tileMatrix,
"tms": tms,
"title": "Cloud Optimized GeoTIFF",
"layer_name": "cogeo",
"media_type": media_type,
},
media_type=MimeTypes.xml.value,
)
############################################################################
# /point
############################################################################
def _point(self):
@self.router.get(
r"/point/{lon},{lat}",
responses={200: {"description": "Return a value for a point"}},
name=f"{self.router_prefix}point",
)
def point(
lon: float = Path(..., description="Longitude"),
lat: float = Path(..., description="Latitude"),
src_path=Depends(self.path_dependency),
params=Depends(self.point_dependency),
options=Depends(self.options),
):
"""Get Point value for a dataset."""
timings = []
headers: Dict[str, str] = {}
with utils.Timer() as t:
reader = src_path.reader or self.reader
with reader(src_path.url, **self.reader_options) as src_dst:
values = src_dst.point(
lon,
lat,
indexes=params.indexes,
expression=params.expression,
nodata=params.nodata,
**options.kwargs,
)
timings.append(("Read", t.elapsed))
if timings:
headers["X-Server-Timings"] = "; ".join(
["{} - {:0.2f}".format(name, time * 1000) for (name, time) in timings]
)
return {"coordinates": [lon, lat], "values": values}
############################################################################
# /preview (Optional)
############################################################################
def _preview(self):
prev_endpoint_params = img_endpoint_params.copy()
prev_endpoint_params["name"] = f"{self.router_prefix}preview"
@self.router.get(r"/preview", **prev_endpoint_params)
@self.router.get(r"/preview.{format}", **prev_endpoint_params)
def preview(
format: ImageType = Query(
None, description="Output image type. Default is auto."
),
src_path=Depends(self.path_dependency),
params=Depends(self.img_dependency),
options=Depends(self.options),
):
"""Create preview of a dataset."""
timings = []
headers: Dict[str, str] = {}
with utils.Timer() as t:
reader = src_path.reader or self.reader
with reader(src_path.url, **self.reader_options) as src_dst:
data, mask = src_dst.preview(
height=params.height,
width=params.width,
max_size=params.max_size,
indexes=params.indexes,
expression=params.expression,
nodata=params.nodata,
resampling_method=params.resampling_method.name,
**options.kwargs,
)
colormap = (
params.colormap or getattr(src_dst, "colormap", None)
)
timings.append(("Read", t.elapsed))
if not format:
format = ImageType.jpg if mask.all() else ImageType.png
with utils.Timer() as t:
data = utils.postprocess(
data,
mask,
rescale=params.rescale,
color_formula=params.color_formula,
)
timings.append(("Post-process", t.elapsed))
with utils.Timer() as t:
content = utils.reformat(data, mask, format, colormap=colormap)
timings.append(("Format", t.elapsed))
if timings:
headers["X-Server-Timings"] = "; ".join(
["{} - {:0.2f}".format(name, time * 1000) for (name, time) in timings]
)
return Response(
content, media_type=ImageMimeTypes[format.value].value, headers=headers,
)
############################################################################
# /crop (Optional)
############################################################################
def _part(self):
part_endpoint_params = img_endpoint_params.copy()
part_endpoint_params["name"] = f"{self.router_prefix}part"
# @router.get(r"/crop/{minx},{miny},{maxx},{maxy}", **part_endpoint_params)
@self.router.get(
r"/crop/{minx},{miny},{maxx},{maxy}.{format}",
**part_endpoint_params,
)
def part(
minx: float = Path(..., description="Bounding box min X"),
miny: float = Path(..., description="Bounding box min Y"),
maxx: float = Path(..., description="Bounding box max X"),
maxy: float = Path(..., description="Bounding box max Y"),
format: ImageType = Query(None, description="Output image type."),
src_path=Depends(self.path_dependency),
params=Depends(self.img_dependency),
options=Depends(self.options),
):
"""Create image from part of a dataset."""
timings = []
headers: Dict[str, str] = {}
with utils.Timer() as t:
reader = src_path.reader or self.reader
with reader(src_path.url, **self.reader_options) as src_dst:
data, mask = src_dst.part(
[minx, miny, maxx, maxy],
height=params.height,
width=params.width,
max_size=params.max_size,
indexes=params.indexes,
expression=params.expression,
nodata=params.nodata,
resampling_method=params.resampling_method.name,
**options.kwargs,
)
colormap = (
params.colormap or getattr(src_dst, "colormap", None)
)
timings.append(("Read", t.elapsed))
if not format:
format = ImageType.jpg if mask.all() else ImageType.png
with utils.Timer() as t:
data = utils.postprocess(
data,
mask,
rescale=params.rescale,
color_formula=params.color_formula,
)
timings.append(("Post-process", t.elapsed))
with utils.Timer() as t:
content = utils.reformat(data, mask, format, colormap=colormap)
timings.append(("Format", t.elapsed))
if timings:
headers["X-Server-Timings"] = "; ".join(
["{} - {:0.2f}".format(name, time * 1000) for (name, time) in timings]
)
return Response(
content,
media_type=ImageMimeTypes[format.value].value,
headers=headers,
)
"""API."""
from dataclasses import dataclass, field
import os
from urllib.parse import urlencode
from typing import Callable, Dict, Optional
import pkg_resources
from rasterio.transform import from_bounds
from cogeo_mosaic.mosaic import MosaicJSON
from cogeo_mosaic.utils import get_footprints
from cogeo_mosaic.backends import BaseBackend, MosaicBackend
from rio_tiler.io import BaseReader
from rio_tiler_crs import COGReader
from rio_tiler.constants import MAX_THREADS
from fastapi import Depends, Path, Query
from starlette.requests import Request
from starlette.responses import Response
from starlette.templating import Jinja2Templates
from .. import utils
from ..db.memcache import CacheLayer
from ..models.cog import cogBounds
from ..models.mapbox import TileJSON
from ..models.mosaic import CreateMosaicJSON, UpdateMosaicJSON, mosaicInfo
from ..ressources.common import img_endpoint_params
from ..ressources.enums import ImageMimeTypes, ImageType, MimeTypes, PixelSelectionMethod
from ..ressources.responses import XMLResponse
from ..errors import BadRequestError, TileNotFoundError
from ..dependencies import MosaicTMSParams, request_hash
from .factory import BaseFactory
template_dir = pkg_resources.resource_filename("titiler", "templates")
templates = Jinja2Templates(directory=template_dir)
@dataclass
class MosaicTilerFactory(BaseFactory):
reader: BaseBackend = field(default=MosaicBackend)
dataset_reader: BaseReader = field(default=COGReader)
tms_dependency: Callable = field(default=MosaicTMSParams)
add_asset_deps: bool = True # We add if by default
def __post_init__(self):
super().__post_init__()
def register_routes(self):
"""
This Method register routes to the router.
Because we wrap the endpoints in a class we cannot define the routes as
methods (because of the self argument). The HACK is to define routes inside
the class method and register them after the class initialisation.
"""
self._read()
self._create()
self._update()
self._bounds()
self._info()
self._tile()
self._point()
############################################################################
# /read
############################################################################
def _read(self):
"""Add / - GET (Read) route."""
@self.router.get(
"",
response_model=MosaicJSON,
response_model_exclude_none=True,
responses={200: {"description": "Return MosaicJSON definition"}},
name=f"{self.router_prefix}read",
)
def read(src_path=Depends(self.path_dependency),):
"""Read a MosaicJSON"""
with self.reader(src_path.url) as mosaic:
return mosaic.mosaic_def
############################################################################
# /create
############################################################################
def _create(self):
"""Add / - POST (create) route."""
@self.router.post(
"",
response_model=MosaicJSON,
response_model_exclude_none=True,
name=f"{self.router_prefix}create",
)
def create(body: CreateMosaicJSON):
"""Create a MosaicJSON"""
mosaic = MosaicJSON.from_urls(
body.files,
minzoom=body.minzoom,
maxzoom=body.maxzoom,
max_threads=body.max_threads,
)
src_path = self.path_dependency(body.url)
reader = src_path.reader or self.dataset_reader
with self.reader(src_path.url, mosaic_def=mosaic, reader=reader) as mosaic:
try:
mosaic.write()
except NotImplementedError:
raise BadRequestError(
f"{mosaic.__class__.__name__} does not support write operations"
)
return mosaic.mosaic_def
############################################################################
# /update
############################################################################
def _update(self):
"""Add / - PUT (update) route."""
@self.router.put(
"",
response_model=MosaicJSON,
response_model_exclude_none=True,
name=f"{self.router_prefix}update",
)
def update_mosaicjson(body: UpdateMosaicJSON):
"""Update an existing MosaicJSON"""
src_path = self.path_dependency(body.url)
reader = src_path.reader or self.dataset_reader
with self.reader(src_path.url, reader=reader) as mosaic:
features = get_footprints(body.files, max_threads=body.max_threads)
try:
mosaic.update(features, add_first=body.add_first, quiet=True)
except NotImplementedError:
raise BadRequestError(
f"{mosaic.__class__.__name__} does not support update operations"
)
return mosaic.mosaic_def
############################################################################
# /bounds
############################################################################
def _bounds(self):
"""Register /bounds endpoint to router."""
@self.router.get(
"/bounds",
response_model=cogBounds,
responses={200: {"description": "Return the bounds of the MosaicJSON"}},
name=f"{self.router_prefix}bounds",
)
def bounds(src_path=Depends(self.path_dependency)):
"""Return the bounds of the COG."""
with self.reader(src_path.url) as src_dst:
return {"bounds": src_dst.bounds}
############################################################################
# /info
############################################################################
def _info(self):
"""Register /info endpoint to router."""
@self.router.get(
"/info",
response_model=mosaicInfo,
responses={200: {"description": "Return info about the MosaicJSON"}},
name=f"{self.router_prefix}info",
)
def info(src_path=Depends(self.path_dependency)):
"""Return basic info."""
with self.reader(src_path.url) as src_dst:
info = src_dst.info()
return info
############################################################################
# /tiles
############################################################################
def _tile(self):
tile_endpoint_params = img_endpoint_params.copy()
tile_endpoint_params["name"] = f"{self.router_prefix}tile"
@self.router.get(r"/tiles/{z}/{x}/{y}", **tile_endpoint_params)
@self.router.get(r"/tiles/{z}/{x}/{y}.{format}", **tile_endpoint_params)
@self.router.get(r"/tiles/{z}/{x}/{y}@{scale}x", **tile_endpoint_params)
@self.router.get(r"/tiles/{z}/{x}/{y}@{scale}x.{format}", **tile_endpoint_params)
@self.router.get(r"/tiles/{TileMatrixSetId}/{z}/{x}/{y}", **tile_endpoint_params)
@self.router.get(
r"/tiles/{TileMatrixSetId}/{z}/{x}/{y}.{format}", **tile_endpoint_params
)
@self.router.get(
r"/tiles/{TileMatrixSetId}/{z}/{x}/{y}@{scale}x", **tile_endpoint_params
)
@self.router.get(
r"/tiles/{TileMatrixSetId}/{z}/{x}/{y}@{scale}x.{format}",
**tile_endpoint_params,
)
def tile(
z: int = Path(..., ge=0, le=30, description="Mercator tiles's zoom level"),
x: int = Path(..., description="Mercator tiles's column"),
y: int = Path(..., description="Mercator tiles's row"),
tms=Depends(self.tms_dependency),
scale: int = Query(
1, gt=0, lt=4, description="Tile size scale. 1=256x256, 2=512x512..."
),
format: ImageType = Query(None, description="Output image type. Default is auto."),
src_path=Depends(self.path_dependency),
params=Depends(self.tiles_dependency),
options=Depends(self.options),
pixel_selection: PixelSelectionMethod = Query(
PixelSelectionMethod.first, description="Pixel selection method."
),
cache_client: CacheLayer = Depends(utils.get_cache),
request_id: str = Depends(request_hash),
):
"""Create map tile from a COG."""
timings = []
headers: Dict[str, str] = {}
tilesize = scale * 256
content = None
if cache_client:
try:
content, ext = cache_client.get_image_from_cache(request_id)
format = ImageType[ext]
headers["X-Cache"] = "HIT"
except Exception:
content = None
if not content:
with utils.Timer() as t:
reader = src_path.reader or self.dataset_reader
reader_options = {**self.reader_options, "tms": tms}
threads = int(os.getenv("MOSAIC_CONCURRENCY", MAX_THREADS))
with self.reader(
src_path.url, reader=reader, reader_options=reader_options
) as src_dst:
(data, mask), assets_used = src_dst.tile(
x,
y,
z,
pixel_selection=pixel_selection.method(),
threads=threads,
tilesize=tilesize,
indexes=params.indexes,
expression=params.expression,
nodata=params.nodata,
resampling_method=params.resampling_method.name,
**options.kwargs,
)
timings.append(("Read-tile", t.elapsed))
if data is None:
raise TileNotFoundError(f"Tile {z}/{x}/{y} was not found")
if not format:
format = ImageType.jpg if mask.all() else ImageType.png
with utils.Timer() as t:
data = utils.postprocess(
data,
mask,
rescale=params.rescale,
color_formula=params.color_formula,
)
timings.append(("Post-process", t.elapsed))
bounds = tms.xy_bounds(x, y, z)
dst_transform = from_bounds(*bounds, tilesize, tilesize)
with utils.Timer() as t:
content = utils.reformat(
tile,
mask,
format,
colormap=params.colormap,
transform=dst_transform,
crs=tms.crs,
)
timings.append(("Format", t.elapsed))
if cache_client and content:
cache_client.set_image_cache(request_id, (content, format.value))
if timings:
headers["X-Server-Timings"] = "; ".join(
["{} - {:0.2f}".format(name, time * 1000) for (name, time) in timings]
)
if assets_used:
headers["X-Assets"] = ",".join(assets_used)
return Response(
content, media_type=ImageMimeTypes[format.value].value, headers=headers,
)
@self.router.get(
"/tilejson.json",
response_model=TileJSON,
responses={200: {"description": "Return a tilejson"}},
response_model_exclude_none=True,
name=f"{self.router_prefix}tilejson",
)
@self.router.get(
"/{TileMatrixSetId}/tilejson.json",
response_model=TileJSON,
responses={200: {"description": "Return a tilejson"}},
response_model_exclude_none=True,
name=f"{self.router_prefix}tilejson",
)
def tilejson(
request: Request,
tms=Depends(self.tms_dependency),
src_path=Depends(self.path_dependency),
tile_format: Optional[ImageType] = Query(
None, description="Output image type. Default is auto."
),
tile_scale: int = Query(
1, gt=0, lt=4, description="Tile size scale. 1=256x256, 2=512x512..."
),
minzoom: Optional[int] = Query(None, description="Overwrite default minzoom."),
maxzoom: Optional[int] = Query(None, description="Overwrite default maxzoom."),
):
"""Return TileJSON document for a COG."""
kwargs = {
"z": "{z}",
"x": "{x}",
"y": "{y}",
"scale": tile_scale,
"TileMatrixSetId": tms.identifier,
}
if tile_format:
kwargs["format"] = tile_format.value
q = dict(request.query_params)
q.pop("TileMatrixSetId", None)
q.pop("tile_format", None)
q.pop("tile_scale", None)
q.pop("minzoom", None)
q.pop("maxzoom", None)
qs = urlencode(list(q.items()))
tiles_url = request.url_for(f"{self.router_prefix}tile", **kwargs)
tiles_url += f"?{qs}"
with self.reader(src_path.url,) as src_dst:
center = list(src_dst.center)
if minzoom:
center[-1] = minzoom
tjson = {
"bounds": src_dst.bounds,
"center": tuple(center),
"minzoom": minzoom if minzoom is not None else src_dst.minzoom,
"maxzoom": maxzoom if maxzoom is not None else src_dst.maxzoom,
"name": os.path.basename(src_path.url),
"tiles": [tiles_url],
}
return tjson
@self.router.get(
"/WMTSCapabilities.xml",
response_class=XMLResponse,
name=f"{self.router_prefix}wmts",
tags=["OGC"],
)
@self.router.get(
"/{TileMatrixSetId}/WMTSCapabilities.xml",
response_class=XMLResponse,
name=f"{self.router_prefix}wmts",
tags=["OGC"],
)
def wmts(
request: Request,
tms=Depends(self.tms_dependency),
src_path=Depends(self.path_dependency),
tile_format: ImageType = Query(
ImageType.png, description="Output image type. Default is png."
),
tile_scale: int = Query(
1, gt=0, lt=4, description="Tile size scale. 1=256x256, 2=512x512..."
),
minzoom: Optional[int] = Query(None, description="Overwrite default minzoom."),
maxzoom: Optional[int] = Query(None, description="Overwrite default maxzoom."),
):
"""OGC WMTS endpoint."""
kwargs = {
"z": "{TileMatrix}",
"x": "{TileCol}",
"y": "{TileRow}",
"scale": tile_scale,
"format": tile_format.value,
"TileMatrixSetId": tms.identifier,
}
tiles_endpoint = request.url_for(f"{self.router_prefix}tile", **kwargs)
q = dict(request.query_params)
q.pop("TileMatrixSetId", None)
q.pop("tile_format", None)
q.pop("tile_scale", None)
q.pop("minzoom", None)
q.pop("maxzoom", None)
q.pop("SERVICE", None)
q.pop("REQUEST", None)
qs = urlencode(list(q.items()))
tiles_endpoint += f"?{qs}"
with self.reader(src_path.url) as src_dst:
bounds = src_dst.bounds
minzoom = minzoom if minzoom is not None else src_dst.minzoom
maxzoom = maxzoom if maxzoom is not None else src_dst.maxzoom
media_type = ImageMimeTypes[tile_format.value].value
tileMatrix = []
for zoom in range(minzoom, maxzoom + 1):
matrix = tms.matrix(zoom)
tm = f"""
<TileMatrix>
<ows:Identifier>{matrix.identifier}</ows:Identifier>
<ScaleDenominator>{matrix.scaleDenominator}</ScaleDenominator>
<TopLeftCorner>{matrix.topLeftCorner[0]} {matrix.topLeftCorner[1]}</TopLeftCorner>
<TileWidth>{matrix.tileWidth}</TileWidth>
<TileHeight>{matrix.tileHeight}</TileHeight>
<MatrixWidth>{matrix.matrixWidth}</MatrixWidth>
<MatrixHeight>{matrix.matrixHeight}</MatrixHeight>
</TileMatrix>"""
tileMatrix.append(tm)
return templates.TemplateResponse(
"wmts.xml",
{
"request": request,
"tiles_endpoint": tiles_endpoint,
"bounds": bounds,
"tileMatrix": tileMatrix,
"tms": tms,
"title": "Cloud Optimized GeoTIFF",
"layer_name": "cogeo",
"media_type": media_type,
},
media_type=MimeTypes.xml.value,
)
############################################################################
# /point (Optional)
############################################################################
def _point(self):
@self.router.get(
r"/point/{lon},{lat}",
responses={200: {"description": "Return a value for a point"}},
name=f"{self.router_prefix}point",
)
def point(
lon: float = Path(..., description="Longitude"),
lat: float = Path(..., description="Latitude"),
src_path=Depends(self.path_dependency),
params=Depends(self.point_dependency),
options=Depends(self.options),
):
"""Get Point value for a Mosaic."""
timings = []
headers: Dict[str, str] = {}
threads = int(os.getenv("MOSAIC_CONCURRENCY", MAX_THREADS))
with utils.Timer() as t:
reader = src_path.reader or self.dataset_reader
with self.reader(
src_path.url,
reader=reader,
reader_options=self.reader_options,
) as src_dst:
values = src_dst.point(
lon,
lat,
threads=threads,
indexes=params.indexes,
expression=params.expression,
nodata=params.nodata,
**options.kwargs,
)
timings.append(("Read", t.elapsed))
if timings:
headers["X-Server-Timings"] = "; ".join(
["{} - {:0.2f}".format(name, time * 1000) for (name, time) in timings]
)
return {"coordinates": [lon, lat], "values": values}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment