Skip to content

Instantly share code, notes, and snippets.

@alexedwards
Last active May 8, 2025 20:19
Show Gist options
  • Save alexedwards/81c75d0dca1a0943be2d5ee228d2d69b to your computer and use it in GitHub Desktop.
Save alexedwards/81c75d0dca1a0943be2d5ee228d2d69b to your computer and use it in GitHub Desktop.
func TestRouter(t *testing.T) {
used := ""
mw1 := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
used += "1"
next.ServeHTTP(w, r)
})
}
mw2 := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
used += "2"
next.ServeHTTP(w, r)
})
}
mw3 := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
used += "3"
next.ServeHTTP(w, r)
})
}
mw4 := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
used += "4"
next.ServeHTTP(w, r)
})
}
mw5 := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
used += "5"
next.ServeHTTP(w, r)
})
}
mw6 := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
used += "6"
next.ServeHTTP(w, r)
})
}
hf := func(w http.ResponseWriter, r *http.Request) {}
r := NewRouter()
r.Use(mw1)
r.Use(mw2)
r.HandleFunc("GET /{$}", hf)
r.Group(func(r *Router) {
r.Use(mw3, mw4)
r.HandleFunc("GET /foo", hf)
r.Group(func(r *Router) {
r.Use(mw5)
r.HandleFunc("GET /nested/foo", hf)
})
})
r.Group(func(r *Router) {
r.Use(mw6)
r.HandleFunc("GET /bar", hf)
})
r.HandleFunc("GET /baz", hf)
var tests = []struct {
RequestMethod string
RequestPath string
ExpectedUsed string
ExpectedStatus int
}{
{
RequestMethod: "GET",
RequestPath: "/",
ExpectedUsed: "12",
ExpectedStatus: http.StatusOK,
},
{
RequestMethod: "GET",
RequestPath: "/foo",
ExpectedUsed: "1234",
ExpectedStatus: http.StatusOK,
},
{
RequestMethod: "GET",
RequestPath: "/nested/foo",
ExpectedUsed: "12345",
ExpectedStatus: http.StatusOK,
},
{
RequestMethod: "GET",
RequestPath: "/bar",
ExpectedUsed: "126",
ExpectedStatus: http.StatusOK,
},
{
RequestMethod: "GET",
RequestPath: "/baz",
ExpectedUsed: "12",
ExpectedStatus: http.StatusOK,
},
// Check global middleware used on errors generated by http.ServeMux
{
RequestMethod: "GET",
RequestPath: "/notfound",
ExpectedUsed: "12",
ExpectedStatus: http.StatusNotFound,
},
{
RequestMethod: "POST",
RequestPath: "/nested/foo",
ExpectedUsed: "12",
ExpectedStatus: http.StatusMethodNotAllowed,
},
}
for _, test := range tests {
used = ""
rq, err := http.NewRequest(test.RequestMethod, test.RequestPath, nil)
if err != nil {
t.Errorf("NewRequest: %s", err)
}
rr := httptest.NewRecorder()
r.ServeHTTP(rr, rq)
rs := rr.Result()
if rs.StatusCode != test.ExpectedStatus {
t.Errorf("%s %s: expected status %d but was %d", test.RequestMethod, test.RequestPath, test.ExpectedStatus, rs.StatusCode)
}
if used != test.ExpectedUsed {
t.Errorf("%s %s: middleware used: expected %q; got %q", test.RequestMethod, test.RequestPath, test.ExpectedUsed, used)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment