Created
January 14, 2016 18:16
-
-
Save prschmid/c56e4aa9da58e42b175f to your computer and use it in GitHub Desktop.
Decorator to validate HTTP parameters and submitted JSON for Flask routes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from marshmallow import ( | |
fields, | |
Schema) | |
def marshmallow_schema_to_dict(schema): | |
"""Convert a :class:`marshmallow.Schema` to a dict definition | |
:param schema: The :class:`marshmallow.Schema` to convert | |
:returns: A dict containing the details of the schema | |
""" | |
return { | |
'fields': [ | |
{ | |
name: { | |
'type': field.__class__.__name__, | |
'required': field.required, | |
'allow_none': field.allow_none, | |
} | |
} | |
for name, field in | |
schema._declared_fields.iteritems()] | |
} | |
def build_schema_from_dict(d, allow_nested=True): | |
"""Build a Marshmallow schema based on a dictionary of parameters | |
:param d: The dict of parameters to use to build the Schema | |
:param allow_nested: Whether or not nested schemas are allowed. If | |
``True`` then a fields.Nested() will be created | |
when there is a nested value. | |
:return: A Marshmallow schema based on the dictionary | |
""" | |
for k, v in d.iteritems(): | |
if isinstance(v, tuple): | |
schema = v[0] | |
if len(v) > 1: | |
opts = v[1] | |
elif isinstance(v, dict): | |
schema = v | |
opts = {} | |
else: | |
continue | |
if not allow_nested: | |
raise ValueError("Nested attributes not allowed.") | |
# Recursively generate the nested schema(s) | |
schema = build_schema_from_dict(schema) | |
# Update the current dict with the Nested schema | |
d[k] = fields.Nested(schema, **opts) | |
return type('Schema', (Schema, ), d) | |
def ensure(params=None, input=None): | |
"""Decorator to validate HTTP parameters and submitted JSON. | |
Usage: | |
# Using this by just defining the attributes and no explicit schema | |
from marshmallow import field | |
@route('/foo') | |
@ensure( | |
input={ | |
'bar': fields.Str(required=True), | |
'baz': fields.Str() | |
}) | |
def foo(): | |
pass | |
# Support for options on loading (e.g. loading a list with many=True) | |
@route('/foo') | |
@ensure( | |
input=( | |
{ | |
'bar': fields.Str(required=True), | |
'baz': fields.Str(), | |
}, | |
{ | |
'many': True | |
}) | |
def foo(): | |
pass | |
# Support for nested "schemas" | |
@route('/foo') | |
@ensure( | |
input={ | |
'bar': fields.Str(required=True), | |
'baz': { | |
'bam': fields.Str() | |
} | |
}) | |
def foo(): | |
pass | |
# Support for nested "schemas" with options | |
@route('/foo') | |
@ensure( | |
input={ | |
'bar': fields.Str(required=True), | |
'baz': ( | |
{ | |
'bam': fields.Str() | |
}, | |
{ | |
'required': True, | |
'many': True | |
} | |
}) | |
def foo(): | |
pass | |
# Defining an explicit schema for the validation | |
from marshmallow import field, Schema | |
FooSchema(Schema): | |
bar = fields.Str(required=True) | |
baz = fields.Str() | |
@route('/foo') | |
@ensure(input=FooSchema) | |
def foo(): | |
pass | |
Note: This method makes use of a context manager called `ignored` to ignore | |
expected exceptions. | |
@contextmanager | |
def ignored(*exceptions): | |
try: | |
yield | |
except exceptions: | |
pass | |
:param params: The input :class:`marshmallow.Schema` or a dict of the fields | |
of the schema to use for the request parameters | |
:param input: The input :class:`marshmallow.Schema` or a dict of the fields | |
of the schema to use for the input JSON data | |
""" | |
# A simple named tuple to keep track of schemas and their loading options | |
SchemaDefinition = namedtuple("SchemaDefinition", ["schema", "options"]) | |
# Convert the input fields into Schemas if the were provided | |
# as dictionaries | |
schemas = { | |
'params': params, | |
'input': input | |
} | |
for name, schema in schemas.iteritems(): | |
if not schema: | |
continue | |
# Hack to make sure we don't have nested input parameter definitions | |
allow_nested = True | |
if name == 'params': | |
allow_nested = False | |
if isinstance(schema, tuple): | |
if isinstance(schema[0], dict): | |
schemas[name] = SchemaDefinition( | |
build_schema_from_dict( | |
copy.deepcopy(schema[0]), allow_nested=allow_nested), | |
schema[1]) | |
else: | |
schemas[name] = SchemaDefinition( | |
schema[0], | |
schema[1]) | |
elif isinstance(schema, dict): | |
schemas[name] = SchemaDefinition( | |
build_schema_from_dict( | |
copy.deepcopy(schema), allow_nested=allow_nested), | |
{}) | |
else: | |
schemas[name] = SchemaDefinition( | |
schema, | |
{}) | |
def load(args, schema, options={}): | |
"""Perform the loading of the data into the given schema | |
:param args: The arguments provided by the user from the endpoint | |
:param schema: The :class:`marshmallow.Schema` class to load the data | |
into | |
:param options: A dict of options to pass to the load() method | |
""" | |
# Remove the 'required' param from the options since | |
# that is not a top level option... but we still want to | |
# allow a user to validate against that | |
required = options.get('required', False) | |
with ignored(KeyError): | |
del options['required'] | |
# If we allow many, but only a singleton was provided, convert the | |
# input args to a list | |
if options.get('many', False) and not isinstance(args, list): | |
args = [args] | |
data, errors = schema().load( | |
args, | |
**options) | |
if required and not data: | |
raise ValueError("No data provided") | |
return data, errors | |
def wrap(f): | |
@functools.wraps(f) | |
def wrapper(*args, **kwargs): | |
data = None | |
# Ensure the URL parameters | |
if schemas['params']: | |
try: | |
data, errors = load( | |
request.args, | |
schemas['params'].schema, | |
schemas['params'].options) | |
except (ValueError, ValidationError) as exc: | |
errors = { | |
'parsing': | |
"Could not validate HTTP arguments. {}".format( | |
exc.message) | |
} | |
except Exception: | |
errors = {'parsing': "Could not validate HTTP arguments"} | |
if errors: | |
raise BadRequestApiError( | |
message=errors, | |
schema=marshmallow_schema_to_dict( | |
schemas['params'].schema)) | |
else: | |
request.params = data | |
else: | |
request.params = {} | |
# Ensure the input | |
if schemas['input']: | |
require_json_content_type() | |
input = {} | |
# Don't fail on requests with no JSON | |
with ignored(BadRequest): | |
input = request.json | |
try: | |
data, errors = load( | |
input, | |
schemas['input'].schema, | |
schemas['input'].options) | |
except (ValueError, ValidationError) as exc: | |
errors = { | |
'parsing': | |
"Could not validate input. {}".format(exc.message) | |
} | |
except Exception: | |
errors = {'parsing': "Could not validate input"} | |
if errors: | |
raise BadRequestApiError( | |
message=errors, | |
schema=marshmallow_schema_to_dict( | |
schemas['input'].schema)) | |
else: | |
request.input = data | |
# Ok, do what we came to do | |
return f(*args, **kwargs) | |
return wrapper | |
return wrap |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment