Skip to content

Instantly share code, notes, and snippets.

@ryanwang520
Created January 30, 2019 03:14
Show Gist options
  • Select an option

  • Save ryanwang520/4ad82ff754f266f645f7f4aa1611b317 to your computer and use it in GitHub Desktop.

Select an option

Save ryanwang520/4ad82ff754f266f645f7f4aa1611b317 to your computer and use it in GitHub Desktop.
import functools
import graphene
from graphene.utils.str_converters import to_snake_case
from graphql import get_default_backend
import logging
import time
from flask import request, current_app
from typing import Tuple
import graphql_server
from graphql.execution import ExecutionResult
from graphql.language.ast import (
FragmentSpread,
FragmentDefinition,
OperationDefinition,
InlineFragment,
IntValue,
Variable,
)
from graphql_server import HttpQueryError
from gemini.globals import statsd
from gemini.graphql.fields import Field
logger = logging.getLogger(__name__)
def deep_get_field(obj, path: Tuple):
item = obj
for i in path:
def extract_non_null(item):
if isinstance(item, graphene.NonNull):
return item.of_type
return item
item = extract_non_null(item)
if isinstance(item, graphene.List):
item = getattr(extract_non_null(item.of_type), i)
elif isinstance(item, graphene.Field):
item = getattr(extract_non_null(item.type), i)
else:
item = getattr(item, i)
return item
def analyze_cost(document, variables):
default_complexity = 0
query = document.schema._query
path = ()
def calc_cost(selection, current_path=None, parent=None, multipliers=(1,)):
costs = 0
# parent_multipliers = multipliers
is_def = isinstance(selection, OperationDefinition)
if isinstance(selection, InlineFragment):
complexity = sum(
calc_cost(s, current_path, selection, multipliers)
for s in selection.selection_set.selections
)
costs += complexity
else:
name = to_snake_case(selection.name.value)
if isinstance(selection, FragmentSpread):
fragment_name = selection.name.value
for fragment in fragments:
if fragment.name.value == fragment_name:
return max(
calc_cost(s, current_path, None, multipliers)
for s in fragment.selection_set.selections
)
if not is_def:
current_path = current_path + ((name,) if not is_def else ())
if parent:
for type in document.schema.types:
if parent.type_condition.name.value == type.__name__:
field = getattr(type, selection.name.value)
break
else:
field = deep_get_field(query, current_path)
# has_explicit_complex = False
multiply_by = 0
if isinstance(field, Field):
complexity = field.cost.complexity
# if complexity:
# has_explicit_complex = True
use_multipliers = field.cost.use_multipliers
if use_multipliers and field.cost.multipliers:
for arg in selection.arguments:
arg_name = arg.name.value
snake_name = to_snake_case(arg_name)
if snake_name in field.cost.multipliers:
val = arg.value
if isinstance(val, IntValue):
multiply_by += int(arg.value.value)
if isinstance(val, Variable):
multiply_by += variables.get(
arg_name, field.args[snake_name].default_value
)
if multiply_by:
multipliers = multipliers + (multiply_by,)
else:
complexity = default_complexity
multiply = functools.reduce(lambda x, y: x * y, multipliers)
costs += multiply * complexity
if selection.selection_set:
costs += sum(
calc_cost(s, current_path, None, multipliers)
for s in selection.selection_set.selections
)
return costs
fragments = [
definition
for definition in document.document_ast.definitions
if isinstance(definition, FragmentDefinition)
]
defs = [
definition
for definition in document.document_ast.definitions
if (not isinstance(definition, FragmentDefinition))
and (definition.operation == "query")
and (definition.name.value != "IntrospectionQuery")
]
if not defs:
return 0
return max(calc_cost(definition, path) for definition in defs)
def get_max_depth(document):
def depth_for_selection(selection):
if isinstance(selection, FragmentSpread):
fragment_name = selection.name.value
for fragment in fragments:
if fragment.name.value == fragment_name:
return depth_for_selection(fragment)
if not selection.selection_set:
return 0
return 1 + max(
depth_for_selection(s) for s in selection.selection_set.selections
)
fragments = [
definition
for definition in document.document_ast.definitions
if isinstance(definition, FragmentDefinition)
]
defs = [
definition
for definition in document.document_ast.definitions
if (not isinstance(definition, FragmentDefinition))
and (definition.name.value != "IntrospectionQuery")
]
if not defs:
return 0
return max(depth_for_selection(definition) for definition in defs)
def execute_graphql_request(
schema, params, allow_only_query=False, backend=None, **kwargs
):
if not params.query:
raise HttpQueryError(400, "Must provide query string.")
try:
if not backend:
backend = get_default_backend()
document = backend.document_from_string(schema, params.query)
except Exception as e:
return ExecutionResult(errors=[e], invalid=True)
operation_type = document.get_operation_type(params.operation_name)
if allow_only_query:
if operation_type and operation_type != "query":
raise HttpQueryError(
405,
"Can only perform a {} operation from a POST request.".format(
operation_type
),
headers={"Allow": "POST"},
)
start = time.time()
operation_name = None
try:
# 检查请求名
for operation_name, operation_type in document.operations_map.items():
if operation_name is None:
logger.error("%s: Query %s 没有命名", request.blueprint, params.query)
raise ValueError("请求非法")
# 检查深度
try:
max_depth = get_max_depth(document)
except Exception as e:
logger.exception(e)
raise
if max_depth > 10:
logger.error("%s: Query %s 深度超过10", request.blueprint, params.query)
raise ValueError("查询过于复杂")
# 检查负载度
if current_app.debug:
try:
cost_start = time.time()
cost = analyze_cost(document, params.variables)
logger.info(
"calc cost use for {} {} s".format(
operation_name, time.time() - cost_start
)
)
logger.info("cost %s", cost)
except Exception as e:
logger.exception(e)
raise
# 拒绝超过5000的请求
if cost > 5000:
raise ValueError("请求节点过多")
return document.execute(
operation_name=params.operation_name, variables=params.variables, **kwargs
)
except Exception as e:
return ExecutionResult(errors=[e], invalid=True)
finally:
statsd.timing(
f"{request.blueprint}.{operation_type}.{operation_name}",
time.time() - start,
)
def apply():
graphql_server.execute_graphql_request = execute_graphql_request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment