Created
September 21, 2019 14:04
-
-
Save FZambia/d67f7ff4de9aa8a706c293b8eaf5532c to your computer and use it in GitHub Desktop.
Klauspost compress library as replacement for std lib in Gorilla Websocket
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/compression.go b/compression.go | |
index 813ffb1..4c492e0 100644 | |
--- a/compression.go | |
+++ b/compression.go | |
@@ -41,16 +41,47 @@ func isValidCompressionLevel(level int) bool { | |
return minCompressionLevel <= level && level <= maxCompressionLevel | |
} | |
-func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { | |
+// FlateWriter ... | |
+type FlateWriter interface { | |
+ Write(data []byte) (n int, err error) | |
+ Reset(dst io.Writer) | |
+ Flush() error | |
+ Close() error | |
+} | |
+ | |
+func defaultAcquireFlateWriter(w io.Writer, level int) FlateWriter { | |
p := &flateWriterPools[level-minCompressionLevel] | |
- tw := &truncWriter{w: w} | |
fw, _ := p.Get().(*flate.Writer) | |
if fw == nil { | |
- fw, _ = flate.NewWriter(tw, level) | |
+ fw, _ = flate.NewWriter(w, level) | |
} else { | |
- fw.Reset(tw) | |
+ fw.Reset(w) | |
} | |
- return &flateWriteWrapper{fw: fw, tw: tw, p: p} | |
+ return &poolFlateWriter{fw, p} | |
+} | |
+ | |
+type poolFlateWriter struct { | |
+ *flate.Writer | |
+ p *sync.Pool | |
+} | |
+ | |
+func (w *poolFlateWriter) Close() error { | |
+ w.p.Put(w.Writer) | |
+ return nil | |
+} | |
+ | |
+func compressNoContextTakeoverFlateWriter(acquireFlateWriter func(w io.Writer, level int) FlateWriter) func(w io.WriteCloser, level int) io.WriteCloser { | |
+ return func(w io.WriteCloser, level int) io.WriteCloser { | |
+ tw := &truncWriter{w: w} | |
+ fw := acquireFlateWriter(tw, level) | |
+ return &flateWriteWrapper{fw: fw, tw: tw} | |
+ } | |
+} | |
+ | |
+func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { | |
+ tw := &truncWriter{w: w} | |
+ fw := defaultAcquireFlateWriter(tw, level) | |
+ return &flateWriteWrapper{fw: fw, tw: tw} | |
} | |
// truncWriter is an io.Writer that writes all but the last four bytes of the | |
@@ -90,9 +121,8 @@ func (w *truncWriter) Write(p []byte) (int, error) { | |
} | |
type flateWriteWrapper struct { | |
- fw *flate.Writer | |
+ fw FlateWriter | |
tw *truncWriter | |
- p *sync.Pool | |
} | |
func (w *flateWriteWrapper) Write(p []byte) (int, error) { | |
@@ -107,16 +137,19 @@ func (w *flateWriteWrapper) Close() error { | |
return errWriteClosed | |
} | |
err1 := w.fw.Flush() | |
- w.p.Put(w.fw) | |
+ err2 := w.fw.Close() | |
w.fw = nil | |
if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { | |
return errors.New("websocket: internal error, unexpected bytes at end of flate stream") | |
} | |
- err2 := w.tw.w.Close() | |
+ err3 := w.tw.w.Close() | |
if err1 != nil { | |
return err1 | |
} | |
- return err2 | |
+ if err2 != nil { | |
+ return err1 | |
+ } | |
+ return err3 | |
} | |
type flateReadWrapper struct { | |
diff --git a/compression_test.go b/compression_test.go | |
index 8a26b30..b089e64 100644 | |
--- a/compression_test.go | |
+++ b/compression_test.go | |
@@ -5,7 +5,10 @@ import ( | |
"fmt" | |
"io" | |
"io/ioutil" | |
+ "sync" | |
"testing" | |
+ | |
+ customFlate "github.com/klauspost/compress/flate" | |
) | |
type nopCloser struct{ io.Writer } | |
@@ -65,6 +68,44 @@ func BenchmarkWriteWithCompression(b *testing.B) { | |
b.ReportAllocs() | |
} | |
+var ( | |
+ customFlateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool | |
+) | |
+ | |
+func acquireCustomFlateWriter(w io.Writer, level int) FlateWriter { | |
+ p := &customFlateWriterPools[level-minCompressionLevel] | |
+ fw, _ := p.Get().(*customFlate.Writer) | |
+ if fw == nil { | |
+ fw, _ = customFlate.NewWriter(w, level) | |
+ } else { | |
+ fw.Reset(w) | |
+ } | |
+ return &customPoolFlateWriter{fw, p} | |
+} | |
+ | |
+type customPoolFlateWriter struct { | |
+ *customFlate.Writer | |
+ p *sync.Pool | |
+} | |
+ | |
+func (w *customPoolFlateWriter) Close() error { | |
+ w.p.Put(w.Writer) | |
+ return nil | |
+} | |
+ | |
+func BenchmarkWriteWithCompressionCustom(b *testing.B) { | |
+ w := ioutil.Discard | |
+ c := newTestConn(nil, w, false) | |
+ messages := textMessages(100) | |
+ c.enableWriteCompression = true | |
+ c.newCompressionWriter = compressNoContextTakeoverFlateWriter(acquireCustomFlateWriter) | |
+ b.ResetTimer() | |
+ for i := 0; i < b.N; i++ { | |
+ c.WriteMessage(TextMessage, messages[i%len(messages)]) | |
+ } | |
+ b.ReportAllocs() | |
+} | |
+ | |
func TestValidCompressionLevel(t *testing.T) { | |
c := newTestConn(nil, nil, false) | |
for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} { | |
diff --git a/conn_broadcast_test.go b/conn_broadcast_test.go | |
index cb88cbb..75bff5b 100644 | |
--- a/conn_broadcast_test.go | |
+++ b/conn_broadcast_test.go | |
@@ -71,7 +71,7 @@ func (b *broadcastBench) makeConns(numConns int) { | |
c := newTestConn(nil, b.w, true) | |
if b.compression { | |
c.enableWriteCompression = true | |
- c.newCompressionWriter = compressNoContextTakeover | |
+ c.newCompressionWriter = compressNoContextTakeoverFlateWriter(acquireCustomFlateWriter) | |
} | |
conns[i] = newBroadcastConn(c) | |
go func(c *broadcastConn) { | |
diff --git a/server.go b/server.go | |
index 887d558..6e6c5be 100644 | |
--- a/server.go | |
+++ b/server.go | |
@@ -70,6 +70,8 @@ type Upgrader struct { | |
// guarantee that compression will be supported. Currently only "no context | |
// takeover" modes are supported. | |
EnableCompression bool | |
+ | |
+ AcquireFlateWriter func(w io.Writer, level int) FlateWriter | |
} | |
func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { | |
@@ -203,7 +205,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade | |
c.subprotocol = subprotocol | |
if compress { | |
- c.newCompressionWriter = compressNoContextTakeover | |
+ if u.AcquireFlateWriter != nil { | |
+ c.newCompressionWriter = compressNoContextTakeoverFlateWriter(u.AcquireFlateWriter) | |
+ } else { | |
+ c.newCompressionWriter = compressNoContextTakeover | |
+ } | |
c.newDecompressionReader = decompressNoContextTakeover | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment