Skip to content

Instantly share code, notes, and snippets.

@manhdaovan
Last active June 11, 2023 08:07
Show Gist options
  • Save manhdaovan/1baf5d6211a941b093eaad631d7b72be to your computer and use it in GitHub Desktop.
Save manhdaovan/1baf5d6211a941b093eaad631d7b72be to your computer and use it in GitHub Desktop.
Simple simulation for gRPC ChainUnaryServer - interceptors
package main
import (
"fmt"
)
type handler func (string) string
type interceptor func(string, handler) string
func chainInterceptor(interceptors ...interceptor) interceptor {
interceptorsLen := len(interceptors)
lastIdx := interceptorsLen - 1
return func(originStr string, lastHandler handler) string {
var (
curIdx int
chainHandler handler
)
chainHandler = func(str string) string {
if curIdx == lastIdx {
return lastHandler(str)
}
curIdx ++
r := interceptors[curIdx](str, chainHandler)
return r
}
return interceptors[0](originStr, chainHandler)
}
}
func main() {
interceptor1 := func(str string, handler1 handler) string {
str += "im in interceptor1 |"
fmt.Println("im in interceptor1, str = ", str)
return handler1(str)
}
interceptor2 := func(str string, handler2 handler) string {
str += "im in interceptor2 |"
fmt.Println("im in interceptor2, str = ", str)
return handler2(str)
}
interceptor3 := func(str string, handler3 handler) string {
str += "im in interceptor3 |"
fmt.Println("im in interceptor3, str = ", str)
return handler3(str)
}
lastHandler := func(str string) string {
fmt.Println("str ------- ", str)
return str
}
chainInterceptor(interceptor1, interceptor2, interceptor3)("aaa", lastHandler)
}
@manhdaovan
Copy link
Author

manhdaovan commented Sep 5, 2020

It can be illustrated mathematically:

package main

import "fmt"

// F is a f(x) -> y, 
// considered as a Handler
type F func(x int) (y int)

// G is a g(x, f) -> y (same output as f), 
// considered as a Interceptor
type G func(x int, f F) (y int)

// H is a h(g1, g2, ..., gn) -> g (chaining all Gs to single G), 
// considered as a chain of Interceptors
type H func(gs ...G) (g G)

// K is a k(x, f, g) -> f (chaining function)
// K will be used inside H as a adapter to chain all Gs to single G by a loop (recursion)
// such as:
//           f0 = f
//           f1 = k(x, f0, g0)
//           f2 = k(x, f1, g1)
//           fn = k(x, fn-1, gn-1)
type K func(x int, f F, g G) F

func main() {
	var handler F = func(x int) int {
		fmt.Println("x == ", x)
		return x
	}

	var plusInterceptor G = func(x int, handler F) int {
		fmt.Println("x before plusInterceptor == ", x)
		x = x + 2
		return handler(x)
	}

	var timesInterceptor G = func(x int, handler F) int {
		fmt.Println("x before timesInterceptor == ", x)
		x = x * 2
		return handler(x)
	}

	var chainInterceptor H = func(interceptors ...G) G {
		var chain K = func(x int, handler F, interceptor G) F {
			return func(x int) (y int) {
				return interceptor(x, handler)
			}
		}
	
		return func(x int, handler F) (y int) {
			// chain from bottom to top
			// so apply from top to bottom
			// g1(g2(g3(g4)))(x)
			for i := len(interceptors) -1; i >= 0 ; i -- {
				handler = chain(x, handler, interceptors[i])
			}
	
			return handler(x)
		}
	}

	interceptor := chainInterceptor(plusInterceptor, timesInterceptor)
	interceptor(3, handler)
}

@manhdaovan
Copy link
Author

manhdaovan commented Jun 11, 2023

Simple HTTP handlers chain (middleware)

package main

import (
	"fmt"
	"math/rand"
	"net/http"
	"time"
)

type handler func(w http.ResponseWriter, r *http.Request)
type middleware func(w http.ResponseWriter, r *http.Request, h handler)

func chain(middlewares ...middleware) middleware {
	return func(w http.ResponseWriter, r *http.Request, h handler) {
		if len(middlewares) == 0 {
			h(w, r)
			return
		}

		var chainHandler handler
		var currIdx int

		chainHandler = func(w http.ResponseWriter, r *http.Request) {
			if currIdx == len(middlewares) -1 {
				h(w, r)
				return
			}

			currIdx++
			middlewares[currIdx](w, r, chainHandler)
		}

		middlewares[0](w, r, chainHandler)
	}
}

func main() {
	loggerMiddleware := func(w http.ResponseWriter, r *http.Request, h handler) {
		fmt.Println("logger middleware - RequestURI:", r.RequestURI)
		fmt.Println("logger middleware - Response Header:", w.Header())
		h(w, r)
	}

	metricsMiddleware := func(w http.ResponseWriter, r *http.Request, h handler) {
		start := time.Now().UnixMilli()
		defer func ()  {
			end := time.Now().UnixMilli()
			fmt.Println("metrics middleware - executed duration:", end - start)
		}()

		h(w, r)
	}

	m := chain(loggerMiddleware, metricsMiddleware)
	http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
		m(w, r, handleRootPath)
	})

	fmt.Println("started server at 0.0.0.0:8080")
	http.ListenAndServe("0.0.0.0:8080", nil)
}

func handleRootPath(w http.ResponseWriter, r *http.Request) {
	fmt.Println("handleRootPath - RequestURI:", r.RequestURI)
	rand.Seed(time.Now().UnixNano())
	time.Sleep(time.Duration(rand.Intn(1000)) * time.Millisecond)
	now := time.Now().String()
	w.Write([]byte("handleRootPath - response:" + now))
}

or even simpler for chain function

func chain(middlewares ...middleware) middleware {
	return func(w http.ResponseWriter, r *http.Request, h handler) {
		handlerChain := func (w http.ResponseWriter, r *http.Request, m middleware, h handler) handler {
			return func(w http.ResponseWriter, r *http.Request) {
				m(w, r, h)
			}
		}

		for _, m := range middlewares {
			h = handlerChain(w, r, m, h)
		}

		h(w, r)
	}
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment