Skip to content

Instantly share code, notes, and snippets.

@boeboe
Created September 28, 2023 05:17
Show Gist options
  • Save boeboe/88c30b1df3ba645614411f4c6803e2e4 to your computer and use it in GitHub Desktop.
Save boeboe/88c30b1df3ba645614411f4c6803e2e4 to your computer and use it in GitHub Desktop.
// Copyright (c) Tetrate, Inc 2023 All Rights Reserved.
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"strconv"
"strings"
"sync"
"time"
"filippo.io/mostly-harmless/cryptosource"
"github.com/gorilla/mux"
"github.com/openzipkin/zipkin-go"
zmw "github.com/openzipkin/zipkin-go/middleware/http"
)
// Skipping G404 (CWE-338): Use of weak random number generator (math/rand instead of crypto/rand) (Confidence: MEDIUM, Severity: HIGH)
// since we use filippo.io/mostly-harmless/cryptosource.
var rnd = rand.New(cryptosource.New()) // #nosec G404
// setErrors allows one to set the percentage of error responses this service
// will generate on the main echoHandler.
func (ep *Endpoints) setErrors(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
strErrors, ok := mux.Vars(r)["percentage"]
if !ok {
ep.writeResponse(ctx, w, response{
Code: http.StatusBadRequest,
Error: errPercentage,
})
return
}
i, err := strconv.Atoi(strErrors)
if err != nil {
ep.writeResponse(ctx, w, response{
Code: http.StatusBadRequest,
Error: errPercentage,
})
return
}
if i < 0 || i > 100 {
ep.writeResponse(ctx, w, response{
Code: http.StatusBadRequest,
Error: errPercentage,
})
return
}
ep.mtx.Lock()
ep.errors = int32(i)
ep.mtx.Unlock()
ep.writeResponse(ctx, w, response{
Code: http.StatusOK,
Message: fmt.Sprintf("errors percentage set to: %d%%", i),
})
}
// setDoubleHeaders allows one to set the percentage of double headers this
// service will generate on the main echoHandler.
func (ep *Endpoints) setDoubleHeaders(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
strErrors, ok := mux.Vars(r)["percentage"]
if !ok {
ep.writeResponse(ctx, w, response{
Code: http.StatusBadRequest,
Error: errPercentage,
})
return
}
i, err := strconv.Atoi(strErrors)
if err != nil {
ep.writeResponse(ctx, w, response{
Code: http.StatusBadRequest,
Error: errPercentage,
})
return
}
if i < 0 || i > 100 {
ep.writeResponse(ctx, w, response{
Code: http.StatusBadRequest,
Error: errPercentage,
})
return
}
ep.mtx.Lock()
ep.headers = int32(i)
ep.mtx.Unlock()
ep.writeResponse(ctx, w, response{
Code: http.StatusOK,
Message: fmt.Sprintf("double headers percentage set to: %d%%", i),
})
}
// setLatency allows one to set the latency in miliseconds this service will
// generate on the main echoHandler.
func (ep *Endpoints) setLatency(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
strErrors, ok := mux.Vars(r)["duration"]
if !ok {
ep.writeResponse(ctx, w, response{
Code: http.StatusBadRequest,
Error: errDuration,
})
return
}
d, err := time.ParseDuration(strErrors)
if err != nil {
// not a duration string, let's see if it is a raw number...
var i int
if i, err = strconv.Atoi(strErrors); err != nil {
// not a raw number either...
ep.writeResponse(ctx, w, response{
Code: http.StatusBadRequest,
Error: errDuration,
})
return
}
d = time.Duration(i) * time.Millisecond
}
if d < 0 {
ep.writeResponse(ctx, w, response{
Code: http.StatusBadRequest,
Error: errDuration,
})
return
}
ep.mtx.Lock()
ep.duration = d
ep.mtx.Unlock()
ep.writeResponse(ctx, w, response{
Code: http.StatusOK,
Message: fmt.Sprintf("duration set to: %s", d.String()),
})
}
// setHandleFailures allows one to set behavior of this service's proxy handler.
// If set to true, a downstream error will not cascade into a failure by this
// event. Instead, it will mimick a service that is resilient to downstream
// issues and can report back successfully. If set to false, nothing is changed
// and the proxy handler will happily forward the reply from downstream.
func (ep *Endpoints) setHandleFailures(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
handleFailures, ok := mux.Vars(r)["handleFailures"]
if !ok {
ep.writeResponse(ctx, w, response{
Code: http.StatusBadRequest,
Error: errHandleFailures,
})
return
}
var h bool
switch strings.ToLower(handleFailures) {
case "1", "on", "yes", "y", "true", "t":
h = true
case "0", "off", "no", "n", "false", "f":
h = false
default:
ep.writeResponse(ctx, w, response{
Code: http.StatusBadRequest,
Error: errHandleFailures,
})
return
}
ep.mtx.Lock()
ep.handleFailures = h
ep.mtx.Unlock()
ep.writeResponse(ctx, w, response{
Code: http.StatusOK,
Message: fmt.Sprintf("handle failures set to: %t", h),
})
}
// crash instructs this service to crash with the provided method after 5
// seconds of receiving this directive.
func (ep *Endpoints) crash(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
msg := mux.Vars(r)["message"]
ep.writeResponse(ctx, w, response{
Code: http.StatusOK,
Message: "crashing in 5 seconds",
})
go func() {
time.Sleep(5 * time.Second)
panic("crash requested: " + msg)
}()
}
// emulateConcurrency instructs this service to run 8 fake heavy local methods.
// The methods will take the provided duration as their run time. The
// concurrency argument will instruct these methods to run serial, in parallel,
// or mixed serial and parallel. The methods are instrumented as local spans,
// so they will show up in your trace graph.
func (ep *Endpoints) emulateConcurrency(w http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
vars = mux.Vars(r)
d time.Duration
err error
pSpan = zipkin.SpanFromContext(r.Context()).Context()
)
if strErrors, ok := vars["duration"]; ok {
d, err = time.ParseDuration(strErrors)
if err != nil {
// not a duration string, let's see if it is a raw number...
var i int
if i, err = strconv.Atoi(strErrors); err != nil {
// not a raw number either...
ep.writeResponse(ctx, w, response{
Code: http.StatusBadRequest,
Error: errDuration,
})
return
}
d = time.Duration(i) * time.Millisecond
}
if d < 0 {
ep.writeResponse(ctx, w, response{
Code: http.StatusBadRequest,
Error: errDuration,
})
return
}
}
// we will be emulating 8 heavy internal functions
var wg sync.WaitGroup
wg.Add(8)
proc := func(i int) {
defer wg.Done()
span := ep.tracer.StartSpan(fmt.Sprintf("proc-%d", i), zipkin.Parent(pSpan))
defer span.Finish()
span.Tag("duration", d.String())
time.Sleep(d)
}
c := vars["concurrency"]
switch strings.ToLower(c) {
case "serial":
for i := 0; i < 8; i++ {
proc(i)
}
case "mixed":
for i := 0; i < 8; i++ {
if i%2 == 0 {
go proc(i)
continue
}
proc(i)
}
case "parallel":
for i := 0; i < 8; i++ {
go proc(i)
}
default:
ep.writeResponse(ctx, w, response{
Code: http.StatusBadRequest,
Error: errConcurrency,
})
return
}
// wait until all goroutines are finished
wg.Wait()
ep.writeResponse(ctx, w, response{
Code: http.StatusOK,
Message: "ran several local spans",
})
}
// proxy allows us to hop from service to multiple services in parallel by
// providing path chunks referencing the services.
// Multiple services can be reached in parallel by enclosing each path chunk
// in between "[]". For instance, /proxy/[svcf/proxy/[svcd/proxy/svce][svcd]][svcb/errors/50].
// Sending multiple requests at each hop allows us to mimic more complex real life microservice
// topologies.
// It parses and strips the first /proxy/service:port directive from the path
// and reverse proxies the remaining path request to the targeted service.
// This allows us to hop from service to service by providing path chunks
// referencing the services.
//
// Example path: /proxy/svcf/proxy/svcd/proxy/svcb/errors/50
// This path will hop from app ingress to svdf, svcd, svcb, where this final
// svcb will receive an /errors/50 request to handle.
// Example path: /proxy/[svcf/proxy/[svcd][svcd][svcd]][svcb/errors/50]
// This path will hop from app ingress to svcf, svcb in parallel and the responses
// from both request will be combined. svcf will call in parallel three times svcd and all
// three responses will be combined. svcb will receive an /errors/50 request to handle.
func (ep *Endpoints) proxy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
proxyChains, ok := mux.Vars(r)["proxyChains"]
if !ok {
ep.writeResponse(ctx, w, response{
Code: http.StatusBadRequest,
Error: errProxyService,
})
return
}
ep.mtx.RLock()
d := ep.duration
e := ep.errors
handleFailure := ep.handleFailures
ep.mtx.RUnlock()
// inject configured latency
time.Sleep(d)
// Skipping G404 (CWE-338): Use of weak random number generator (math/rand instead of crypto/rand) (Confidence: MEDIUM, Severity: HIGH)
// since we use filippo.io/mostly-harmless/cryptosource.
if rnd.Int31n(100) < e { // #nosec G404
// return error response...
ep.writeResponse(ctx, w, response{
Code: http.StatusInternalServerError,
Error: errInternal,
})
return
}
chains := ep.parseProxyChains(proxyChains)
if len(chains) == 0 {
// return error response...
ep.writeResponse(ctx, w, response{
Code: http.StatusBadRequest,
Error: errors.New("invalid proxy chain format"),
})
return
}
type requestResponseRecorder struct {
request *http.Request
response *httptest.ResponseRecorder
}
var wg sync.WaitGroup
if len(chains) == 1 {
wg.Add(1)
ep.proxyRequestChain(ctx, chains[0], w, r.Clone(r.Context()), handleFailure, &wg)
wg.Wait()
return
}
recorders := make([]requestResponseRecorder, 0, len(chains))
for _, chain := range chains {
wg.Add(1)
chainRequest := r.Clone(r.Context())
recorder := httptest.NewRecorder()
recorders = append(recorders, requestResponseRecorder{chainRequest, recorder})
ep.proxyRequestChain(ctx, chain, recorder, chainRequest, handleFailure, &wg)
}
wg.Wait()
responses := make([]response, 0, len(recorders))
for _, recorder := range recorders {
res := recorder.response.Result()
if res.Header.Get("Content-Type") == "application/json" {
resp := response{}
decoder := json.NewDecoder(res.Body)
if decoder.Decode(&resp) == nil {
responses = append(responses, resp)
continue
}
ep.Logger.Debug("decoding into response failed")
}
responses = append(responses, response{
Service: recorder.request.URL.Host,
Code: res.StatusCode,
TraceID: traceID(recorder.request.Context()),
Message: fmt.Sprintf("response from calling %s", recorder.request.URL),
Headers: res.Header,
})
}
ep.writeResponse(ctx, w, response{
Code: http.StatusOK,
Headers: r.Header,
Responses: responses,
})
}
func (ep *Endpoints) proxyRequestChain(ctx context.Context, chain string, w http.ResponseWriter, r *http.Request, handleFailure bool, wg *sync.WaitGroup) {
defer wg.Done()
svc, err := ep.parseService(chain)
if err != nil {
// return error response...
ep.writeResponse(ctx, w, response{
Code: http.StatusInternalServerError,
Error: errInternal,
})
return
}
u, err := url.Parse(svc)
if err != nil {
ep.Logger.Error("parse service %s", err, svc)
// return error response...
ep.writeResponse(ctx, w, response{
Code: http.StatusInternalServerError,
Error: errInternal,
})
return
}
r.Host = u.Host // this is needed or Envoy will get confused where to route it
r.Header.Add("Proxied-By", ep.ServiceName)
u.Path = "" // We don't need the path for the reverse proxy, otherwise it will be added twice
p := httputil.NewSingleHostReverseProxy(u)
r.URL, _ = url.Parse(svc)
d := p.Director
p.Director = func(req *http.Request) {
d(req)
// Remove the excluded headers
for _, h := range ep.dropHeaders {
req.Header.Del(h)
}
}
p.Transport, _ = zmw.NewTransport(ep.tracer, zmw.RoundTripper(p.Transport))
if handleFailure {
p.ModifyResponse = func(res *http.Response) error {
if res.StatusCode == 200 {
// proceed unaltered
return nil
}
// let's mimick a service that did a client request which failed,
// but due to nice business logic it is still able to handle
// the failure gracefully and return success status itself.
res.StatusCode = 200
raw, _ := io.ReadAll(res.Body)
_ = res.Body.Close()
ep.writeResponse(ctx, w, response{
Code: http.StatusOK,
Message: fmt.Sprintf(
"%s called %s and got error return: %s",
ep.ServiceName, svc, string(raw)),
})
// bail proxy logic, we returned details downstream ourselves
return errors.New("bail")
}
}
p.ServeHTTP(w, r)
}
// echoHandler returns the received request handlers, potentially setting double
// headers (for testing Envoy sidecars), or fail with an error. The method will
// take at least as long as the set latency. Double headers and errors will
// occur with the set percentages in the service.
func (ep *Endpoints) echoHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// retrieve our behavioral config
ep.mtx.RLock()
d := ep.duration
h := ep.headers
e := ep.errors
ep.mtx.RUnlock()
// inject configured latency
time.Sleep(d)
// Skipping G404 (CWE-338): Use of weak random number generator (math/rand instead of crypto/rand) (Confidence: MEDIUM, Severity: HIGH)
// since we use filippo.io/mostly-harmless/cryptosource.
if rnd.Int31n(100) < e { // #nosec G404
// return error response...
ep.writeResponse(ctx, w, response{
Code: http.StatusInternalServerError,
Error: errInternal,
})
return
}
// Skipping G404 (CWE-338): Use of weak random number generator (math/rand instead of crypto/rand) (Confidence: MEDIUM, Severity: HIGH)
// since we use filippo.io/mostly-harmless/cryptosource.
if rnd.Int31n(100) < h { // #nosec G404
// set some double headers
w.WriteHeader(http.StatusOK)
w.Header().Add("Content-Type", "text/html")
w.Header().Add("Content-Type", "application/json")
}
// emulate successful response, sending request headers received
ep.writeResponse(ctx, w, response{
Code: http.StatusOK,
Headers: r.Header,
})
}
func (ep *Endpoints) parseService(service string) (string, error) {
ep.Logger.Debug("service param is: %s", service)
unScapedHost, err := url.QueryUnescape(service)
if err != nil {
return "", fmt.Errorf("unscaping service: %w", err)
}
ep.Logger.Debug("parsed host is: %s", unScapedHost)
destURL, err := url.Parse(unScapedHost)
if err != nil {
// return error response...
return "", fmt.Errorf("parsing URL from unscapped host: %w", err)
}
svc := unScapedHost
if destURL.Host == "" {
// Host is empty, so we don't have a full URL
// - If scheme is empty, most likely result of parsing a single hostname
// as url.Parse adds that as Path.
// - If there is a colon, it will probably be a host:port pair
if destURL.Scheme == "" || strings.Contains(unScapedHost, ":") {
svc = "http://" + unScapedHost
}
} else if destURL.Scheme != "" {
// Scheme is set, we will assume an URL has been provided, eg. http://google.com
svc = destURL.String()
}
ep.Logger.Debug("svc : %s", svc)
return svc, nil
}
func (ep *Endpoints) parseProxyChains(proxyChains string) []string {
var chains []string
if len(proxyChains) == 0 {
return chains
}
const (
chainStart = '['
chainEnd = ']'
)
if proxyChains[0] != chainStart {
return []string{proxyChains}
}
var numStarts, lastChainStartPos int
for i := 0; i < len(proxyChains); i++ {
switch proxyChains[i] {
case chainStart:
numStarts++
case chainEnd:
numStarts--
}
if numStarts == 0 {
chain := proxyChains[lastChainStartPos+1 : i]
if chain != "" {
chains = append(chains, chain)
}
lastChainStartPos = i + 1
}
}
if numStarts != 0 {
ep.Logger.Error("invalid proxy chains: %w", fmt.Errorf("missing proxy chain ending char ']': %s", proxyChains))
}
return chains
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment