Last active
February 6, 2024 21:28
-
-
Save dtaniwaki/341ec184eed20d965e154f71a826ee73 to your computer and use it in GitHub Desktop.
gRPC Server Interceptor for https://github.com/envoyproxy/protoc-gen-validate
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import collections | |
import functools | |
import logging | |
from typing import Any, Callable, Iterable, Union | |
from google.protobuf.message import Message | |
from grpc import HandlerCallDetails, RpcMethodHandler, ServerInterceptor, ServicerContext, StatusCode | |
from grpc.experimental import wrap_server_method_handler | |
from validator import ValidationFailed, validate | |
logger = logging.getLogger(__name__) | |
MESSAGE_TYPE = Union[Message, Iterable[Message]] | |
def _validate_iter(message_iter: Iterable[Message]) -> Iterable[Message]: | |
for msg in message_iter: | |
v = validate(msg) | |
v(msg) | |
yield msg | |
def _validate_response_iter(response_iter: Iterable[Message]) -> Iterable[Message]: | |
for res in response_iter: | |
v = validate(res) | |
try: | |
v(res) | |
except ValidationFailed as e: | |
logger.warning("Response validation failed: %s" % str(e)) | |
yield res | |
def _wrapper( | |
behavior: Callable[[MESSAGE_TYPE, ServicerContext], MESSAGE_TYPE] | |
) -> Callable[[MESSAGE_TYPE, ServicerContext], Message]: | |
@functools.wraps(behavior) | |
def wrapper(request: MESSAGE_TYPE, context: ServicerContext) -> MESSAGE_TYPE: | |
if isinstance(request, collections.Iterable): | |
# No validation until the actual iteration in behavior. | |
request = _validate_iter(request) | |
else: | |
try: | |
v = validate(request) | |
v(request) | |
except ValidationFailed as e: | |
context.abort(StatusCode.INVALID_ARGUMENT, str(e)) | |
try: | |
response = behavior(request, context) | |
if isinstance(response, collections.Iterable): | |
response = _validate_response_iter(response) | |
else: | |
v = validate(response) | |
try: | |
v(response) | |
except ValidationFailed as e: | |
logger.warning("Response validation failed: %s" % str(e)) | |
return response | |
except ValidationFailed as e: | |
context.abort(StatusCode.INVALID_ARGUMENT, str(e)) | |
return # type: ignore | |
return wrapper | |
class ProtocValidationServerInterceptor(ServerInterceptor): # type: ignore | |
def intercept_service( | |
self, continuation: Callable[[HandlerCallDetails], RpcMethodHandler], handler_call_details: HandlerCallDetails | |
) -> RpcMethodHandler: | |
handler: RpcMethodHandler = continuation(handler_call_details) | |
return wrap_server_method_handler(_wrapper, handler) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment