Skip to content

Instantly share code, notes, and snippets.

@aclisp
Forked from erikdubbelboer/gziphandler.go
Created July 7, 2017 11:20
Show Gist options
  • Save aclisp/aeecf4035fb2cb2c7b54d916cf5dc3be to your computer and use it in GitHub Desktop.
Save aclisp/aeecf4035fb2cb2c7b54d916cf5dc3be to your computer and use it in GitHub Desktop.
package gziphandler
import (
"compress/gzip"
"log"
"net/http"
"strings"
"sync"
)
type gzipResponseWriter struct {
http.ResponseWriter
w *gzip.Writer
statusCode int
headerWritten bool
}
var (
pool = sync.Pool{
New: func() interface{} {
w, _ := gzip.NewWriterLevel(nil, gzip.BestSpeed)
return &gzipResponseWriter{
w: w,
}
},
}
)
func (gzr *gzipResponseWriter) WriteHeader(statusCode int) {
gzr.statusCode = statusCode
gzr.headerWritten = true
if gzr.statusCode != http.StatusNotModified && gzr.statusCode != http.StatusNoContent {
gzr.ResponseWriter.Header().Del("Content-Length")
gzr.ResponseWriter.Header().Set("Content-Encoding", "gzip")
}
gzr.ResponseWriter.WriteHeader(statusCode)
}
func (gzr *gzipResponseWriter) Write(b []byte) (int, error) {
if _, ok := gzr.Header()["Content-Type"]; !ok {
// If no content type, apply sniffing algorithm to un-gzipped body.
gzr.ResponseWriter.Header().Set("Content-Type", http.DetectContentType(b))
}
if !gzr.headerWritten {
// This is exactly what Go would also do if it hasn't been written yet.
gzr.WriteHeader(http.StatusOK)
}
return gzr.w.Write(b)
}
func (gzr *gzipResponseWriter) Flush() {
if gzr.w != nil {
gzr.w.Flush()
}
if fw, ok := gzr.ResponseWriter.(http.Flusher); ok {
fw.Flush()
}
}
func New(fn http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
fn(w, r)
return
}
gzr := pool.Get().(*gzipResponseWriter)
gzr.statusCode = 0
gzr.headerWritten = false
gzr.ResponseWriter = w
gzr.w.Reset(w)
defer func() {
// gzr.w.Close will write a footer even if no data has been written.
// StatusNotModified and StatusNoContent expect an empty body so don't close it.
if gzr.statusCode != http.StatusNotModified && gzr.statusCode != http.StatusNoContent {
if err := gzr.w.Close(); err != nil {
log.Printf("[ERR] %v", err)
}
}
pool.Put(gzr)
}()
fn(gzr, r)
}
}
package gziphandler
import (
"compress/gzip"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/http/httputil"
"testing"
)
func TestNoGzip(t *testing.T) {
req, err := http.NewRequest("GET", "http://example.com/", nil)
if err != nil {
t.Fatal(err)
}
rec := httptest.NewRecorder()
New(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("test"))
})(rec, req)
if rec.Code != 200 {
t.Fatalf("expected 200 got %d", rec.Code)
}
if rec.HeaderMap.Get("Content-Encoding") != "" {
t.Fatalf(`expected Content-Encoding: "" got %s`, rec.HeaderMap.Get("Content-Encoding"))
}
if rec.Body.String() != "test" {
t.Fatalf(`expected "test" go "%s"`, rec.Body.String())
}
if testing.Verbose() {
b, _ := httputil.DumpResponse(rec.Result(), true)
t.Log("\n" + string(b))
}
}
func TestGzip(t *testing.T) {
req, err := http.NewRequest("GET", "http://example.com/", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Accept-Encoding", "gzip, deflate")
rec := httptest.NewRecorder()
New(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", "4")
w.Header().Set("Content-Type", "text/test")
w.Write([]byte("test"))
})(rec, req)
if rec.Code != 200 {
t.Fatalf("expected 200 got %d", rec.Code)
}
if rec.HeaderMap.Get("Content-Encoding") != "gzip" {
t.Fatalf("expected Content-Encoding: gzip got %s", rec.HeaderMap.Get("Content-Encoding"))
}
if rec.HeaderMap.Get("Content-Length") != "" {
t.Fatalf(`expected Content-Length: "" got %s`, rec.HeaderMap.Get("Content-Length"))
}
if rec.HeaderMap.Get("Content-Type") != "text/test" {
t.Fatalf(`expected Content-Type: "text/test" got %s`, rec.HeaderMap.Get("Content-Type"))
}
r, err := gzip.NewReader(rec.Body)
if err != nil {
t.Fatal(err)
}
body, err := ioutil.ReadAll(r)
if err != nil {
t.Fatal(err)
}
if string(body) != "test" {
t.Fatalf(`expected "test" go "%s"`, string(body))
}
if testing.Verbose() {
b, _ := httputil.DumpResponse(rec.Result(), true)
t.Log("\n" + string(b))
}
}
func TestNoBody(t *testing.T) {
req, err := http.NewRequest("GET", "http://example.com/", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Accept-Encoding", "gzip, deflate")
rec := httptest.NewRecorder()
New(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
})(rec, req)
if rec.Code != http.StatusNoContent {
t.Fatalf("expected %d got %d", http.StatusNoContent, rec.Code)
}
if rec.HeaderMap.Get("Content-Encoding") != "" {
t.Fatalf(`expected Content-Encoding: "" got %s`, rec.HeaderMap.Get("Content-Encoding"))
}
if rec.Body.Len() > 0 {
t.Logf("%q", rec.Body.String())
t.Fatalf("no body expected for %d bytes", rec.Body.Len())
}
if testing.Verbose() {
b, _ := httputil.DumpResponse(rec.Result(), true)
t.Log("\n" + string(b))
}
}
func BenchmarkGzip(b *testing.B) {
body := []byte("testtesttesttesttesttesttesttesttesttesttesttesttest")
req, err := http.NewRequest("GET", "http://example.com/", nil)
if err != nil {
b.Fatal(err)
}
req.Header.Set("Accept-Encoding", "gzip, deflate")
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
rec := httptest.NewRecorder()
New(func(w http.ResponseWriter, r *http.Request) {
w.Write(body)
})(rec, req)
if rec.Code != http.StatusOK {
b.Fatalf("expected %d got %d", http.StatusOK, rec.Code)
}
if rec.Body.Len() != 49 {
b.Fatalf("expected 49 bytes, got %d bytes", rec.Body.Len())
}
}
})
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment