Forked from ryanwang520/gist:4ad82ff754f266f645f7f4aa1611b317
Last active
February 25, 2019 08:52
-
-
Save Fity/02548e765d2023b37fc1ec5cda9da3e2 to your computer and use it in GitHub Desktop.
Analysis For GraphQL Requests
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import functools | |
| import graphene | |
| from graphene.utils.str_converters import to_snake_case | |
| from graphql import get_default_backend | |
| import logging | |
| import time | |
| from flask import request, current_app | |
| from typing import Tuple | |
| import graphql_server | |
| from graphql.execution import ExecutionResult | |
| from graphql.language.ast import ( | |
| FragmentSpread, | |
| FragmentDefinition, | |
| OperationDefinition, | |
| InlineFragment, | |
| IntValue, | |
| Variable, | |
| ) | |
| from graphql_server import HttpQueryError | |
| from gemini.globals import statsd | |
| from gemini.graphql.fields import Field | |
| logger = logging.getLogger(__name__) | |
| def deep_get_field(obj, path: Tuple): | |
| item = obj | |
| for i in path: | |
| def extract_non_null(item): | |
| if isinstance(item, graphene.NonNull): | |
| return item.of_type | |
| return item | |
| item = extract_non_null(item) | |
| if isinstance(item, graphene.List): | |
| item = getattr(extract_non_null(item.of_type), i) | |
| elif isinstance(item, graphene.Field): | |
| item = getattr(extract_non_null(item.type), i) | |
| else: | |
| item = getattr(item, i) | |
| return item | |
| def analyze_cost(document, variables): | |
| default_complexity = 0 | |
| query = document.schema._query | |
| path = () | |
| def calc_cost(selection, current_path=None, parent=None, multipliers=(1,)): | |
| costs = 0 | |
| # parent_multipliers = multipliers | |
| is_def = isinstance(selection, OperationDefinition) | |
| if isinstance(selection, InlineFragment): | |
| complexity = sum( | |
| calc_cost(s, current_path, selection, multipliers) | |
| for s in selection.selection_set.selections | |
| ) | |
| costs += complexity | |
| else: | |
| name = to_snake_case(selection.name.value) | |
| if isinstance(selection, FragmentSpread): | |
| fragment_name = selection.name.value | |
| for fragment in fragments: | |
| if fragment.name.value == fragment_name: | |
| return max( | |
| calc_cost(s, current_path, None, multipliers) | |
| for s in fragment.selection_set.selections | |
| ) | |
| if not is_def: | |
| current_path = current_path + ((name,) if not is_def else ()) | |
| if parent: | |
| for type in document.schema.types: | |
| if parent.type_condition.name.value == type.__name__: | |
| field = getattr(type, selection.name.value) | |
| break | |
| else: | |
| field = deep_get_field(query, current_path) | |
| # has_explicit_complex = False | |
| multiply_by = 0 | |
| if isinstance(field, Field): | |
| complexity = field.cost.complexity | |
| # if complexity: | |
| # has_explicit_complex = True | |
| use_multipliers = field.cost.use_multipliers | |
| if use_multipliers and field.cost.multipliers: | |
| for arg in selection.arguments: | |
| arg_name = arg.name.value | |
| snake_name = to_snake_case(arg_name) | |
| if snake_name in field.cost.multipliers: | |
| val = arg.value | |
| if isinstance(val, IntValue): | |
| multiply_by += int(arg.value.value) | |
| if isinstance(val, Variable): | |
| multiply_by += variables.get( | |
| arg_name, field.args[snake_name].default_value | |
| ) | |
| if multiply_by: | |
| multipliers = multipliers + (multiply_by,) | |
| else: | |
| complexity = default_complexity | |
| multiply = functools.reduce(lambda x, y: x * y, multipliers) | |
| costs += multiply * complexity | |
| if selection.selection_set: | |
| costs += sum( | |
| calc_cost(s, current_path, None, multipliers) | |
| for s in selection.selection_set.selections | |
| ) | |
| return costs | |
| fragments = [ | |
| definition | |
| for definition in document.document_ast.definitions | |
| if isinstance(definition, FragmentDefinition) | |
| ] | |
| defs = [ | |
| definition | |
| for definition in document.document_ast.definitions | |
| if (not isinstance(definition, FragmentDefinition)) | |
| and (definition.operation == "query") | |
| and (definition.name.value != "IntrospectionQuery") | |
| ] | |
| if not defs: | |
| return 0 | |
| return max(calc_cost(definition, path) for definition in defs) | |
| def get_max_depth(document): | |
| def depth_for_selection(selection): | |
| if isinstance(selection, FragmentSpread): | |
| fragment_name = selection.name.value | |
| for fragment in fragments: | |
| if fragment.name.value == fragment_name: | |
| return depth_for_selection(fragment) | |
| if not selection.selection_set: | |
| return 0 | |
| return 1 + max( | |
| depth_for_selection(s) for s in selection.selection_set.selections | |
| ) | |
| fragments = [ | |
| definition | |
| for definition in document.document_ast.definitions | |
| if isinstance(definition, FragmentDefinition) | |
| ] | |
| defs = [ | |
| definition | |
| for definition in document.document_ast.definitions | |
| if (not isinstance(definition, FragmentDefinition)) | |
| and (definition.name.value != "IntrospectionQuery") | |
| ] | |
| if not defs: | |
| return 0 | |
| return max(depth_for_selection(definition) for definition in defs) | |
| def execute_graphql_request( | |
| schema, params, allow_only_query=False, backend=None, **kwargs | |
| ): | |
| if not params.query: | |
| raise HttpQueryError(400, "Must provide query string.") | |
| try: | |
| if not backend: | |
| backend = get_default_backend() | |
| document = backend.document_from_string(schema, params.query) | |
| except Exception as e: | |
| return ExecutionResult(errors=[e], invalid=True) | |
| operation_type = document.get_operation_type(params.operation_name) | |
| if allow_only_query: | |
| if operation_type and operation_type != "query": | |
| raise HttpQueryError( | |
| 405, | |
| "Can only perform a {} operation from a POST request.".format( | |
| operation_type | |
| ), | |
| headers={"Allow": "POST"}, | |
| ) | |
| start = time.time() | |
| operation_name = None | |
| try: | |
| # 检查请求名 | |
| for operation_name, operation_type in document.operations_map.items(): | |
| if operation_name is None: | |
| logger.error("%s: Query %s 没有命名", request.blueprint, params.query) | |
| raise ValueError("请求非法") | |
| # 检查深度 | |
| try: | |
| max_depth = get_max_depth(document) | |
| except Exception as e: | |
| logger.exception(e) | |
| raise | |
| if max_depth > 10: | |
| logger.error("%s: Query %s 深度超过10", request.blueprint, params.query) | |
| raise ValueError("查询过于复杂") | |
| # 检查负载度 | |
| if current_app.debug: | |
| try: | |
| cost_start = time.time() | |
| cost = analyze_cost(document, params.variables) | |
| logger.info( | |
| "calc cost use for {} {} s".format( | |
| operation_name, time.time() - cost_start | |
| ) | |
| ) | |
| logger.info("cost %s", cost) | |
| except Exception as e: | |
| logger.exception(e) | |
| raise | |
| # 拒绝超过5000的请求 | |
| if cost > 5000: | |
| raise ValueError("请求节点过多") | |
| return document.execute( | |
| operation_name=params.operation_name, variables=params.variables, **kwargs | |
| ) | |
| except Exception as e: | |
| return ExecutionResult(errors=[e], invalid=True) | |
| finally: | |
| statsd.timing( | |
| f"{request.blueprint}.{operation_type}.{operation_name}", | |
| time.time() - start, | |
| ) | |
| def apply(): | |
| graphql_server.execute_graphql_request = execute_graphql_request |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment