Last active
May 14, 2020 17:17
-
-
Save bwplotka/55a383a218f40f0a8f59da7851018c68 to your computer and use it in GitHub Desktop.
Go Replayable io.Reader: Useful when you want to share slice of bytes across many io.Reader sequential consumers). E.g the same reqest.Body in HTTP server used by multiple RoundTrippers (!).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package replayable | |
import ( | |
"bytes" | |
"io" | |
) | |
type Reader struct { | |
wrapped io.Reader | |
buf []byte | |
offset int | |
} | |
// Rewind allows replayable Reader to be read again. | |
func (b *Reader) Rewind() { | |
if b == nil { | |
return | |
} | |
b.offset = 0 | |
} | |
func (b *Reader) Read(p []byte) (n int, err error) { | |
if b == nil { | |
return 0, io.EOF | |
} | |
if len(b.buf)-b.offset > 0 { | |
n, err = bytes.NewReader(b.buf[b.offset:]).Read(p) | |
b.offset += n | |
} | |
if err == nil && n < len(p) { | |
var n64 int64 | |
// Try to buffer rest (if needed) from wrapped io.Reader. | |
tmp := bytes.NewBuffer(b.buf) | |
n64, err = tmp.ReadFrom(io.LimitReader(b.wrapped, int64(len(p)-n))) | |
b.buf = tmp.Bytes() | |
if n64 > 0 { | |
copy(p[n:], b.buf[b.offset:]) | |
n += int(n64) | |
b.offset += int(n64) | |
} | |
} | |
// Buffer.ReadFrom masks io.EOF so we assume EOF once n == 0 and no error. | |
if err == nil && n == 0 && len(p) > 0 { | |
return 0, io.EOF | |
} | |
return n, err | |
} | |
// NewReader returns replayable.Reader. | |
// The content read from the source is buffered in a lazy fashion to keep storage requirements | |
// limited to a minimum while still allowing for the reader to be rewinded and previously read | |
// content to be replayed. | |
func NewReader(src io.Reader) *Reader { | |
return &Reader{wrapped: src} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package replayable | |
import ( | |
"bytes" | |
"io" | |
"testing" | |
"github.com/stretchr/testify/require" | |
) | |
func TestReplayableReader(t *testing.T) { | |
for _, tcase := range []struct { | |
name string | |
src io.Reader | |
sequentialReadBytes []int | |
rewindBeforeRead []bool | |
expectedBytes [][]byte | |
expectedErrs []error | |
}{ | |
{ | |
name: "WrappedNil_Read_ShouldReturnEOF", | |
src: nil, | |
sequentialReadBytes: []int{10}, | |
rewindBeforeRead: []bool{false}, | |
expectedBytes: [][]byte{{}}, | |
expectedErrs: []error{io.EOF}, | |
}, | |
{ | |
name: "WrappedNil_RewindRead_ShouldReturnEOF", | |
src: nil, | |
sequentialReadBytes: []int{10}, | |
rewindBeforeRead: []bool{true}, | |
expectedBytes: [][]byte{{}}, | |
expectedErrs: []error{io.EOF}, | |
}, | |
{ | |
name: "SmallBigBigReads_FinishedWithEOF", | |
src: bytes.NewReader([]byte{1, 2, 3, 4}), | |
sequentialReadBytes: []int{1, 8192, 8192}, | |
rewindBeforeRead: []bool{false, false, false}, | |
expectedBytes: [][]byte{{1}, {2, 3, 4}, {}}, | |
expectedErrs: []error{nil, nil, io.EOF}, | |
}, | |
{ | |
name: "SmallReads_FinishedWithEOF", | |
src: bytes.NewReader([]byte{1, 2, 3, 4}), | |
sequentialReadBytes: []int{1, 2, 4, 1}, | |
rewindBeforeRead: []bool{false, false, false, false}, | |
expectedBytes: [][]byte{{1}, {2, 3}, {4}, {}}, | |
expectedErrs: []error{nil, nil, nil, io.EOF}, | |
}, | |
{ | |
name: "SmallReadsTakingExactBytes", | |
src: bytes.NewReader([]byte{1, 2, 3, 4, 5}), | |
sequentialReadBytes: []int{1, 2, 2}, | |
rewindBeforeRead: []bool{false, false, false}, | |
expectedBytes: [][]byte{{1}, {2, 3}, {4, 5}}, | |
expectedErrs: []error{nil, nil, nil}, | |
}, | |
{ | |
name: "SmallReadsRewindSmallRead", | |
src: bytes.NewReader([]byte{1, 2, 3, 4, 5}), | |
sequentialReadBytes: []int{1, 2, 4, 2}, | |
rewindBeforeRead: []bool{false, false, true, false}, | |
expectedBytes: [][]byte{{1}, {2, 3}, {1, 2, 3, 4}, {5}}, | |
expectedErrs: []error{nil, nil, nil, nil}, | |
}, | |
{ | |
name: "BigReadRewindSmallReads", | |
src: bytes.NewReader([]byte{1, 2, 3, 4}), | |
sequentialReadBytes: []int{8192, 2, 3}, | |
rewindBeforeRead: []bool{false, true, false}, | |
expectedBytes: [][]byte{{1, 2, 3, 4}, {1, 2}, {3, 4}}, | |
expectedErrs: []error{nil, nil, nil}, | |
}, | |
{ | |
name: "BigReadRewindBigReadSmall_FinishedWithEOF", | |
src: bytes.NewReader([]byte{1, 2, 3, 4}), | |
sequentialReadBytes: []int{8192, 8192, 3}, | |
rewindBeforeRead: []bool{false, true, false}, | |
expectedBytes: [][]byte{{1, 2, 3, 4}, {1, 2, 3, 4}, {}}, | |
expectedErrs: []error{nil, nil, io.EOF}, | |
}, | |
} { | |
if ok := t.Run(tcase.name, func(t *testing.T) { | |
b := NewReader(tcase.src) | |
for i, read := range tcase.sequentialReadBytes { | |
if tcase.rewindBeforeRead[i] { | |
b.Rewind() | |
} | |
toRead := make([]byte, read) | |
n, err := b.Read(toRead) | |
require.Equal(t, tcase.expectedErrs[i], err, "read %d", i+1) | |
require.Len(t, tcase.expectedBytes[i], n, "read %d", i+1) | |
require.Equal(t, tcase.expectedBytes[i], toRead[:len(tcase.expectedBytes[i])], "read %d", i+1) | |
} | |
}); !ok { | |
return | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment