Skip to content

Instantly share code, notes, and snippets.

@GiedriusS
Created September 9, 2021 16:25
Show Gist options
  • Save GiedriusS/313b3be1543f543487df5f624f84d836 to your computer and use it in GitHub Desktop.
Save GiedriusS/313b3be1543f543487df5f624f84d836 to your computer and use it in GitHub Desktop.
// Copyright (c) The Thanos Authors.
// Licensed under the Apache License 2.0.
package extgrpc
import (
"context"
"fmt"
"sync"
"github.com/prometheus/prometheus/pkg/labels"
"github.com/thanos-io/thanos/pkg/store/labelpb"
"github.com/thanos-io/thanos/pkg/store/storepb"
"google.golang.org/grpc"
)
// Singleflight represents a Stream interceptor which does not call
// the same method again with the same arguments via gRPC.
type SingleflightSeries struct {
queriesMtx sync.Mutex
queriesInProgress map[string]*listenerValue
}
type listenerChanValue struct {
resp *storepb.SeriesResponse
err error
}
type listenerValue struct {
listeners []chan *listenerChanValue
listenersMtx sync.Mutex
}
// copySeriesResponse makes a copy of the given SeriesResponse if it is a Series.
// If not then the original response is returned.
func copySeriesResponse(r *storepb.SeriesResponse) *storepb.SeriesResponse {
originalSeries := r.GetSeries()
if originalSeries == nil {
return r
}
resp := &storepb.SeriesResponse{}
newLabels := labels.Labels{}
for _, lbl := range originalSeries.Labels {
newLabels = append(newLabels, labels.Label{
Name: lbl.Name,
Value: lbl.Value,
})
}
series := &storepb.Series{
Labels: labelpb.ZLabelsFromPromLabels(newLabels),
}
if len(originalSeries.Chunks) > 0 {
chunks := make([]storepb.AggrChunk, len(originalSeries.Chunks))
copy(chunks, originalSeries.Chunks)
series.Chunks = chunks
}
resp.Result = &storepb.SeriesResponse_Series{
Series: series,
}
return resp
}
// singleflightClientStream wraps grpc.ClientStream allowing to send Series() only once with
// identical request.
type singleflightClientStream struct {
grpc.ClientStream
sfs *SingleflightSeries
queryOwner bool
respListener chan *listenerChanValue
originalReq *storepb.SeriesRequest
reqListeners []chan *listenerChanValue
}
// StreamClientInterceptor is a gRPC client-side interceptor that provides Prometheus monitoring for Streaming RPCs.
func (s *SingleflightSeries) StreamClientInterceptor() func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
clientStream, err := streamer(ctx, desc, cc, method, opts...)
if err != nil {
return nil, err
}
if method != "/thanos.Store/Series" {
return clientStream, err
}
return &singleflightClientStream{sfs: s, ClientStream: clientStream}, err
}
}
func (s *singleflightClientStream) SendMsg(m interface{}) error {
req, ok := m.(*storepb.SeriesRequest)
if !ok {
panic("expected only a SeriesRequest to go via the singleflight mechanism")
}
s.originalReq = req
// Mark ourselves as a query in progress.
s.sfs.queriesMtx.Lock()
// TODO(GiedriusS): don't allocate a bunch of strings. Instead
// do what previously has been done and add matchers at the end with some separator.
// Wipe the separator and see if there is a match.
val, ok := s.sfs.queriesInProgress[req.String()]
if !ok {
val = &listenerValue{
listeners: []chan *listenerChanValue{},
}
s.sfs.queriesInProgress[req.String()] = val
fmt.Println("added to val as a listener", req.String())
}
s.sfs.queriesMtx.Unlock()
val.listenersMtx.Lock()
s.queryOwner = len(val.listeners) == 0
fmt.Println(s.queryOwner, req.String(), len(val.listeners), val.listeners)
if !s.queryOwner {
s.respListener = make(chan *listenerChanValue)
val.listeners = append(val.listeners, s.respListener)
}
val.listenersMtx.Unlock()
if s.queryOwner {
if err := s.ClientStream.SendMsg(m); err != nil {
// If an error occurs then remove ourselves from queries in progress
// before returning.
s.sfs.queriesMtx.Lock()
delete(s.sfs.queriesInProgress, req.String())
s.sfs.queriesMtx.Unlock()
val.listenersMtx.Lock()
for _, l := range s.sfs.queriesInProgress[req.String()].listeners {
l <- &listenerChanValue{err: err}
close(l)
}
val.listenersMtx.Unlock()
return err
}
}
return nil
}
func (s *singleflightClientStream) RecvMsg(m interface{}) error {
// If I am the one sending the request then read it and send a copy
// of it to all listeners. If the error is io.EOF then clean up everything.
// If I am not the one sending the request then listen for read from the listeners.
if s.queryOwner {
s.sfs.queriesMtx.Lock()
// Save our listeners.
qip, ok := s.sfs.queriesInProgress[s.originalReq.String()]
if ok {
s.reqListeners = qip.listeners
delete(s.sfs.queriesInProgress, s.originalReq.String())
fmt.Println("deleted from listeners", s.originalReq.String())
}
s.sfs.queriesMtx.Unlock()
err := s.ClientStream.RecvMsg(m)
if err != nil {
for _, l := range s.reqListeners {
l <- &listenerChanValue{err: err}
close(l)
}
return err
} else {
resp, ok := m.(*storepb.SeriesResponse)
if !ok {
panic("only responses should pass through here")
}
for _, l := range s.reqListeners {
l <- &listenerChanValue{resp: copySeriesResponse(resp)}
}
}
} else {
// This is guaranteed to always succeed unless a panic occurs in a goroutine.
// TODO(GiedriusS): add timeout on reading.
fmt.Println("!!!!!!!! BINGOOOO")
passedResp := <-s.respListener
if passedResp.err != nil {
return passedResp.err
}
receiver := m.(*storepb.SeriesResponse)
*receiver = *passedResp.resp
}
return nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment