Created
January 23, 2016 19:47
-
-
Save eliquious/b4add2899fa6107d0941 to your computer and use it in GitHub Desktop.
LMAX Disruptor TCP Server
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//CLIENT | |
package main | |
import ( | |
"bufio" | |
"fmt" | |
"io" | |
"net" | |
"runtime" | |
"sync" | |
"sync/atomic" | |
"time" | |
) | |
var requestCount uint64 | |
var totalPingsPerConnection uint64 = 10000000 | |
var concurrentConnections uint64 = 32 | |
var totalPings = concurrentConnections * totalPingsPerConnection | |
func monitor(done chan bool) chan bool { | |
out := make(chan bool) | |
go func() { | |
var last uint64 | |
start := time.Now() | |
for { | |
select { | |
case <-done: | |
elapsed := time.Since(start) | |
fmt.Printf("%f ns\n", float64(elapsed)/float64(requestCount)) | |
fmt.Printf("%d requests\n", requestCount) | |
fmt.Printf("%f requests per second\n", float64(time.Second)/(float64(elapsed)/float64(requestCount))) | |
fmt.Printf("elapsed: %s\r\n", elapsed) | |
out <- true | |
return | |
case <-time.After(1 * time.Second): | |
current := atomic.LoadUint64(&requestCount) | |
fmt.Printf("%d combined requests per second (%d)\n", current-last, current) | |
last = current | |
if current >= uint64(totalPings) { | |
return | |
} | |
} | |
} | |
}() | |
return out | |
} | |
func (c *client) readLoop(wg *sync.WaitGroup) { | |
defer wg.Done() | |
rd := bufio.NewReader(c.conn) | |
buf := make([]byte, 4) | |
for atomic.LoadUint64(&c.revcd) < totalPingsPerConnection { | |
n, err := rd.Read(buf) | |
if n > 0 { | |
atomic.AddUint64(&c.revcd, 1) | |
atomic.AddUint64(&requestCount, 1) | |
} else if err == io.EOF { | |
return | |
} | |
if err != nil && err != io.EOF { | |
fmt.Println(err) | |
return | |
} | |
} | |
// fmt.Printf("total recvd: %d\r\n", c.revcd) | |
} | |
func (c *client) writeLoop(wg *sync.WaitGroup) { | |
defer wg.Done() | |
wr := bufio.NewWriterSize(c.conn, 65536) | |
outBuf := []byte("Ping") | |
// var buffered int | |
for atomic.LoadUint64(&c.sent) < totalPingsPerConnection { | |
n, err := wr.Write(outBuf) | |
if n > 0 { | |
} | |
if err != nil && err != io.EOF { | |
fmt.Println(err) | |
return | |
} | |
atomic.AddUint64(&c.sent, 1) | |
} | |
wr.Flush() | |
// fmt.Printf("total sent: %d\r\n", c.sent) | |
} | |
const RingBufferCapacity = 1024 * 1024 | |
type client struct { | |
sent uint64 | |
revcd uint64 | |
conn *net.TCPConn | |
} | |
func NewClient(wg *sync.WaitGroup) { | |
defer wg.Done() | |
tcpAddr, _ := net.ResolveTCPAddr("tcp4", "localhost:9022") | |
conn, err := net.DialTCP("tcp", nil, tcpAddr) | |
if err != nil { | |
fmt.Println(err) | |
return | |
} | |
var w sync.WaitGroup | |
c := client{conn: conn} | |
w.Add(2) | |
go c.writeLoop(&w) | |
go c.readLoop(&w) | |
w.Wait() | |
conn.Close() | |
} | |
func main() { | |
runtime.GOMAXPROCS(8) | |
var wg sync.WaitGroup | |
done := make(chan bool) | |
c := monitor(done) | |
for i := uint64(0); i < concurrentConnections; i++ { | |
wg.Add(1) | |
go NewClient(&wg) | |
} | |
wg.Wait() | |
done <- true | |
<-c | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"fmt" | |
"io" | |
"log" | |
"net" | |
"os" | |
"runtime" | |
"time" | |
disruptor "github.com/smartystreets/go-disruptor" | |
"golang.org/x/net/context" | |
) | |
func main() { | |
runtime.GOMAXPROCS(8) | |
server := Server{} | |
server.Start(":9022") | |
} | |
// Server handles all the incoming connections as well as handler dispatch. | |
type Server struct { | |
Logger *log.Logger | |
Addr *net.TCPAddr | |
listener *net.TCPListener | |
context context.Context | |
cancel context.CancelFunc | |
} | |
// Start starts accepting client connections. This method is non-blocking. | |
func (s *Server) Start(addr string) (err error) { | |
// Validate the ssh bind addr | |
if addr == "" { | |
err = fmt.Errorf("server: Empty bind address") | |
return | |
} | |
// Open SSH socket listener | |
netAddr, e := net.ResolveTCPAddr("tcp", addr) | |
if e != nil { | |
err = fmt.Errorf("server: Invalid tcp address") | |
return | |
} | |
// Create listener | |
listener, e := net.ListenTCP("tcp", netAddr) | |
if e != nil { | |
err = e | |
return | |
} | |
s.Logger = log.New(os.Stdout, "logger: ", log.Lshortfile) | |
s.listener = listener | |
s.Addr = listener.Addr().(*net.TCPAddr) | |
s.Logger.Println("Starting server", "addr", addr) | |
c, cancel := context.WithCancel(context.Background()) | |
s.context = c | |
s.cancel = cancel | |
go s.listen(c) | |
<-c.Done() | |
return | |
} | |
// Stop stops the server and kills all goroutines. This method is blocking. | |
func (s *Server) Stop() { | |
s.Logger.Println("[INFO] Shutting down server...") | |
s.cancel() | |
} | |
// listen accepts new connections and handles the conversion from TCP to SSH connections. | |
func (s *Server) listen(c context.Context) { | |
defer s.listener.Close() | |
for { | |
// Accepts will only block for 1s | |
s.listener.SetDeadline(time.Now().Add(time.Second)) | |
select { | |
// Stop server on channel receive | |
case <-c.Done(): | |
s.Logger.Println("[DEBUG] Context Completed") | |
return | |
default: | |
// Accept new connection | |
tcpConn, err := s.listener.Accept() | |
if err != nil { | |
if neterr, ok := err.(net.Error); ok && neterr.Timeout() { | |
// s.Logger.Println("[DBG] Connection timeout...") | |
} else { | |
s.Logger.Println("[WRN] Connection failed", "error", err) | |
} | |
continue | |
} | |
// Handle connection | |
s.Logger.Println("[INF] Successful TCP connection:", tcpConn.RemoteAddr().String()) | |
h := NewTcpHandler(tcpConn, s.Logger, s.context) | |
go h.Execute() | |
} | |
} | |
} | |
const RingBufferCapacity = 16 * 1024 * 1024 | |
const RingBufferMask = RingBufferCapacity - 1 | |
func NewTcpHandler(conn net.Conn, logger *log.Logger, ctx context.Context) *tcpHandler { | |
ring := [RingBufferCapacity]byte{} | |
controller := disruptor. | |
Configure(RingBufferCapacity). | |
WithConsumerGroup(&ByteConsumer{Writer: conn, ring: &ring}). | |
Build() | |
controller.Start() | |
c, cancel := context.WithCancel(ctx) | |
return &tcpHandler{ | |
logger, conn, &ring, &controller, c, cancel, | |
} | |
} | |
type tcpHandler struct { | |
logger *log.Logger | |
conn net.Conn | |
ring *[RingBufferCapacity]byte | |
controller *disruptor.Disruptor | |
context context.Context | |
cancel context.CancelFunc | |
} | |
func (t *tcpHandler) Execute() { | |
defer t.conn.Close() | |
defer t.controller.Stop() | |
// Read from connection | |
go t.createReadLoop() | |
<-t.context.Done() | |
} | |
func (t *tcpHandler) createReadLoop() { | |
defer t.cancel() | |
writer := t.controller.Writer() | |
buffer := make([]byte, 65336) | |
var seq int64 | |
for { | |
select { | |
case <-t.context.Done(): | |
return | |
default: | |
n, err := t.conn.Read(buffer) | |
// fmt.Printf("n: %d; err: %s\r\n", n, err) | |
if n > 0 { | |
seq = writer.Reserve(int64(n)) | |
for i := 0; i < n; i++ { | |
t.ring[seq&RingBufferMask] = buffer[i] | |
} | |
writer.Commit(seq, seq+int64(n)) | |
} else if err == io.EOF { | |
return | |
} | |
if err != nil && err != io.EOF { | |
return | |
} | |
} | |
} | |
} | |
type ByteConsumer struct { | |
Writer io.Writer | |
ring *[RingBufferCapacity]byte | |
buffer [65336]byte | |
} | |
func (b *ByteConsumer) Consume(lower, upper int64) { | |
var offset int | |
for lower <= upper { | |
if offset >= 65336 { | |
b.Writer.Write(b.buffer[:]) | |
offset = 0 | |
} | |
b.buffer[offset] = b.ring[lower&RingBufferMask] | |
offset++ | |
lower++ | |
} | |
if offset >= 0 { | |
b.Writer.Write(b.buffer[:offset]) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment