-
-
Save amitsaha/903ab955f5347f2e6042b84660401d2a to your computer and use it in GitHub Desktop.
grpc interceptor in Python/Golang
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"context" | |
pb "example/helloworld" | |
"io" | |
"log" | |
"time" | |
"google.golang.org/grpc" | |
"google.golang.org/grpc/codes" | |
"google.golang.org/grpc/credentials/insecure" | |
"google.golang.org/grpc/metadata" | |
"google.golang.org/grpc/status" | |
) | |
type ClientInterceptor struct { | |
Header string | |
Value string | |
ResponseHeader string | |
ResponseValue string | |
} | |
func (client *ClientInterceptor) unaryInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { | |
log.Println("Add header in interceptor") | |
ctx = metadata.AppendToOutgoingContext(ctx, client.Header, client.Value) | |
var responseMeta metadata.MD | |
opts = append(opts, grpc.Header(&responseMeta)) | |
err := invoker(ctx, method, req, reply, cc, opts...) | |
if err != nil { | |
return err | |
} | |
if val, ok := responseMeta[client.ResponseHeader]; ok { | |
client.ResponseValue = val[0] | |
log.Println("Response header: ", val) | |
} else { | |
return status.Error(codes.DataLoss, "Response without header founded") | |
} | |
return nil | |
} | |
func (client *ClientInterceptor) streamInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { | |
log.Println("Add header in interceptor") | |
ctx = metadata.AppendToOutgoingContext(ctx, client.Header, client.Value) | |
s, err := streamer(ctx, desc, cc, method, opts...) | |
if err != nil { | |
return nil, err | |
} | |
log.Println("Access response header") | |
responseMeta, err := s.Header() // Block at here, server not trigger any function | |
if err != nil { | |
return nil, err | |
} | |
if val, ok := responseMeta[client.ResponseHeader]; ok { | |
client.ResponseValue = val[0] | |
log.Println("Response header: ", val[0]) | |
} else { | |
return nil, status.Errorf(codes.DataLoss, "Response without header founded") | |
} | |
return s, nil | |
// return newWrappedStream(s), nil | |
} | |
type wrappedStream struct { | |
grpc.ClientStream | |
getHeader bool | |
} | |
func (w *wrappedStream) RecvMsg(m interface{}) error { | |
if !w.getHeader { | |
responseMeta, err := w.Header() | |
if err != nil { | |
return nil | |
} | |
if val, ok := responseMeta["x-custom-echo"]; ok { | |
log.Println("Response header: ", val[0]) | |
} | |
} | |
return w.ClientStream.RecvMsg(m) | |
} | |
func (w *wrappedStream) SendMsg(m interface{}) error { | |
return w.ClientStream.SendMsg(m) | |
} | |
func newWrappedStream(s grpc.ClientStream) grpc.ClientStream { | |
return &wrappedStream{s, false} | |
} | |
func main() { | |
interceptor := &ClientInterceptor{Header: "x-custom", Value: "hello header", ResponseHeader: "x-custom-echo"} | |
conn, err := grpc.Dial("127.0.0.1:60661", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithUnaryInterceptor(interceptor.unaryInterceptor), grpc.WithStreamInterceptor(interceptor.streamInterceptor)) | |
if err != nil { | |
log.Fatalf("did not connect: %v", err) | |
} | |
defer conn.Close() | |
client := pb.NewGreeterClient(conn) | |
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) | |
defer cancel() | |
stream, err := client.StreamEcho(ctx, &pb.EchoMsg{Msg: "hello world"}) | |
if err != nil { | |
log.Fatalf("could not greet: %v", err) | |
} | |
for { | |
echo, err := stream.Recv() | |
if err == io.EOF { | |
break | |
} | |
if err != nil { | |
log.Fatalln(err.Error()) | |
} | |
log.Println(echo.Msg) | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import grpc | |
import helloworld_pb2 | |
import helloworld_pb2_grpc | |
class ClientInterceptor( | |
grpc.UnaryStreamClientInterceptor, | |
grpc.UnaryUnaryClientInterceptor, | |
grpc.StreamStreamClientInterceptor, | |
grpc.StreamUnaryClientInterceptor, | |
): | |
def _intercept_call(self, continuation, client_call_details: grpc.ClientCallDetails, request_or_iterator): | |
new_details = grpc.ClientCallDetails() | |
new_details.method = client_call_details.method | |
new_details.timeout = client_call_details.timeout | |
new_details.credentials = client_call_details.credentials | |
if client_call_details.metadata: | |
metadata = list(client_call_details.metadata) | |
else: | |
metadata = list() | |
logging.info("Add header in interceptor") | |
metadata.append(("x-custom", "hello header")) | |
new_details.metadata = metadata | |
response = continuation(new_details, request_or_iterator) | |
for key, value in response.initial_metadata(): | |
if key == "x-custom-echo": | |
logging.info(f"Response header: {value}") | |
break | |
return response | |
def intercept_stream_stream(self, continuation, client_call_details, request_iterator): | |
return self._intercept_call(continuation, client_call_details, request_iterator) | |
def intercept_stream_unary(self, continuation, client_call_details, request_iterator): | |
return self._intercept_call(continuation, client_call_details, request_iterator) | |
def intercept_unary_stream(self, continuation, client_call_details, request): | |
return self._intercept_call(continuation, client_call_details, request) | |
def intercept_unary_unary(self, continuation, client_call_details, request): | |
return self._intercept_call(continuation, client_call_details, request) | |
logging.basicConfig(level=logging.DEBUG) | |
channel = grpc.insecure_channel('localhost:60661') | |
channel = grpc.intercept_channel(channel, ClientInterceptor()) | |
stub = helloworld_pb2_grpc.GreeterStub(channel) | |
response_iterator = stub.StreamEcho(helloworld_pb2.EchoMsg(msg="hello world")) | |
for response in response_iterator: | |
logging.info(response.msg) | |
channel.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
syntax = "proto3"; | |
package helloworld; | |
option go_package = "./helloworld"; | |
service Greeter{ | |
rpc StreamEcho (EchoMsg) returns (stream EchoMsg) {} | |
} | |
message EchoMsg{ | |
string msg = 1; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from concurrent import futures | |
import grpc | |
import helloworld_pb2 | |
import helloworld_pb2_grpc | |
def interceptor(func): | |
def wrapper(cls, request, context: grpc.ServicerContext): | |
logging.info("Trigger interceptor") | |
for key, val in context.invocation_metadata(): | |
if key == "x-custom": | |
logging.info("Found header") | |
context.send_initial_metadata([("x-custom-echo",f"echo {val}")]) | |
break | |
return func(cls, request, context) | |
return wrapper | |
class Greeter(helloworld_pb2_grpc.GreeterServicer): | |
@interceptor | |
def StreamEcho(self, request, context): | |
logging.info("Trigger StreamEcho") | |
echo = request.msg | |
for _ in range(3): | |
yield helloworld_pb2.EchoMsg(msg=f"Echo: {echo}") | |
def serve(): | |
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) | |
helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) | |
server.add_insecure_port('[::]:60661') | |
server.start() | |
logging.info("Start server at 60661") | |
server.wait_for_termination() | |
if __name__ == '__main__': | |
logging.basicConfig(level=logging.DEBUG) | |
serve() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment