Skip to content

Instantly share code, notes, and snippets.

@kirk91
Last active January 9, 2020 10:43
Show Gist options
  • Save kirk91/4e20f71d4d7144918172e61d9f315d41 to your computer and use it in GitHub Desktop.
Save kirk91/4e20f71d4d7144918172e61d9f315d41 to your computer and use it in GitHub Desktop.
Replay redis request traffic
package main
import (
"bufio"
"context"
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"os"
"os/signal"
"strconv"
"sync"
"syscall"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/pcap"
"github.com/google/gopacket/tcpassembly"
"github.com/google/gopacket/tcpassembly/tcpreader"
)
var (
file = flag.String("file", "", "pcap file which contains redis traffic")
times = flag.Int("times", 1, "replay times")
targetHost = flag.String("target-host", "127.0.0.1", "redis target host")
targetPort = flag.Int("target-port", 6379, "redis target port")
)
type redisStreamFactory struct{}
type redisStream struct {
net, transport gopacket.Flow
r tcpreader.ReaderStream
}
func (*redisStreamFactory) New(net, transport gopacket.Flow) tcpassembly.Stream {
stream := &redisStream{
net: net,
transport: transport,
r: tcpreader.NewReaderStream(),
}
go stream.Run()
return &stream.r
}
func (r *redisStream) Run() {
br := bufio.NewReader(&r.r)
defer io.Copy(ioutil.Discard, br)
nextReqHead, err := skipPartialRequest(br)
if err != nil {
log.Println("Error reading stream", r.net, r.transport, ":", err)
return
}
// redo the following requests
addr := fmt.Sprintf("%s:%d", *targetHost, *targetPort)
conn, err := net.Dial("tcp", addr)
if err != nil {
log.Printf("Error dialing to %s: %v\n", addr, err)
return
}
defer conn.Close()
if _, err := conn.Write([]byte(nextReqHead)); err != nil {
log.Printf("Error write to %s: %v\n", addr, err)
return
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
io.Copy(conn, br)
conn.Close()
wg.Done()
}()
go func() {
io.Copy(ioutil.Discard, conn)
wg.Done()
}()
wg.Wait()
}
// skipPartialRequest skips the remaining data of last request, it's best-effort.
func skipPartialRequest(br *bufio.Reader) (nextReqHead string, err error) {
for {
buf, err := br.ReadBytes('*')
if err != nil {
return "", err
}
// check CRLF characters before '*'
l := len(buf)
if l > 1 {
if buf[l-2] != '\n' {
continue
}
}
if l > 2 {
if buf[l-3] != '\r' {
continue
}
}
// check if the next characters are number
buf, err = br.ReadBytes('\n')
if err != nil {
return "", err
}
l = len(buf)
if l <= 2 || buf[l-2] != '\r' {
continue
}
n, err := strconv.Atoi(string(buf[:l-2]))
if err != nil {
continue
}
return fmt.Sprintf("*%d\r\n", n), nil
}
}
func replay(ctx context.Context) {
handle, err := pcap.OpenOffline(*file)
if err != nil {
log.Fatal(err)
}
defer handle.Close()
// Set up assembly
streamFactory := &redisStreamFactory{}
streamPool := tcpassembly.NewStreamPool(streamFactory)
assembler := tcpassembly.NewAssembler(streamPool)
defer assembler.FlushAll()
// Create packet source
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
packets := packetSource.Packets()
for {
select {
case packet := <-packets:
// A nil packet indicates the end of a pcap file.
if packet == nil {
return
}
if packet.NetworkLayer() == nil ||
packet.TransportLayer() == nil ||
packet.TransportLayer().LayerType() != layers.LayerTypeTCP {
log.Println("Unusable packet")
continue
}
tcp := packet.TransportLayer().(*layers.TCP)
assembler.AssembleWithTimestamp(packet.NetworkLayer().NetworkFlow(), tcp, packet.Metadata().Timestamp)
case <-ctx.Done():
return
}
}
}
func main() {
flag.Parse()
if *file == "" {
fmt.Println("Please specify the pcap file")
os.Exit(1)
}
ctx, cancel := context.WithCancel(context.Background())
signalCh := make(chan os.Signal, 1)
signal.Notify(signalCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-signalCh
cancel()
}()
for i := 0; i < *times; i++ {
replay(ctx)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment