-
-
Save aclisp/aeecf4035fb2cb2c7b54d916cf5dc3be to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package gziphandler | |
import ( | |
"compress/gzip" | |
"log" | |
"net/http" | |
"strings" | |
"sync" | |
) | |
type gzipResponseWriter struct { | |
http.ResponseWriter | |
w *gzip.Writer | |
statusCode int | |
headerWritten bool | |
} | |
var ( | |
pool = sync.Pool{ | |
New: func() interface{} { | |
w, _ := gzip.NewWriterLevel(nil, gzip.BestSpeed) | |
return &gzipResponseWriter{ | |
w: w, | |
} | |
}, | |
} | |
) | |
func (gzr *gzipResponseWriter) WriteHeader(statusCode int) { | |
gzr.statusCode = statusCode | |
gzr.headerWritten = true | |
if gzr.statusCode != http.StatusNotModified && gzr.statusCode != http.StatusNoContent { | |
gzr.ResponseWriter.Header().Del("Content-Length") | |
gzr.ResponseWriter.Header().Set("Content-Encoding", "gzip") | |
} | |
gzr.ResponseWriter.WriteHeader(statusCode) | |
} | |
func (gzr *gzipResponseWriter) Write(b []byte) (int, error) { | |
if _, ok := gzr.Header()["Content-Type"]; !ok { | |
// If no content type, apply sniffing algorithm to un-gzipped body. | |
gzr.ResponseWriter.Header().Set("Content-Type", http.DetectContentType(b)) | |
} | |
if !gzr.headerWritten { | |
// This is exactly what Go would also do if it hasn't been written yet. | |
gzr.WriteHeader(http.StatusOK) | |
} | |
return gzr.w.Write(b) | |
} | |
func (gzr *gzipResponseWriter) Flush() { | |
if gzr.w != nil { | |
gzr.w.Flush() | |
} | |
if fw, ok := gzr.ResponseWriter.(http.Flusher); ok { | |
fw.Flush() | |
} | |
} | |
func New(fn http.HandlerFunc) http.HandlerFunc { | |
return func(w http.ResponseWriter, r *http.Request) { | |
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { | |
fn(w, r) | |
return | |
} | |
gzr := pool.Get().(*gzipResponseWriter) | |
gzr.statusCode = 0 | |
gzr.headerWritten = false | |
gzr.ResponseWriter = w | |
gzr.w.Reset(w) | |
defer func() { | |
// gzr.w.Close will write a footer even if no data has been written. | |
// StatusNotModified and StatusNoContent expect an empty body so don't close it. | |
if gzr.statusCode != http.StatusNotModified && gzr.statusCode != http.StatusNoContent { | |
if err := gzr.w.Close(); err != nil { | |
log.Printf("[ERR] %v", err) | |
} | |
} | |
pool.Put(gzr) | |
}() | |
fn(gzr, r) | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package gziphandler | |
import ( | |
"compress/gzip" | |
"io/ioutil" | |
"net/http" | |
"net/http/httptest" | |
"net/http/httputil" | |
"testing" | |
) | |
func TestNoGzip(t *testing.T) { | |
req, err := http.NewRequest("GET", "http://example.com/", nil) | |
if err != nil { | |
t.Fatal(err) | |
} | |
rec := httptest.NewRecorder() | |
New(func(w http.ResponseWriter, r *http.Request) { | |
w.Write([]byte("test")) | |
})(rec, req) | |
if rec.Code != 200 { | |
t.Fatalf("expected 200 got %d", rec.Code) | |
} | |
if rec.HeaderMap.Get("Content-Encoding") != "" { | |
t.Fatalf(`expected Content-Encoding: "" got %s`, rec.HeaderMap.Get("Content-Encoding")) | |
} | |
if rec.Body.String() != "test" { | |
t.Fatalf(`expected "test" go "%s"`, rec.Body.String()) | |
} | |
if testing.Verbose() { | |
b, _ := httputil.DumpResponse(rec.Result(), true) | |
t.Log("\n" + string(b)) | |
} | |
} | |
func TestGzip(t *testing.T) { | |
req, err := http.NewRequest("GET", "http://example.com/", nil) | |
if err != nil { | |
t.Fatal(err) | |
} | |
req.Header.Set("Accept-Encoding", "gzip, deflate") | |
rec := httptest.NewRecorder() | |
New(func(w http.ResponseWriter, r *http.Request) { | |
w.Header().Set("Content-Length", "4") | |
w.Header().Set("Content-Type", "text/test") | |
w.Write([]byte("test")) | |
})(rec, req) | |
if rec.Code != 200 { | |
t.Fatalf("expected 200 got %d", rec.Code) | |
} | |
if rec.HeaderMap.Get("Content-Encoding") != "gzip" { | |
t.Fatalf("expected Content-Encoding: gzip got %s", rec.HeaderMap.Get("Content-Encoding")) | |
} | |
if rec.HeaderMap.Get("Content-Length") != "" { | |
t.Fatalf(`expected Content-Length: "" got %s`, rec.HeaderMap.Get("Content-Length")) | |
} | |
if rec.HeaderMap.Get("Content-Type") != "text/test" { | |
t.Fatalf(`expected Content-Type: "text/test" got %s`, rec.HeaderMap.Get("Content-Type")) | |
} | |
r, err := gzip.NewReader(rec.Body) | |
if err != nil { | |
t.Fatal(err) | |
} | |
body, err := ioutil.ReadAll(r) | |
if err != nil { | |
t.Fatal(err) | |
} | |
if string(body) != "test" { | |
t.Fatalf(`expected "test" go "%s"`, string(body)) | |
} | |
if testing.Verbose() { | |
b, _ := httputil.DumpResponse(rec.Result(), true) | |
t.Log("\n" + string(b)) | |
} | |
} | |
func TestNoBody(t *testing.T) { | |
req, err := http.NewRequest("GET", "http://example.com/", nil) | |
if err != nil { | |
t.Fatal(err) | |
} | |
req.Header.Set("Accept-Encoding", "gzip, deflate") | |
rec := httptest.NewRecorder() | |
New(func(w http.ResponseWriter, r *http.Request) { | |
w.WriteHeader(http.StatusNoContent) | |
})(rec, req) | |
if rec.Code != http.StatusNoContent { | |
t.Fatalf("expected %d got %d", http.StatusNoContent, rec.Code) | |
} | |
if rec.HeaderMap.Get("Content-Encoding") != "" { | |
t.Fatalf(`expected Content-Encoding: "" got %s`, rec.HeaderMap.Get("Content-Encoding")) | |
} | |
if rec.Body.Len() > 0 { | |
t.Logf("%q", rec.Body.String()) | |
t.Fatalf("no body expected for %d bytes", rec.Body.Len()) | |
} | |
if testing.Verbose() { | |
b, _ := httputil.DumpResponse(rec.Result(), true) | |
t.Log("\n" + string(b)) | |
} | |
} | |
func BenchmarkGzip(b *testing.B) { | |
body := []byte("testtesttesttesttesttesttesttesttesttesttesttesttest") | |
req, err := http.NewRequest("GET", "http://example.com/", nil) | |
if err != nil { | |
b.Fatal(err) | |
} | |
req.Header.Set("Accept-Encoding", "gzip, deflate") | |
b.ResetTimer() | |
b.RunParallel(func(pb *testing.PB) { | |
for pb.Next() { | |
rec := httptest.NewRecorder() | |
New(func(w http.ResponseWriter, r *http.Request) { | |
w.Write(body) | |
})(rec, req) | |
if rec.Code != http.StatusOK { | |
b.Fatalf("expected %d got %d", http.StatusOK, rec.Code) | |
} | |
if rec.Body.Len() != 49 { | |
b.Fatalf("expected 49 bytes, got %d bytes", rec.Body.Len()) | |
} | |
} | |
}) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment