-
-
Save Hunsin/26b2021757e831554d4f59a52a5c9152 to your computer and use it in GitHub Desktop.
package router | |
import ( | |
"context" | |
"net/http" | |
gpath "path" | |
"github.com/julienschmidt/httprouter" | |
) | |
// Param returns the named URL parameter from a request context. | |
func Param(ctx context.Context, name string) string { | |
if p := httprouter.ParamsFromContext(ctx); p != nil { | |
return p.ByName(name) | |
} | |
return "" | |
} | |
// A Middleware chains http.Handlers. | |
type Middleware func(http.Handler) http.Handler | |
// A Router is a http.Handler which supports routing and middlewares. | |
type Router struct { | |
middlewares []Middleware | |
path string | |
root *httprouter.Router | |
} | |
// New creates a new Router. | |
func New() *Router { | |
return &Router{root: httprouter.New(), path: "/"} | |
} | |
// Group returns a new Router with given path and middlewares. | |
// It should be used for handlers which have same path prefix or | |
// common middlewares. | |
func (r *Router) Group(path string, m ...Middleware) *Router { | |
return &Router{ | |
middlewares: append(m, r.middlewares...), | |
path: gpath.Join(r.path, path), | |
root: r.root, | |
} | |
} | |
// Use appends new middlewares to current Router. | |
func (r *Router) Use(m ...Middleware) *Router { | |
r.middlewares = append(m, r.middlewares...) | |
return r | |
} | |
// Handle registers a new request handler combined with middlewares. | |
func (r *Router) Handle(method, path string, handler http.Handler) { | |
for _, v := range r.middlewares { | |
handler = v(handler) | |
} | |
r.root.Handler(method, gpath.Join(r.path, path), handler) | |
} | |
// GET is a shortcut for r.Handle("GET", path, handler) | |
func (r *Router) GET(path string, handler http.HandlerFunc) { | |
r.Handle(http.MethodGet, path, handler) | |
} | |
// HEAD is a shortcut for r.Handle("HEAD", path, handler) | |
func (r *Router) HEAD(path string, handler http.HandlerFunc) { | |
r.Handle(http.MethodHead, path, handler) | |
} | |
// OPTIONS is a shortcut for r.Handle("OPTIONS", path, handler) | |
func (r *Router) OPTIONS(path string, handler http.HandlerFunc) { | |
r.Handle(http.MethodOptions, path, handler) | |
} | |
// POST is a shortcut for r.Handle("POST", path, handler) | |
func (r *Router) POST(path string, handler http.HandlerFunc) { | |
r.Handle(http.MethodPost, path, handler) | |
} | |
// PUT is a shortcut for r.Handle("PUT", path, handler) | |
func (r *Router) PUT(path string, handler http.HandlerFunc) { | |
r.Handle(http.MethodPut, path, handler) | |
} | |
// PATCH is a shortcut for r.Handle("PATCH", path, handler) | |
func (r *Router) PATCH(path string, handler http.HandlerFunc) { | |
r.Handle(http.MethodPatch, path, handler) | |
} | |
// DELETE is a shortcut for r.Handle("DELETE", path, handler) | |
func (r *Router) DELETE(path string, handler http.HandlerFunc) { | |
r.Handle(http.MethodDelete, path, handler) | |
} | |
// HandleFunc is an adapter for http.HandlerFunc. | |
func (r *Router) HandleFunc(method, path string, handler http.HandlerFunc) { | |
r.Handle(method, path, handler) | |
} | |
// NotFound sets the handler which is called if the request path doesn't match | |
// any routes. It overwrites the previous setting. | |
func (r *Router) NotFound(handler http.Handler) { | |
r.root.NotFound = handler | |
} | |
// Static serves files from given root directory. | |
func (r *Router) Static(path, root string) { | |
if len(path) < 10 || path[len(path)-10:] != "/*filepath" { | |
panic("path should end with '/*filepath' in path '" + path + "'.") | |
} | |
base := gpath.Join(r.path, path[:len(path)-9]) | |
fileServer := http.StripPrefix(base, http.FileServer(http.Dir(root))) | |
r.Handle(http.MethodGet, path, fileServer) | |
} | |
// File serves the named file. | |
func (r *Router) File(path, name string) { | |
r.HandleFunc(http.MethodGet, path, func(w http.ResponseWriter, req *http.Request) { | |
http.ServeFile(w, req, name) | |
}) | |
} | |
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { | |
r.root.ServeHTTP(w, req) | |
} |
package router | |
import ( | |
"io" | |
"net/http" | |
"net/http/httptest" | |
"os" | |
"testing" | |
) | |
func TestHandle(t *testing.T) { | |
router := New() | |
h := func(w http.ResponseWriter, _ *http.Request) { | |
w.WriteHeader(http.StatusTeapot) | |
} | |
router.Handle("GET", "/", http.HandlerFunc(h)) | |
r := httptest.NewRequest("GET", "/", nil) | |
w := httptest.NewRecorder() | |
router.ServeHTTP(w, r) | |
if w.Code != http.StatusTeapot { | |
t.Error("Handle() failed") | |
} | |
} | |
func TestHandleFunc(t *testing.T) { | |
router := New() | |
h := func(w http.ResponseWriter, _ *http.Request) { | |
w.WriteHeader(http.StatusTeapot) | |
} | |
router.HandleFunc("GET", "/", h) | |
r := httptest.NewRequest("GET", "/", nil) | |
w := httptest.NewRecorder() | |
router.ServeHTTP(w, r) | |
if w.Code != http.StatusTeapot { | |
t.Error("HandlerFunc() failed") | |
} | |
} | |
func TestMethod(t *testing.T) { | |
router := New() | |
router.DELETE("/delete", func(w http.ResponseWriter, _ *http.Request) { | |
w.WriteHeader(http.StatusTeapot) | |
}) | |
router.GET("/get", func(w http.ResponseWriter, _ *http.Request) { | |
w.WriteHeader(http.StatusTeapot) | |
}) | |
router.HEAD("/head", func(w http.ResponseWriter, _ *http.Request) { | |
w.WriteHeader(http.StatusTeapot) | |
}) | |
router.OPTIONS("/options", func(w http.ResponseWriter, _ *http.Request) { | |
w.WriteHeader(http.StatusTeapot) | |
}) | |
router.PATCH("/patch", func(w http.ResponseWriter, _ *http.Request) { | |
w.WriteHeader(http.StatusTeapot) | |
}) | |
router.POST("/post", func(w http.ResponseWriter, _ *http.Request) { | |
w.WriteHeader(http.StatusTeapot) | |
}) | |
router.PUT("/put", func(w http.ResponseWriter, _ *http.Request) { | |
w.WriteHeader(http.StatusTeapot) | |
}) | |
samples := map[string]string{ | |
"DELETE": "/delete", | |
"GET": "/get", | |
"HEAD": "/head", | |
"OPTIONS": "/options", | |
"PATCH": "/patch", | |
"POST": "/post", | |
"PUT": "/put", | |
} | |
for method, path := range samples { | |
r := httptest.NewRequest(method, path, nil) | |
w := httptest.NewRecorder() | |
router.ServeHTTP(w, r) | |
if w.Code != http.StatusTeapot { | |
t.Errorf("Path %s not registered", path) | |
} | |
} | |
} | |
func TestGroup(t *testing.T) { | |
router := New() | |
foo := router.Group("/foo") | |
bar := router.Group("/bar") | |
baz := foo.Group("/baz") | |
foo.HandleFunc("GET", "", func(w http.ResponseWriter, _ *http.Request) { | |
w.WriteHeader(http.StatusTeapot) | |
}) | |
foo.HandleFunc("GET", "/group", func(w http.ResponseWriter, _ *http.Request) { | |
w.WriteHeader(http.StatusTeapot) | |
}) | |
bar.HandleFunc("GET", "/group", func(w http.ResponseWriter, _ *http.Request) { | |
w.WriteHeader(http.StatusTeapot) | |
}) | |
baz.HandleFunc("GET", "/group", func(w http.ResponseWriter, _ *http.Request) { | |
w.WriteHeader(http.StatusTeapot) | |
}) | |
samples := []string{"/foo", "/foo/group", "/foo/baz/group", "/bar/group"} | |
for _, path := range samples { | |
r := httptest.NewRequest("GET", path, nil) | |
w := httptest.NewRecorder() | |
router.ServeHTTP(w, r) | |
if w.Code != http.StatusTeapot { | |
t.Errorf("Grouped path %s not registered", path) | |
} | |
} | |
} | |
func TestMiddleware(t *testing.T) { | |
var use, group bool | |
router := New().Use(func(next http.Handler) http.Handler { | |
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
use = true | |
next.ServeHTTP(w, r) | |
}) | |
}) | |
foo := router.Group("/foo", func(next http.Handler) http.Handler { | |
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
group = true | |
next.ServeHTTP(w, r) | |
}) | |
}) | |
foo.HandleFunc("GET", "/bar", func(w http.ResponseWriter, _ *http.Request) { | |
w.WriteHeader(http.StatusTeapot) | |
}) | |
r := httptest.NewRequest("GET", "/foo/bar", nil) | |
w := httptest.NewRecorder() | |
router.ServeHTTP(w, r) | |
if !use { | |
t.Error("Middleware registered by Use() under \"/\" not touched") | |
} | |
if !group { | |
t.Error("Middleware registered by Group() under \"/foo\" not touched") | |
} | |
} | |
func createTemp(name, content string) error { | |
f, err := os.Create(name) | |
if err != nil { | |
return err | |
} | |
defer f.Close() | |
_, err = f.WriteString(content) | |
if err != nil { | |
return err | |
} | |
return f.Sync() | |
} | |
func TestStatic(t *testing.T) { | |
files := []string{"temp_1", "temp_2"} | |
strs := []string{"test content", "static contents"} | |
for i := range files { | |
err := createTemp(files[i], strs[i]) | |
if err != nil { | |
t.Fatal("failed creating temp file:", err) | |
} | |
defer os.Remove(files[i]) | |
} | |
pwd, _ := os.Getwd() | |
router := New() | |
router.Static("/*filepath", pwd) | |
for i := range files { | |
r := httptest.NewRequest("GET", "/"+files[i], nil) | |
w := httptest.NewRecorder() | |
router.ServeHTTP(w, r) | |
body := w.Result().Body | |
defer body.Close() | |
file, _ := io.ReadAll(body) | |
if string(file) != strs[i] { | |
t.Error("Static() failed") | |
} | |
} | |
} | |
func TestFile(t *testing.T) { | |
name := "temp_file" | |
str := "test_content" | |
if err := createTemp(name, str); err != nil { | |
t.Fatal("failed creating temp file:", err) | |
} | |
defer os.Remove(name) | |
router := New() | |
router.File("/file", name) | |
r := httptest.NewRequest("GET", "/file", nil) | |
w := httptest.NewRecorder() | |
router.ServeHTTP(w, r) | |
body := w.Result().Body | |
defer body.Close() | |
file, _ := io.ReadAll(body) | |
if string(file) != str { | |
t.Error("File() failed") | |
} | |
} |
Hello @saeidakbari may I see your sample code?
Hello @Hunsin
sorry for late
here is my code
package main
import (
"github.com/julienschmidt/httprouter"
"./router"
"net/http"
"log"
)
func main() {
r := httprouter.New()
mid1 := func(handle httprouter.Handle) httprouter.Handle {
return func(writer http.ResponseWriter, request *http.Request, params httprouter.Params) {
log.Print("Hello 1")
r.ServeHTTP(writer, request)
}
}
mid2 := func(handle httprouter.Handle) httprouter.Handle {
return func(writer http.ResponseWriter, request *http.Request, params httprouter.Params) {
log.Print("Hello 2")
r.ServeHTTP(writer, request)
}
}
v1 := router.NewGroup(r, "/v1")
v2 := router.NewGroup(r, "/v2")
v21 := v2.Group("/v1", mid1, mid2)
r.GET("/", func(writer http.ResponseWriter, request *http.Request, params httprouter.Params) {
writer.Write([]byte("Hello from Index"))
})
v1.GET("/test", func(writer http.ResponseWriter, request *http.Request, params httprouter.Params) {
writer.Write([]byte("Hello from v1"))
})
v2.GET("/test", func(writer http.ResponseWriter, request *http.Request, params httprouter.Params) {
writer.Write([]byte("Hello from v2"))
})
v21.GET("/test", func(writer http.ResponseWriter, request *http.Request, params httprouter.Params) {
writer.Write([]byte("Hello from v21"))
})
log.Fatal(http.ListenAndServe(":8000", r))
}
and the router is exactly as you write above
Hello @saeidakbari
I found Github doesn't send notification here. So, feel free to contact me at: asky[dot]hunsin[at]gmail[dot]com
First
I modified the code again. I change func NewRouter()
to New()
. But it doesn't matter in your case, just for your convenience.
Second
It is recommend to put your source code in this structure:
gopath <-- the folder where environment variable GOPATH is set to
└─ src
├─ github.com
│ └─ saeidakbari
│ ├─ router
│ │ └─ router.go <-- package router
│ └─ project
│ └─ main.go <-- your main package
└─ others
Then you can import router
package in main.go
import (
// ...others
"github.com/saeidakbari/router"
)
For more information please check here
Third
Declare r
variable with router.New()
. httprouter.Router
doesn't support middleware and groups but router.Router
can.
r := router.New()
The reason your code stuck in loop is because you call r.ServeHTTP
in the middlewares.
Instead, you need to call handle(writer, request, params)
.
mid1 := func(handle httprouter.Handle) httprouter.Handle {
return func(writer http.ResponseWriter, request *http.Request, params httprouter.Params) {
log.Print("Hello 1")
handle(writer, request, params)
}
}
Thank you for sharing such a great thing
But the problem is when I call a Middleware routing stuck in endless loop and just do Middleware.
Can you help what is the problem?