Last active
July 11, 2018 18:07
-
-
Save ghetzel/d780121a7e5a887f4c0a8ce15e9db1e5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bufio" | |
"bytes" | |
) | |
type SubsequenceHandlerFunc func(seq []byte) | |
type ScanInterceptor struct { | |
accumulator *bytes.Buffer | |
subsequences map[string]SubsequenceHandlerFunc | |
longestSubsequence int | |
totalWritten int64 | |
highWaterMark map[string]int64 | |
passthrough bufio.SplitFunc | |
} | |
// A ScanInterceptor is used as a SplitFunc on a bufio.Scanner. It will look at the stream of bytes being scanned for | |
// specific substrings. The registered handler function associated with a substring will be called whenever it is seen | |
// in the stream. The passthrough SplitFunc is called as normal. This allows for a stream to be | |
// split and processed while also being inspected for specific content, allowing the user to react to that content | |
// as it comes by. | |
func NewScanInterceptor(passthrough bufio.SplitFunc, intercepts ...map[string]SubsequenceHandlerFunc) *ScanInterceptor { | |
var intercept map[string]SubsequenceHandlerFunc | |
if len(intercepts) == 0 { | |
intercept = make(map[string]SubsequenceHandlerFunc) | |
} else { | |
intercept = intercepts[0] | |
} | |
// return a new, empty interceptor | |
return &ScanInterceptor{ | |
passthrough: passthrough, | |
accumulator: bytes.NewBuffer(nil), | |
subsequences: intercept, | |
highWaterMark: make(map[string]int64), | |
} | |
} | |
// Add an intercept sequence and handler. If the sequence is already registered, its handler | |
// function will be replaced with this one. | |
func (self *ScanInterceptor) Intercept(sequence string, handler SubsequenceHandlerFunc) { | |
self.subsequences[sequence] = handler | |
for k, _ := range self.subsequences { | |
if len(k) > self.longestSubsequence { | |
self.longestSubsequence = len(k) | |
} | |
} | |
} | |
// Implements the bufio.SplitFunc interface. | |
func (self *ScanInterceptor) Scan(data []byte, atEOF bool) (advance int, token []byte, err error) { | |
if _, err := self.accumulator.Write(data); err != nil { | |
return 0, nil, err | |
} | |
// if we've accumulated *at least* as many bytes as our longest subsequence, then | |
// we go to work... | |
if processedLen := self.accumulator.Len(); processedLen >= self.longestSubsequence { | |
// get the bytes we've accumulated since start or the last time we reset | |
soFar := self.accumulator.Bytes() | |
// for each registered subsequence... | |
for k, handler := range self.subsequences { | |
subseq := []byte(k) | |
// skip zero-length matches | |
if len(subseq) == 0 { | |
continue | |
} | |
// the High Water Mark (HWM) represents the furthest we've ever gotten in the stream. | |
// we make sure that our current HWM is *before* the end of the stream, so that if this | |
// SplitFunc is called repeatedy for the same data (which can happen), we're not firing off | |
// multiple handler calls for the same position(s). | |
// | |
if self.highWaterMark[k] > self.totalWritten { | |
continue | |
} | |
// find the index in the stream of our match (if any) | |
if indexOf := bytes.Index(soFar, subseq); indexOf >= 0 { | |
// mark the end of the stream (so we ensure we dont fire events for anything before this point) | |
endIndex := indexOf + len(subseq) | |
// fire the handler | |
handler(soFar[indexOf:endIndex]) | |
// advance the HWM for this interceptor past this result | |
self.highWaterMark[k] = self.totalWritten + int64(endIndex) | |
} | |
} | |
// reset the accumulator, we go again! | |
self.accumulator = bytes.NewBuffer(nil) | |
} | |
// call the SplitFunc we were given | |
advance, token, err = self.passthrough(data, atEOF) | |
// however far we just advanced (if at all), keep track of that | |
self.totalWritten += int64(advance) | |
// return the results of the SplitFunc we were given | |
return advance, token, err | |
} | |
// Return the total number of bytes this scanner has scanned. | |
func (self *ScanInterceptor) BytesScanned() int64 { | |
return self.totalWritten | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bufio" | |
"bytes" | |
"testing" | |
"github.com/stretchr/testify/require" | |
) | |
func TestScanInterceptorNothing(t *testing.T) { | |
assert := require.New(t) | |
var lines []string | |
splitter := NewScanInterceptor(bufio.ScanLines) | |
data := bytes.NewBuffer([]byte("first\nsecond\nthird\n")) | |
scanner := bufio.NewScanner(data) | |
scanner.Split(splitter.Scan) | |
for scanner.Scan() { | |
lines = append(lines, scanner.Text()) | |
} | |
assert.NoError(scanner.Err()) | |
assert.Equal([]string{ | |
`first`, | |
`second`, | |
`third`, | |
}, lines) | |
} | |
// test single subsequence | |
// --------------------------------------------------------------------------------------------- | |
func TestScanInterceptorSingle(t *testing.T) { | |
assert := require.New(t) | |
errors := 0 | |
prompts := 0 | |
var lines []string | |
splitter := NewScanInterceptor(bufio.ScanLines, map[string]SubsequenceHandlerFunc{ | |
`[error] `: func(seq []byte) { | |
errors += 1 | |
}, | |
` password: `: func(seq []byte) { | |
prompts += 1 | |
}, | |
`Password: `: func(seq []byte) { | |
prompts += 1 | |
}, | |
}) | |
data := bytes.NewBuffer([]byte( | |
"Warning: Permanently added '[127.0.0.1]:2200' (ECDSA) to the list of known hosts.\n" + | |
"[email protected]'s password: ", | |
)) | |
scanner := bufio.NewScanner(data) | |
scanner.Split(splitter.Scan) | |
for scanner.Scan() { | |
lines = append(lines, scanner.Text()) | |
} | |
assert.NoError(scanner.Err()) | |
assert.Equal(0, errors) | |
assert.Equal(1, prompts) | |
assert.Equal([]string{ | |
`Warning: Permanently added '[127.0.0.1]:2200' (ECDSA) to the list of known hosts.`, | |
`[email protected]'s password: `, | |
}, lines) | |
} | |
// test multiple subsequences | |
// --------------------------------------------------------------------------------------------- | |
func TestScanInterceptorMultiple(t *testing.T) { | |
assert := require.New(t) | |
errors := 0 | |
prompts := 0 | |
var lines []string | |
splitter := NewScanInterceptor(bufio.ScanLines, map[string]SubsequenceHandlerFunc{ | |
`[error] `: func(seq []byte) { | |
errors += 1 | |
}, | |
` password: `: func(seq []byte) { | |
prompts += 1 | |
}, | |
`Password: `: func(seq []byte) { | |
prompts += 1 | |
}, | |
}) | |
data := bytes.NewBuffer([]byte( | |
"Password: [error] something cool went wrong\n" + | |
"[email protected]'s password: ", | |
)) | |
scanner := bufio.NewScanner(data) | |
scanner.Split(splitter.Scan) | |
for scanner.Scan() { | |
lines = append(lines, scanner.Text()) | |
} | |
assert.NoError(scanner.Err()) | |
assert.Equal(1, errors) | |
assert.Equal(2, prompts) | |
assert.Equal([]string{ | |
`Password: [error] something cool went wrong`, | |
`[email protected]'s password: `, | |
}, lines) | |
} | |
// test add intercept after the fact | |
// --------------------------------------------------------------------------------------------- | |
func TestScanInterceptorAddIntercept(t *testing.T) { | |
assert := require.New(t) | |
errors := 0 | |
warnings := 0 | |
var lines []string | |
splitter := NewScanInterceptor(bufio.ScanLines, map[string]SubsequenceHandlerFunc{ | |
`[error] `: func(seq []byte) { | |
errors += 1 | |
}, | |
}) | |
data := bytes.NewBuffer([]byte( | |
"Warning: Permanently added '[127.0.0.1]:2200' (ECDSA) to the list of known hosts.\n" + | |
"[error] something cool went wrong\n", | |
)) | |
scanner := bufio.NewScanner(data) | |
scanner.Split(splitter.Scan) | |
for scanner.Scan() { | |
lines = append(lines, scanner.Text()) | |
} | |
assert.NoError(scanner.Err()) | |
assert.Equal(1, errors) | |
assert.Equal(0, warnings) | |
assert.Equal([]string{ | |
`Warning: Permanently added '[127.0.0.1]:2200' (ECDSA) to the list of known hosts.`, | |
`[error] something cool went wrong`, | |
}, lines) | |
// new scanner, same interceptor, add new data | |
scanner = bufio.NewScanner(data) | |
scanner.Split(splitter.Scan) | |
splitter.Intercept(`Warning:`, func(seq []byte) { | |
warnings += 1 | |
}) | |
lines = nil | |
data.WriteString("some cool stuff going on OH NOOOO Warning: NOOOOOOO\n") | |
for scanner.Scan() { | |
lines = append(lines, scanner.Text()) | |
} | |
assert.NoError(scanner.Err()) | |
assert.Equal(1, warnings) | |
assert.Equal([]string{ | |
`some cool stuff going on OH NOOOO Warning: NOOOOOOO`, | |
}, lines) | |
} | |
func TestScanInterceptorBinarySubsequence(t *testing.T) { | |
assert := require.New(t) | |
terminators := 0 | |
splitter := NewScanInterceptor(bufio.ScanBytes) | |
data := bytes.NewBuffer([]byte{ | |
0x71, 0x00, 0x5d, 0x13, 0xfe, 0x05, 0xff, 0xff, | |
0xe7, 0xfe, 0x00, 0x16, 0x20, 0x02, 0x07, 0x5d, | |
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, | |
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, | |
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, | |
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, | |
0xaa, 0x55, | |
}) | |
splitter.Intercept(string([]byte{0xAA, 0x55}), func(seq []byte) { | |
terminators += 1 | |
}) | |
scanner := bufio.NewScanner(data) | |
scanner.Split(splitter.Scan) | |
for scanner.Scan() { | |
continue | |
} | |
assert.NoError(scanner.Err()) | |
assert.Equal(1, terminators) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment