Skip to content

Instantly share code, notes, and snippets.

@nonamenix
Last active April 5, 2019 19:37
Show Gist options
  • Save nonamenix/cbec59644eacc7661c86b5caa75cfb7e to your computer and use it in GitHub Desktop.
Save nonamenix/cbec59644eacc7661c86b5caa75cfb7e to your computer and use it in GitHub Desktop.
Swagger specification for hug. Ugly draft.
import inspect
import collections
from collections import OrderedDict
import hug
import logging
from apispec import APISpec
from apispec.ext.marshmallow.swagger import field2parameter
from copy import copy
from defaultsettings import DefaultSettings
import importlib
from marshmallow import fields, Schema
from marshmallow.schema import SchemaMeta
from hug_extensions.hug_swagger.testingschemas import TestingSchema
from . import swagger
logger = logging.getLogger(__name__)
class Settings(DefaultSettings):
HOST = 'localhost:9001'
SCHEMES = ['http']
VERSION = '0.1'
TITLE = 'Swagger Application'
DEFINITIONS_PATH = None
TESTING_MODE = False
settings = Settings('SWAGGER_')
del Settings
def get_summary(description):
return description.split('\n')[0]
def where_is_parameter(name, url):
# TODO: body, header
return 'path' if '{%s}' % name in url else 'query'
def get_parameters(url, interface):
defaults = interface.defaults
sig = inspect.signature(interface.interface.spec)
parameters = {}
for name in interface.parameters:
parameter_type = sig.parameters[name].annotation
if getattr(parameter_type, 'directive', False):
logger.info('Skip directive: %s for url: %s ', name, url)
continue
if parameter_type != inspect.Parameter.empty:
# path and query
if isinstance(parameter_type, fields.Field):
parameter_place = where_is_parameter(name, url)
parameter_type.metadata = {'location': where_is_parameter(name, url)}
parameter_type.required = name not in defaults
parameter = field2parameter(parameter_type, name=name, default_in=parameter_place)
if name in defaults:
parameter['default'] = defaults[name]
parameters[name] = parameter
# body
elif name == 'body' and (isinstance(parameter_type, Schema) or isinstance(parameter_type, SchemaMeta)):
if isinstance(parameter_type, Schema):
schema_name = parameter_type.__class__.__name__
elif isinstance(parameter_type, SchemaMeta):
schema_name = parameter_type.__name__
ref_definition = "#/definitions/{}".format(schema_name)
ref_schema = {"$ref": ref_definition}
parameters['body'] = {
"in": "body",
"name": "body",
"required": True,
"schema": ref_schema
}
else:
logger.error('Use marshmallow fields in url: %s instead of hug: %s %s', url, name, parameter_type)
else:
# logger.info('There is no type annotation for %s in url: %s', name, url)
pass
return parameters
def get_operation_and_define_response_schemas(interface, spec):
handler = interface.interface.spec
sig = inspect.signature(handler) # type: Signature
annotated_response_schema = sig.return_annotation
responses = copy(getattr(handler, 'swagger_responses', OrderedDict()))
if annotated_response_schema != inspect.Parameter.empty:
responses.setdefault(200, {})['schema'] = annotated_response_schema
for code, response in responses.items():
response = copy(response)
try:
schema = response['schema']
if isinstance(schema, str): # schema name provided
name = schema
elif isinstance(schema, Schema): # schema instance provided
name = schema.__class__.__name__
spec.definition(name, schema=schema)
elif isinstance(schema, SchemaMeta): # schema class provided
name = schema.__name__
spec.definition(name, schema=schema())
else:
logger.error('Wrong response schema %s', schema)
schema = None
except KeyError:
pass
else:
if schema is not None:
ref_name = '#/definitions/{}'.format(name)
ref_schema = {'$ref': ref_name}
response["schema"] = ref_schema
responses[code] = response
return responses
@hug.get('/swagger.json')
def swagger_json(hug_api):
spec = APISpec(
title=settings.TITLE,
version=settings.VERSION,
plugins=(
'apispec.ext.marshmallow',
),
schemes=settings.SCHEMES,
host=settings.HOST
)
if settings.DEFINITIONS_PATH is not None:
definitions = importlib.import_module(settings.DEFINITIONS_PATH)
for name, schema in definitions.__dict__.items(): # type: str, Schema
if name.endswith('Schema') and len(name) > len('Schema'):
spec.definition(name, schema=schema)
routes = hug_api.http.routes['']
for url, route in routes.items():
for method, versioned_interfaces in route.items():
for versions, interface in versioned_interfaces.items():
methods_data = {}
documentation = interface.documentation()
methods_data['content_type'] = documentation['outputs']['content_type']
try:
methods_data['summary'] = get_summary(documentation['usage'])
methods_data['description'] = documentation['usage']
except KeyError:
pass
parameters = get_parameters(url, interface)
if parameters:
methods_data['parameters'] = parameters
for name, parameter in parameters.items():
spec.add_parameter(name, parameter['in'], **parameter)
responses = get_operation_and_define_response_schemas(interface, spec)
if responses:
methods_data['responses'] = responses
if not isinstance(versions, collections.Iterable):
versions = [versions]
for version in versions:
versioned_url = '/v{}{}'.format(version, url) if version else url
spec.add_path(versioned_url, operations={
method.lower(): methods_data
})
return spec.to_dict()
if settings.TESTING_MODE:
@hug.get('/swagger/hug/{hug_types_number}/{hug_types_greater_than_5}/')
def openapi_test(
request,
hug_timer,
hug_types_number: hug.types.number,
hug_types_greater_than_5: hug.types.GreaterThan(5)):
"""Endpoint with hug.types
Not versioned api method"""
return {
'hug_types_number': hug_types_number,
'hug_types_greater_than_5': hug_types_greater_than_5
}
@hug.get('/swagger/marshmallow/{swagger_types_number}/', versions=[2, 3])
@swagger.response(200, description='Good response', schema=TestingSchema)
@swagger.response(400, description='Bad response')
def openapi_test_swagger_types(
request,
hug_timer,
swagger_types_number: fields.Integer(),
swagger_types_number_in_query: fields.Integer() = 3) -> TestingSchema():
"""Endpoint with marshmallow types"""
return {
'swagger_types_number': swagger_types_number,
'swagger_types_number_in_query': swagger_types_number_in_query
}
@hug.post('/swagger/marshmallow/post-body') # TODO: check with last slash
@swagger.response(200, description='Created', schema=TestingSchema())
def openapi_post_body(body: TestingSchema()) -> TestingSchema():
return body
from marshmallow import Schema, fields
class TestingFieldsSchema(Schema):
integer = fields.Integer()
float = fields.Float()
boolean = fields.Boolean()
datetime = fields.DateTime()
timedelta = fields.TimeDelta()
dictionary = fields.Dict()
url = fields.Url()
email = fields.Email()
class TestingSchema(Schema):
hug_types_number = fields.Integer()
hug_types_greater_than_5 = fields.Integer()
hug_types_in_range_1_5 = fields.Integer()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment