Skip to content

Instantly share code, notes, and snippets.

@zerbitx
Created April 12, 2017 16:47
Show Gist options
  • Save zerbitx/486d7bc8454054016d55b888165e0ba3 to your computer and use it in GitHub Desktop.
Save zerbitx/486d7bc8454054016d55b888165e0ba3 to your computer and use it in GitHub Desktop.
Little fileserver that will allow CORS by default, inspired by wanting to use the swagger ui from any directory containing a swagger.json
package main
import (
"flag"
"fmt"
"log"
"net/http"
"os"
"strings"
)
type middleware func(http.Handler) http.Handler
func main() {
wd, err := os.Getwd()
if err != nil {
log.Fatal(err)
}
var port int
var disableCors bool
var resHeaders string
var workingDir string
flag.IntVar(&port, "port", 3000, "Port on which to serve this directory")
flag.BoolVar(&disableCors, "disable.cors", false, "Disable the wildcard Access-Control-Allow-Origin header")
flag.StringVar(&resHeaders, "res.headers", "", "CSV of header names and values to add to server responses")
flag.StringVar(&workingDir, "working.dir", wd, "What directory to serve files from. Defaults to the current working directory.")
flag.Parse()
var h http.Handler
{
// Set up the file server
h = http.FileServer(http.Dir(workingDir))
if !disableCors {
// Add CORS to the response headers
h = corsMW(h)
}
// Add any custom response headers the user chooses.
h = responseHeadersMW(headersMap(resHeaders))(h)
}
log.Printf("Serving %s on %d\n", workingDir, port)
http.Handle("/", h)
log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", port), nil))
}
func corsMW(next http.Handler) http.Handler {
return responseHeadersMW(map[string]string{
"Access-Control-Allow-Origin": "*",
})(next)
}
func responseHeadersMW(headers map[string]string) middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
for h, v := range headers {
rw.Header().Add(h, v)
}
next.ServeHTTP(rw, req)
})
}
}
func headersMap(headersCsv string) map[string]string {
tkns := strings.Split(headersCsv, ",")
headers := map[string]string{}
if len(tkns)%2 == 0 {
for i := 0; i < len(tkns); i += 2 {
key := tkns[i]
value := tkns[i+1]
headers[key] = value
}
}
return headers
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment