Skip to content

Instantly share code, notes, and snippets.

@amitsaha
Forked from codingbaobao/client.go
Created April 12, 2022 01:26
Show Gist options
  • Save amitsaha/903ab955f5347f2e6042b84660401d2a to your computer and use it in GitHub Desktop.
Save amitsaha/903ab955f5347f2e6042b84660401d2a to your computer and use it in GitHub Desktop.
grpc interceptor in Python/Golang
package main
import (
"context"
pb "example/helloworld"
"io"
"log"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
type ClientInterceptor struct {
Header string
Value string
ResponseHeader string
ResponseValue string
}
func (client *ClientInterceptor) unaryInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
log.Println("Add header in interceptor")
ctx = metadata.AppendToOutgoingContext(ctx, client.Header, client.Value)
var responseMeta metadata.MD
opts = append(opts, grpc.Header(&responseMeta))
err := invoker(ctx, method, req, reply, cc, opts...)
if err != nil {
return err
}
if val, ok := responseMeta[client.ResponseHeader]; ok {
client.ResponseValue = val[0]
log.Println("Response header: ", val)
} else {
return status.Error(codes.DataLoss, "Response without header founded")
}
return nil
}
func (client *ClientInterceptor) streamInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
log.Println("Add header in interceptor")
ctx = metadata.AppendToOutgoingContext(ctx, client.Header, client.Value)
s, err := streamer(ctx, desc, cc, method, opts...)
if err != nil {
return nil, err
}
log.Println("Access response header")
responseMeta, err := s.Header() // Block at here, server not trigger any function
if err != nil {
return nil, err
}
if val, ok := responseMeta[client.ResponseHeader]; ok {
client.ResponseValue = val[0]
log.Println("Response header: ", val[0])
} else {
return nil, status.Errorf(codes.DataLoss, "Response without header founded")
}
return s, nil
// return newWrappedStream(s), nil
}
type wrappedStream struct {
grpc.ClientStream
getHeader bool
}
func (w *wrappedStream) RecvMsg(m interface{}) error {
if !w.getHeader {
responseMeta, err := w.Header()
if err != nil {
return nil
}
if val, ok := responseMeta["x-custom-echo"]; ok {
log.Println("Response header: ", val[0])
}
}
return w.ClientStream.RecvMsg(m)
}
func (w *wrappedStream) SendMsg(m interface{}) error {
return w.ClientStream.SendMsg(m)
}
func newWrappedStream(s grpc.ClientStream) grpc.ClientStream {
return &wrappedStream{s, false}
}
func main() {
interceptor := &ClientInterceptor{Header: "x-custom", Value: "hello header", ResponseHeader: "x-custom-echo"}
conn, err := grpc.Dial("127.0.0.1:60661", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithUnaryInterceptor(interceptor.unaryInterceptor), grpc.WithStreamInterceptor(interceptor.streamInterceptor))
if err != nil {
log.Fatalf("did not connect: %v", err)
}
defer conn.Close()
client := pb.NewGreeterClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
stream, err := client.StreamEcho(ctx, &pb.EchoMsg{Msg: "hello world"})
if err != nil {
log.Fatalf("could not greet: %v", err)
}
for {
echo, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
log.Fatalln(err.Error())
}
log.Println(echo.Msg)
}
}
import logging
import grpc
import helloworld_pb2
import helloworld_pb2_grpc
class ClientInterceptor(
grpc.UnaryStreamClientInterceptor,
grpc.UnaryUnaryClientInterceptor,
grpc.StreamStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor,
):
def _intercept_call(self, continuation, client_call_details: grpc.ClientCallDetails, request_or_iterator):
new_details = grpc.ClientCallDetails()
new_details.method = client_call_details.method
new_details.timeout = client_call_details.timeout
new_details.credentials = client_call_details.credentials
if client_call_details.metadata:
metadata = list(client_call_details.metadata)
else:
metadata = list()
logging.info("Add header in interceptor")
metadata.append(("x-custom", "hello header"))
new_details.metadata = metadata
response = continuation(new_details, request_or_iterator)
for key, value in response.initial_metadata():
if key == "x-custom-echo":
logging.info(f"Response header: {value}")
break
return response
def intercept_stream_stream(self, continuation, client_call_details, request_iterator):
return self._intercept_call(continuation, client_call_details, request_iterator)
def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
return self._intercept_call(continuation, client_call_details, request_iterator)
def intercept_unary_stream(self, continuation, client_call_details, request):
return self._intercept_call(continuation, client_call_details, request)
def intercept_unary_unary(self, continuation, client_call_details, request):
return self._intercept_call(continuation, client_call_details, request)
logging.basicConfig(level=logging.DEBUG)
channel = grpc.insecure_channel('localhost:60661')
channel = grpc.intercept_channel(channel, ClientInterceptor())
stub = helloworld_pb2_grpc.GreeterStub(channel)
response_iterator = stub.StreamEcho(helloworld_pb2.EchoMsg(msg="hello world"))
for response in response_iterator:
logging.info(response.msg)
channel.close()
syntax = "proto3";
package helloworld;
option go_package = "./helloworld";
service Greeter{
rpc StreamEcho (EchoMsg) returns (stream EchoMsg) {}
}
message EchoMsg{
string msg = 1;
}
import logging
from concurrent import futures
import grpc
import helloworld_pb2
import helloworld_pb2_grpc
def interceptor(func):
def wrapper(cls, request, context: grpc.ServicerContext):
logging.info("Trigger interceptor")
for key, val in context.invocation_metadata():
if key == "x-custom":
logging.info("Found header")
context.send_initial_metadata([("x-custom-echo",f"echo {val}")])
break
return func(cls, request, context)
return wrapper
class Greeter(helloworld_pb2_grpc.GreeterServicer):
@interceptor
def StreamEcho(self, request, context):
logging.info("Trigger StreamEcho")
echo = request.msg
for _ in range(3):
yield helloworld_pb2.EchoMsg(msg=f"Echo: {echo}")
def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server)
server.add_insecure_port('[::]:60661')
server.start()
logging.info("Start server at 60661")
server.wait_for_termination()
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
serve()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment