Skip to content

Instantly share code, notes, and snippets.

@manhdaovan
Last active June 11, 2023 08:07
Show Gist options
  • Save manhdaovan/1baf5d6211a941b093eaad631d7b72be to your computer and use it in GitHub Desktop.
Save manhdaovan/1baf5d6211a941b093eaad631d7b72be to your computer and use it in GitHub Desktop.
Simple simulation for gRPC ChainUnaryServer - interceptors
package main
import (
"fmt"
)
type handler func (string) string
type interceptor func(string, handler) string
func chainInterceptor(interceptors ...interceptor) interceptor {
interceptorsLen := len(interceptors)
lastIdx := interceptorsLen - 1
return func(originStr string, lastHandler handler) string {
var (
curIdx int
chainHandler handler
)
chainHandler = func(str string) string {
if curIdx == lastIdx {
return lastHandler(str)
}
curIdx ++
r := interceptors[curIdx](str, chainHandler)
return r
}
return interceptors[0](originStr, chainHandler)
}
}
func main() {
interceptor1 := func(str string, handler1 handler) string {
str += "im in interceptor1 |"
fmt.Println("im in interceptor1, str = ", str)
return handler1(str)
}
interceptor2 := func(str string, handler2 handler) string {
str += "im in interceptor2 |"
fmt.Println("im in interceptor2, str = ", str)
return handler2(str)
}
interceptor3 := func(str string, handler3 handler) string {
str += "im in interceptor3 |"
fmt.Println("im in interceptor3, str = ", str)
return handler3(str)
}
lastHandler := func(str string) string {
fmt.Println("str ------- ", str)
return str
}
chainInterceptor(interceptor1, interceptor2, interceptor3)("aaa", lastHandler)
}
@manhdaovan
Copy link
Author

manhdaovan commented Jun 11, 2023

Simple HTTP handlers chain (middleware)

package main

import (
	"fmt"
	"math/rand"
	"net/http"
	"time"
)

type handler func(w http.ResponseWriter, r *http.Request)
type middleware func(w http.ResponseWriter, r *http.Request, h handler)

func chain(middlewares ...middleware) middleware {
	return func(w http.ResponseWriter, r *http.Request, h handler) {
		if len(middlewares) == 0 {
			h(w, r)
			return
		}

		var chainHandler handler
		var currIdx int

		chainHandler = func(w http.ResponseWriter, r *http.Request) {
			if currIdx == len(middlewares) -1 {
				h(w, r)
				return
			}

			currIdx++
			middlewares[currIdx](w, r, chainHandler)
		}

		middlewares[0](w, r, chainHandler)
	}
}

func main() {
	loggerMiddleware := func(w http.ResponseWriter, r *http.Request, h handler) {
		fmt.Println("logger middleware - RequestURI:", r.RequestURI)
		fmt.Println("logger middleware - Response Header:", w.Header())
		h(w, r)
	}

	metricsMiddleware := func(w http.ResponseWriter, r *http.Request, h handler) {
		start := time.Now().UnixMilli()
		defer func ()  {
			end := time.Now().UnixMilli()
			fmt.Println("metrics middleware - executed duration:", end - start)
		}()

		h(w, r)
	}

	m := chain(loggerMiddleware, metricsMiddleware)
	http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
		m(w, r, handleRootPath)
	})

	fmt.Println("started server at 0.0.0.0:8080")
	http.ListenAndServe("0.0.0.0:8080", nil)
}

func handleRootPath(w http.ResponseWriter, r *http.Request) {
	fmt.Println("handleRootPath - RequestURI:", r.RequestURI)
	rand.Seed(time.Now().UnixNano())
	time.Sleep(time.Duration(rand.Intn(1000)) * time.Millisecond)
	now := time.Now().String()
	w.Write([]byte("handleRootPath - response:" + now))
}

or even simpler for chain function

func chain(middlewares ...middleware) middleware {
	return func(w http.ResponseWriter, r *http.Request, h handler) {
		handlerChain := func (w http.ResponseWriter, r *http.Request, m middleware, h handler) handler {
			return func(w http.ResponseWriter, r *http.Request) {
				m(w, r, h)
			}
		}

		for _, m := range middlewares {
			h = handlerChain(w, r, m, h)
		}

		h(w, r)
	}
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment