Skip to content

Instantly share code, notes, and snippets.

@jankuo
Forked from shaunlee/restful.go
Created April 1, 2014 13:29
Show Gist options
  • Save jankuo/9913971 to your computer and use it in GitHub Desktop.
Save jankuo/9913971 to your computer and use it in GitHub Desktop.
package main
import (
"log"
"fmt"
"strings"
"regexp"
"net/http"
)
const DEFAULT_MAX_MEMORY = 32 << 20 // equals to http.defaultMaxMemory
var (
RE_URL_PATTERNS = regexp.MustCompile(`<(((int|float|path):)?(\w+))>`)
URL_DEFAULT_PATTERN = `(?P<%s>[^/]+)`
URL_INT_PATTERN = `(?P<%s>\d+)`
URL_FLOAT_PATTERN = `(?P<%s>\d+\.\d+)`
URL_PATH_PATTERN = `(?P<%s>.+)`
)
type HandlerFunc func(http.ResponseWriter, *http.Request, func())
type PatternMethod map[*regexp.Regexp]http.HandlerFunc
type WebHandler struct {
patterns map[string]PatternMethod
errorHandlers map[int]http.HandlerFunc
}
func NewWebHandler() *WebHandler {
p := &WebHandler{
patterns: make(map[string]PatternMethod),
errorHandlers: make(map[int]http.HandlerFunc),
}
return p
}
// /<field>/<int:field>/<float:field>/<path:field>
func (p *WebHandler) register(method, pattern string, fn ...HandlerFunc) {
if _, ok := p.patterns[method]; !ok {
p.patterns[method] = make(PatternMethod)
}
var replaceTo string
for _, matches := range RE_URL_PATTERNS.FindAllStringSubmatch(pattern, -1) {
switch matches[3] {
case "int": replaceTo = fmt.Sprintf(URL_INT_PATTERN, matches[4])
case "float": replaceTo = fmt.Sprintf(URL_FLOAT_PATTERN, matches[4])
case "path": replaceTo = fmt.Sprintf(URL_PATH_PATTERN, matches[4])
default: replaceTo = fmt.Sprintf(URL_DEFAULT_PATTERN, matches[4])
}
pattern = strings.Replace(pattern, matches[0], replaceTo, 1)
}
p.patterns[method][regexp.MustCompile(fmt.Sprintf(`^%s$`, pattern))] = func(w http.ResponseWriter, r *http.Request) {
var (
i = 0
next func()
)
next = func() {
if i < len(fn) {
f := fn[i]; i++
f(w, r, next)
}
}
next()
}
}
func (p *WebHandler) Get(pattern string, fn ...HandlerFunc) { p.register("GET", pattern, fn...) }
func (p *WebHandler) Post(pattern string, fn ...HandlerFunc) { p.register("POST", pattern, fn...) }
func (p *WebHandler) Put(pattern string, fn ...HandlerFunc) { p.register("PUT", pattern, fn...) }
func (p *WebHandler) Delete(pattern string, fn ...HandlerFunc) { p.register("DELETE", pattern, fn...) }
func (p *WebHandler) All(pattern string, methods []string, fn ...HandlerFunc) {
if methods == nil {
methods = []string{"GET", "POST", "PUT", "DELETE"}
}
for _, method := range methods { p.register(method, pattern, fn...) }
}
func (p *WebHandler) Handle(code int, fn http.HandlerFunc) {
p.errorHandlers[code] = fn
}
func (p *WebHandler) Raise(w http.ResponseWriter, r *http.Request, err string, code int) {
if fn, ok := p.errorHandlers[code]; ok {
fn(w, r)
} else {
http.Error(w, err, code)
}
}
func (p *WebHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer func() {
// TODO: let it go while debugging?
if err := recover(); err != nil {
log.Println(err)
p.Raise(w, r, "500 internal server error", http.StatusInternalServerError)
}
}()
if strings.HasPrefix(r.Header.Get("Content-Type"), "multipart/form-data") {
r.ParseMultipartForm(DEFAULT_MAX_MEMORY)
} else {
r.ParseForm()
}
var (
found = false
url = r.URL.RequestURI()
)
if n := strings.Index(url, "?"); n > -1 { url = url[:n] }
if patterns, ok := p.patterns[r.Method]; ok {
for pattern, fn := range patterns {
if values := pattern.FindStringSubmatch(url); len(values) > 0 {
for i, field := range pattern.SubexpNames() {
if field != "" { r.Form.Set(field, values[i]) }
}
fn(w, r)
found = true
break
}
}
}
if !found {
p.Raise(w, r, "404 page not found", http.StatusNotFound)
}
}
func (p *WebHandler) Run(addr string) error {
return http.ListenAndServe(addr, p)
}
func main() {
app := NewWebHandler()
app.Handle(http.StatusInternalServerError, func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Internal server error!!!", http.StatusInternalServerError)
})
app.Handle(http.StatusNotFound, func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Page not found!!!", http.StatusNotFound)
})
app.Get(`/`,
func(w http.ResponseWriter, r *http.Request, next func()) {
fmt.Fprintf(w, "Middleware")
if r.FormValue("next") == "1" {
next()
}
},
func(w http.ResponseWriter, r *http.Request, next func()) {
fmt.Fprintf(w, "Home")
})
app.All(`/blog/<int:id>/`, nil,
func(w http.ResponseWriter, r *http.Request, next func()) {
fmt.Fprintf(w, "Blog %s", r.FormValue("id"))
})
app.Post(`/test/<name>/<int:id>/<float:lat>/<path:filename>`,
func(w http.ResponseWriter, r *http.Request, next func()) {
fmt.Fprintf(w, "Test %s %s %s %s", r.FormValue("id"), r.FormValue("name"), r.FormValue("lat"), r.FormValue("filename"))
})
log.Println("Listening on :8080 ...")
log.Fatal(app.Run(":8080"))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment