Last active
January 27, 2016 02:56
-
-
Save eliquious/90baf4ce72be7a17aac4 to your computer and use it in GitHub Desktop.
Simple key-value store in Go
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//CLIENT | |
package main | |
import ( | |
"bufio" | |
"fmt" | |
"io" | |
"math/rand" | |
"net" | |
"runtime" | |
"sync" | |
"sync/atomic" | |
"time" | |
) | |
var requestCount uint64 | |
var totalPingsPerConnection uint64 = 1000000 | |
var concurrentConnections uint64 = 128 | |
var totalPings = concurrentConnections * totalPingsPerConnection | |
func monitor(done chan bool) chan bool { | |
out := make(chan bool) | |
go func() { | |
var last uint64 | |
start := time.Now() | |
var elapsed time.Duration | |
OUTER: | |
for { | |
select { | |
case <-done: | |
break OUTER | |
case <-time.After(1 * time.Second): | |
current := atomic.LoadUint64(&requestCount) | |
fmt.Printf("%d combined requests per second (%d)\n", current-last, current) | |
if current >= uint64(totalPings) || current-last == 0 { | |
break OUTER | |
} | |
last = current | |
} | |
} | |
elapsed = time.Since(start) | |
fmt.Printf("%f ns\n", float64(elapsed)/float64(requestCount)) | |
fmt.Printf("%d requests\n", requestCount) | |
fmt.Printf("%f requests per second\n", float64(time.Second)/(float64(elapsed)/float64(requestCount))) | |
fmt.Printf("elapsed: %s\r\n", elapsed) | |
out <- true | |
return | |
}() | |
return out | |
} | |
func (c *client) readLoop(wg *sync.WaitGroup) { | |
defer wg.Done() | |
rd := bufio.NewReader(c.conn) | |
// buf := make([]byte, 1024) | |
for atomic.LoadUint64(&c.revcd) < totalPingsPerConnection { | |
line, _, err := rd.ReadLine() | |
if err != nil { | |
fmt.Println(err) | |
return | |
} | |
if len(line) > 0 && line[0] == '-' { | |
fmt.Println(string(line)) | |
return | |
} | |
atomic.AddUint64(&c.revcd, 1) | |
atomic.AddUint64(&requestCount, 1) | |
// n, err := rd.Read(buf) | |
// if n > 0 { | |
// // fmt.Println(string(buf[:n])) | |
// atomic.AddUint64(&c.revcd, 1) | |
// atomic.AddUint64(&requestCount, 1) | |
// } else if err == io.EOF { | |
// return | |
// } | |
// if err != nil && err != io.EOF { | |
// fmt.Println(err) | |
// return | |
// } | |
} | |
} | |
func (c *client) writeLoop(wg *sync.WaitGroup) { | |
defer wg.Done() | |
wr := bufio.NewWriterSize(c.conn, 65536) | |
outBuf := []byte(fmt.Sprintf("GET key%d\r\n", rand.Intn(8))) | |
// outBuf := []byte(fmt.Sprintf("SET key%d 0 value\r\n", rand.Intn(int(concurrentConnections)*8))) | |
for atomic.LoadUint64(&c.sent) < totalPingsPerConnection { | |
n, err := wr.Write(outBuf) | |
if n > 0 { | |
// wr.Flush() | |
} | |
if err != nil && err != io.EOF { | |
fmt.Println(err) | |
return | |
} | |
atomic.AddUint64(&c.sent, 1) | |
} | |
wr.Flush() | |
} | |
type client struct { | |
sent uint64 | |
revcd uint64 | |
conn *net.TCPConn | |
} | |
func NewClient(wg *sync.WaitGroup) { | |
defer wg.Done() | |
tcpAddr, _ := net.ResolveTCPAddr("tcp4", "localhost:9022") | |
conn, err := net.DialTCP("tcp", nil, tcpAddr) | |
if err != nil { | |
fmt.Println(err) | |
return | |
} | |
var w sync.WaitGroup | |
c := client{conn: conn} | |
w.Add(2) | |
go c.writeLoop(&w) | |
go c.readLoop(&w) | |
w.Wait() | |
conn.Close() | |
} | |
func main() { | |
runtime.GOMAXPROCS(2) | |
var wg sync.WaitGroup | |
done := make(chan bool) | |
c := monitor(done) | |
for i := uint64(0); i < concurrentConnections; i++ { | |
wg.Add(1) | |
go NewClient(&wg) | |
} | |
wg.Wait() | |
done <- true | |
<-c | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bufio" | |
"fmt" | |
"io" | |
"log" | |
"net" | |
"os" | |
"runtime" | |
"strconv" | |
"time" | |
// "github.com/pkg/profile" | |
"github.com/coocood/freecache" | |
disruptor "github.com/smartystreets/go-disruptor" | |
"golang.org/x/net/context" | |
) | |
func main() { | |
runtime.GOMAXPROCS(8) | |
// defer profile.Start(profile.CPUProfile, profile.ProfilePath(".")).Stop() | |
cache := freecache.NewCache(0) | |
cache.Set([]byte("key0"), []byte("value"), 0) | |
cache.Set([]byte("key1"), []byte("value"), 0) | |
cache.Set([]byte("key2"), []byte("value"), 0) | |
cache.Set([]byte("key3"), []byte("value"), 0) | |
cache.Set([]byte("key4"), []byte("value"), 0) | |
cache.Set([]byte("key5"), []byte("value"), 0) | |
cache.Set([]byte("key6"), []byte("value"), 0) | |
cache.Set([]byte("key7"), []byte("value"), 0) | |
cache.Set([]byte("key8"), []byte("value"), 0) | |
server := Server{Cache: cache} | |
server.Start(":9022") | |
} | |
// Server handles all the incoming connections as well as handler dispatch. | |
type Server struct { | |
Cache *freecache.Cache | |
Logger *log.Logger | |
Addr *net.TCPAddr | |
listener *net.TCPListener | |
context context.Context | |
cancel context.CancelFunc | |
} | |
// Start starts accepting client connections. This method is non-blocking. | |
func (s *Server) Start(addr string) (err error) { | |
// Validate the ssh bind addr | |
if addr == "" { | |
err = fmt.Errorf("server: Empty bind address") | |
return | |
} | |
// Open SSH socket listener | |
netAddr, e := net.ResolveTCPAddr("tcp", addr) | |
if e != nil { | |
err = fmt.Errorf("server: Invalid tcp address") | |
return | |
} | |
// Create listener | |
listener, e := net.ListenTCP("tcp", netAddr) | |
if e != nil { | |
err = e | |
return | |
} | |
s.Logger = log.New(os.Stdout, "logger: ", log.Lshortfile) | |
s.listener = listener | |
s.Addr = listener.Addr().(*net.TCPAddr) | |
s.Logger.Println("Starting server", "addr", addr) | |
c, cancel := context.WithCancel(context.Background()) | |
s.context = c | |
s.cancel = cancel | |
go s.listen(c) | |
<-c.Done() | |
return | |
} | |
// Stop stops the server and kills all goroutines. This method is blocking. | |
func (s *Server) Stop() { | |
s.Logger.Println("[INFO] Shutting down server...") | |
s.cancel() | |
} | |
// listen accepts new connections and handles the conversion from TCP to SSH connections. | |
func (s *Server) listen(c context.Context) { | |
defer s.listener.Close() | |
for { | |
// Accepts will only block for 1s | |
s.listener.SetDeadline(time.Now().Add(time.Second)) | |
select { | |
// Stop server on channel receive | |
case <-c.Done(): | |
s.Logger.Println("[DEBUG] Context Completed") | |
return | |
default: | |
// Accept new connection | |
tcpConn, err := s.listener.Accept() | |
if err != nil { | |
if neterr, ok := err.(net.Error); ok && neterr.Timeout() { | |
// s.Logger.Println("[DBG] Connection timeout...") | |
} else { | |
s.Logger.Println("[WRN] Connection failed", "error", err) | |
} | |
continue | |
} | |
// Handle connection | |
s.Logger.Println("[INF] Successful TCP connection:", tcpConn.RemoteAddr().String()) | |
h := NewTcpHandler(s.Cache, tcpConn, s.Logger, s.context) | |
go h.Execute() | |
} | |
} | |
} | |
const RingBufferCapacity = 4 * 1024 * 1024 | |
const RingBufferMask = RingBufferCapacity - 1 | |
func NewTcpHandler(cache *freecache.Cache, conn net.Conn, logger *log.Logger, ctx context.Context) *tcpHandler { | |
ring := [RingBufferCapacity]byte{} | |
controller := disruptor. | |
Configure(RingBufferCapacity). | |
WithConsumerGroup(&ByteConsumer{ | |
Writer: bufio.NewWriterSize(conn, 128*1024), | |
Closer: conn, | |
ring: &ring, | |
cache: cache, | |
logger: logger, | |
}).Build() | |
controller.Start() | |
c, cancel := context.WithCancel(ctx) | |
return &tcpHandler{cache, | |
logger, conn, &ring, &controller, c, cancel, | |
} | |
} | |
type tcpHandler struct { | |
cache *freecache.Cache | |
logger *log.Logger | |
conn net.Conn | |
ring *[RingBufferCapacity]byte | |
controller *disruptor.Disruptor | |
context context.Context | |
cancel context.CancelFunc | |
} | |
func (t *tcpHandler) Execute() { | |
defer t.conn.Close() | |
defer t.controller.Stop() | |
// Read from connection | |
go t.createReadLoop() | |
<-t.context.Done() | |
} | |
func (t *tcpHandler) createReadLoop() { | |
defer t.cancel() | |
writer := t.controller.Writer() | |
buffer := make([]byte, 64*1024) | |
var sequence, reservations int64 | |
var idx int | |
rd := bufio.NewReaderSize(t.conn, 1024*1024) | |
for { | |
select { | |
case <-t.context.Done(): | |
return | |
default: | |
n, err := rd.Read(buffer) | |
// t.logger.Printf("n: %d; err: %s\r\n", n, err) | |
// t.logger.Printf("body: %s\r\n", string(buffer[:n])) | |
if n > 0 { | |
idx = 0 | |
reservations = int64(n) | |
sequence = writer.Reserve(reservations) | |
for lower := sequence - reservations + 1; lower <= sequence; lower++ { | |
t.ring[lower&RingBufferMask] = buffer[idx] | |
idx++ | |
} | |
writer.Commit(sequence-reservations+1, sequence) | |
} else if err == io.EOF { | |
return | |
} | |
if err != nil && err != io.EOF { | |
return | |
} | |
} | |
} | |
} | |
var ErrMaxSize = []byte("-ERRMAXSIZE Request too large\r\n") | |
var ErrUnknownCmd = []byte("-ERRPARSE Unknown command\r\n") | |
var ErrIncompleteCmd = []byte("-ERRPARSE Incomplete command\r\n") | |
var ErrEmptyRequest = []byte("-ERRPARSE Empty request\r\n") | |
var ErrInvalidCmdDelimiter = []byte("-ERRPARSE Missing tab character after command\r\n") | |
var ErrInvalidKeyDelimiter = []byte("-ERRPARSE Missing tab character after key\r\n") | |
var ErrLargeKey = []byte("-ERRLARGEKEY The key is larger than 65535\r\n") | |
var ErrLargeEntry = []byte("-ERRLARGEENTRY The entry size is larger than 1/1024 of cache size\r\n") | |
var ErrNotFound = []byte("-ERRNOTFOUND Entry not found\r\n") | |
var ErrInvalidExpiration = []byte("-ERRINVEXP Invalid key expiration\r\n") | |
type ByteConsumer struct { | |
Writer *bufio.Writer | |
Closer io.Closer | |
logger *log.Logger | |
cache *freecache.Cache | |
ring *[RingBufferCapacity]byte | |
buffer [65336 * 4]byte | |
closed bool | |
requestSize int | |
} | |
func (b *ByteConsumer) Consume(lower, upper int64) { | |
if b.closed { | |
return | |
} | |
// b.logger.Printf("Consuming %d-%d\r\n", lower, upper) | |
var char byte | |
for sequence := lower; sequence <= upper; sequence++ { | |
if b.requestSize >= len(b.buffer) { | |
b.Writer.Write(ErrMaxSize) | |
b.logger.Printf("ERR %s\r\n", string(ErrMaxSize)) | |
b.Writer.Flush() | |
b.Closer.Close() | |
b.closed = true | |
return | |
} | |
char = b.ring[sequence&RingBufferMask] | |
// b.logger.Printf("char '%s'\r\n", char) | |
// end of request | |
if char == '\n' { | |
line := b.buffer[:b.requestSize] | |
ok := b.parse(line, b.Writer) | |
b.Writer.Flush() | |
b.closed = !ok | |
if b.closed { | |
b.Closer.Close() | |
return | |
} | |
// reset request size to 0 | |
b.requestSize = 0 | |
// also skip the new line that follows the \r | |
// sequence += 1 | |
} else if char == '\r' { | |
continue | |
} else { | |
b.buffer[b.requestSize] = char | |
b.requestSize++ | |
} | |
} | |
} | |
func (b *ByteConsumer) parse(line []byte, w *bufio.Writer) bool { | |
if len(line) == 0 { | |
w.Write(ErrEmptyRequest) | |
return false | |
} | |
// b.logger.Printf("Parsing line: %s\r\n", strconv.Quote(string(line))) | |
var i, expiration int | |
var c byte | |
state := OP_START | |
var e error | |
var key, value, err []byte | |
// Move to loop instead of range syntax to allow jumping of i | |
for i = 0; i < len(line); i++ { | |
c = line[i] | |
switch state { | |
case OP_START: | |
switch c { | |
case 'G', 'g': | |
state = OP_G | |
case 'S', 's': | |
state = OP_S | |
default: | |
b.logger.Printf("OP_START: (%s) %s\r\n", c, strconv.Quote(string(line))) | |
err = ErrUnknownCmd | |
goto PARSE_ERR | |
} | |
case OP_G: | |
switch c { | |
case 'E', 'e': | |
state = OP_GE | |
default: | |
b.logger.Printf("OP_G: (%s) %s\r\n", c, strconv.Quote(string(line))) | |
err = ErrUnknownCmd | |
goto PARSE_ERR | |
} | |
case OP_GE: | |
switch c { | |
case 'T', 't': | |
state = OP_GET | |
default: | |
b.logger.Printf("OP_GE: (%s) %s\r\n", c, strconv.Quote(string(line))) | |
err = ErrUnknownCmd | |
goto PARSE_ERR | |
} | |
case OP_GET: | |
switch c { | |
case '\t', ' ': | |
key = (line)[i+1:] | |
// b.logger.Printf("KEY: %s\r\n", strconv.Quote(string(key))) | |
goto PERFORM_GET | |
default: | |
err = ErrInvalidCmdDelimiter | |
goto PARSE_ERR | |
} | |
case OP_S: | |
switch c { | |
case 'E', 'e': | |
state = OP_SE | |
default: | |
b.logger.Printf("OP_S: (%s) %s\r\n", c, strconv.Quote(string(line))) | |
err = ErrUnknownCmd | |
goto PARSE_ERR | |
} | |
case OP_SE: | |
switch c { | |
case 'T', 't': | |
state = OP_SET | |
default: | |
b.logger.Printf("OP_GE: (%s) %s\r\n", c, strconv.Quote(string(line))) | |
err = ErrUnknownCmd | |
goto PARSE_ERR | |
} | |
case OP_SET: | |
switch c { | |
case '\t', ' ': | |
state = OP_SET_KEY | |
default: | |
err = ErrInvalidCmdDelimiter | |
goto PARSE_ERR | |
} | |
case OP_SET_KEY: | |
offset := i | |
for i < len(line) { | |
if line[i] == '\t' || line[i] == ' ' { | |
break | |
} | |
i++ | |
} | |
// end of input? | |
// key empty? | |
if i == len(line) || i-offset == 0 { | |
err = ErrIncompleteCmd | |
goto PARSE_ERR | |
} | |
// set key | |
key = line[offset:i] | |
// skip space | |
i++ | |
// parse expiration | |
offset = i | |
for i < len(line) { | |
if line[i] == '\t' || line[i] == ' ' { | |
i++ | |
break | |
} | |
i++ | |
} | |
// end of input? empty? | |
if i >= len(line) || i-offset == 0 { | |
err = ErrIncompleteCmd | |
goto PARSE_ERR | |
} | |
exp, e := strconv.Atoi(string(line[offset : i-1])) | |
if e != nil { | |
err = ErrInvalidExpiration | |
goto PARSE_ERR | |
} | |
expiration = exp | |
value = line[i-1:] | |
goto PERFORM_SET | |
// key = line[offset:i] | |
// i++ | |
// switch cmd { | |
// case OP_GET: | |
// goto PERFORM_GET | |
// default: | |
// return ErrUnknownCmd, false | |
// } | |
} | |
} | |
PARSE_ERR: | |
// Ignoring all write errors here, because we are going to return false | |
// and close the connection due to the parse error anyway. | |
w.Write(err) | |
b.logger.Printf("%s (%s)\r\n", string(err), strconv.Quote(string(line))) | |
return false | |
PERFORM_GET: | |
value, e = b.cache.Get(key) | |
if e == freecache.ErrLargeKey { | |
err = ErrLargeKey | |
goto PARSE_ERR | |
} else if e == freecache.ErrLargeEntry { | |
err = ErrLargeEntry | |
goto PARSE_ERR | |
} else if e == freecache.ErrNotFound { | |
err = ErrNotFound | |
goto PARSE_ERR | |
} else if e != nil { | |
err = []byte("-ERRCACHE Unknown cache error\r\n") | |
goto PARSE_ERR | |
} else { | |
if _, err := w.Write([]byte("+VALUE ")); err != nil { | |
return false | |
} | |
if _, err := w.Write(value); err != nil { | |
return false | |
} | |
if _, err := w.Write([]byte("\r\n")); err != nil { | |
return false | |
} | |
return true | |
} | |
PERFORM_SET: | |
e = b.cache.Set(key, value, expiration) | |
if e == freecache.ErrLargeKey { | |
err = ErrLargeKey | |
goto PARSE_ERR | |
} else if e == freecache.ErrLargeEntry { | |
err = ErrLargeEntry | |
goto PARSE_ERR | |
} else if e != nil { | |
err = []byte("-ERRCACHE Unknown cache error\r\n") | |
goto PARSE_ERR | |
} else { | |
if _, err := w.Write([]byte("+OK\r\n")); err != nil { | |
return false | |
} | |
return true | |
} | |
b.logger.Printf("END %s\r\n", c) | |
err = ErrUnknownCmd | |
goto PARSE_ERR | |
} | |
const ( | |
OP_START int = iota | |
OP_G | |
OP_GE | |
OP_GET | |
OP_S | |
OP_SE | |
OP_SET | |
OP_SET_KEY | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment