Skip to content

Instantly share code, notes, and snippets.

@skoowoo
Last active April 9, 2024 02:29
Show Gist options
  • Save skoowoo/c68925e876dd12b48986 to your computer and use it in GitHub Desktop.
Save skoowoo/c68925e876dd12b48986 to your computer and use it in GitHub Desktop.
这是一个用 Go 基于 epoll 写的服务器程序,虽然 epoll 的实现不是很完善,还有许多细节需要改善,但此程序可以作为 Go epoll 开发的样例。这个程序也是在一个特定的需求场景下写的。
package main
import (
"bytes"
"encoding/binary"
"flag"
"io"
"log"
"os"
"runtime"
"serverepoll/ringbuffer"
"strconv"
"syscall"
"time"
)
var (
filePath string
port int
readBlock int
msgQueueSize int
loader *Loader
)
func main() {
flag.StringVar(&filePath, "path", "/tmp/data.txt", "")
flag.IntVar(&port, "port", 9999, "")
flag.IntVar(&readBlock, "rb", 8192, "")
flag.IntVar(&msgQueueSize, "queue", 10000, "")
flag.Parse()
cpus := runtime.NumCPU()
runtime.GOMAXPROCS(cpus + 5)
// 启动loader
loader = NewLoader(filePath, readBlock, msgQueueSize)
loader.Run()
// 启动worker goroutine
workers := make([]*Worker, cpus)
for i := 0; i < cpus; i++ {
worker := newWorker()
workers[i] = worker
go worker.Run()
}
// 创建listen socket
var (
event syscall.EpollEvent
events [32]syscall.EpollEvent
wRR int
epoll EpollFdType
)
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
if err != nil {
log.Fatalln(err)
}
if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil {
log.Fatalln(err)
}
if err := syscall.Bind(fd, &syscall.SockaddrInet4{Port: port}); err != nil {
log.Fatalln(err)
}
if err := syscall.Listen(fd, 4096); err != nil {
log.Fatalln(err)
}
// 将listen fd放入epoll管理
event.Events = syscall.EPOLLIN
event.Fd = int32(fd)
if fd, err := syscall.EpollCreate(1024); err != nil {
log.Fatalln(err)
} else {
epoll = EpollFdType(fd)
defer syscall.Close(fd)
}
if err := epoll.AddEvents(fd, syscall.EPOLLIN); err != nil {
log.Fatalln(err)
}
for {
nevents, err := syscall.EpollWait(int(epoll), events[:], -1)
if err != nil {
log.Fatalln(err)
}
for i := 0; i < nevents; i++ {
connFd, _, err := syscall.Accept(int(events[i].Fd))
if err != nil {
log.Println(err)
continue
}
if err := syscall.SetNonblock(connFd, true); err != nil {
log.Println(err)
syscall.Close(connFd)
continue
}
if err := syscall.SetsockoptInt(connFd, syscall.IPPROTO_TCP, syscall.TCP_NODELAY, 1); err != nil {
log.Println(err)
syscall.Close(connFd)
continue
}
// connFd 轮询递给每个work goroutine
workers[wRR].Conns <- connFd
if wRR++; wRR == cpus {
wRR = 0
}
}
}
}
type Worker struct {
Conns chan int
events []syscall.EpollEvent
epoll EpollFdType
connCount int
pool map[int]*Connection
rbuf *ringbuffer.Buffer
done bool
}
func newWorker() *Worker {
w := new(Worker)
w.Conns = make(chan int, 100)
w.events = make([]syscall.EpollEvent, 1024)
w.pool = make(map[int]*Connection, 100)
fd, err := syscall.EpollCreate(1024)
if err != nil {
log.Fatalln(err)
}
w.epoll = EpollFdType(fd)
return w
}
func (w *Worker) Run() {
defer func() {
syscall.Close(int(w.epoll))
runtime.UnlockOSThread()
}()
runtime.LockOSThread()
for {
select {
case fd := <-w.Conns:
if err := w.epoll.AddEvents(fd, syscall.EPOLLIN); err != nil {
log.Println(err)
syscall.Close(fd)
} else {
w.connCount++
}
default:
}
if w.connCount == 0 {
time.Sleep(time.Millisecond * 100)
continue
}
n, err := syscall.EpollWait(int(w.epoll), w.events[:], 512)
if err != nil {
log.Fatalln(err)
}
for i := 0; i < n; i++ {
fd := int(w.events[i].Fd)
ev := w.events[i].Events
conn, ok := w.pool[fd]
if !ok {
conn = NewConnection(w.epoll, fd)
w.pool[fd] = conn
}
switch {
case ev&syscall.EPOLLIN != 0:
w.Handle(conn, 'r')
case ev&syscall.EPOLLOUT != 0:
w.Handle(conn, 'w')
default:
// TODO
}
if conn.down {
delete(w.pool, fd)
conn = nil
}
}
}
}
func (w *Worker) Handle(conn *Connection, op byte) {
var msg *Message
if op == 'r' {
if err := conn.Read(3); err != nil {
if err != syscall.EAGAIN {
log.Println(err)
}
return
}
conn.ReadBuf()
if w.done {
conn.Close()
return
}
msg, w.rbuf = loader.FetchLine(w.rbuf)
if w.rbuf == nil {
w.done = true
conn.Close()
return
}
conn.w = msg.Marshal(conn.body)
// 发送应答
op = 'w'
}
if op == 'w' {
if err := conn.Write(); err != nil {
if err != syscall.EAGAIN {
log.Println(err)
}
return
}
}
}
type Connection struct {
body []byte
r int
w []byte
events uint32
fd int
epoll EpollFdType
down bool
}
func NewConnection(epoll EpollFdType, fd int) *Connection {
c := new(Connection)
c.body = make([]byte, 256)
c.r = 0
c.w = nil
c.events = syscall.EPOLLIN
c.fd = fd
c.epoll = epoll
c.down = false
return c
}
func (c *Connection) Close() {
if c.events&syscall.EPOLLIN != 0 {
c.epoll.DelEvents(c.fd, syscall.EPOLLIN)
}
if c.events&syscall.EPOLLOUT != 0 {
c.epoll.DelEvents(c.fd, syscall.EPOLLOUT)
}
syscall.Close(c.fd)
c.down = true
}
func (c *Connection) Read(bytes int) (err error) {
var n int
for {
n, err = syscall.Read(c.fd, c.body[c.r:])
if err != nil {
if err == syscall.EAGAIN {
return
}
// 连接出现异常
c.Close()
return
}
if c.r += n; c.r >= bytes {
break
}
}
return
}
func (c *Connection) Write() (err error) {
var n int
for {
n, err = syscall.Write(c.fd, c.w)
if err != nil {
if err == syscall.EAGAIN {
c.epoll.AddEvents(c.fd, syscall.EPOLLOUT)
c.events |= syscall.EPOLLOUT
return
}
// 连接出现异常
c.Close()
return
}
if c.w = c.w[n:]; len(c.w) == 0 {
break
}
}
if c.events&syscall.EPOLLOUT != 0 {
c.epoll.DelEvents(c.fd, syscall.EPOLLOUT)
c.events -= syscall.EPOLLOUT
}
return
}
func (c *Connection) ReadBuf() (buf []byte) {
buf = c.body[:c.r]
c.r = 0
return
}
type EpollFdType int
func (e EpollFdType) AddEvents(fd int, events uint32) (err error) {
var ev syscall.EpollEvent
ev.Fd = int32(fd)
ev.Events = events
if err = syscall.EpollCtl(int(e), syscall.EPOLL_CTL_ADD, fd, &ev); err != nil {
return
}
return
}
func (e EpollFdType) DelEvents(fd int, events uint32) (err error) {
var ev syscall.EpollEvent
ev.Fd = int32(fd)
ev.Events = events
if err = syscall.EpollCtl(int(e), syscall.EPOLL_CTL_DEL, fd, &ev); err != nil {
return
}
return
}
type Message struct {
Seq uint32
SeqStr string
Payload []byte
}
func (m *Message) Marshal(buffer []byte) []byte {
l := len(m.Payload)
sl := len(m.SeqStr)
drop := l / 3
length := 4 + (sl + l - drop) + 2 // seq + payload + \r\n
binary.LittleEndian.PutUint16(buffer[:2], uint16(length))
binary.LittleEndian.PutUint32(buffer[2:6], m.Seq)
pos1, pos2, pos3 := 0, drop, 2*drop
j := 6
for i := 0; i < sl; i++ {
buffer[j] = m.SeqStr[i]
j++
}
for i := l - 1; i >= pos3; i-- {
buffer[j] = m.Payload[i]
j++
}
for i := pos2 - 1; i >= pos1; i-- {
buffer[j] = m.Payload[i]
j++
}
buffer[j] = '\r'
j++
buffer[j] = '\n'
return buffer[:length+2]
}
type Loader struct {
queue chan *Message
ring *ringbuffer.Ring
path string
count uint32
sep []byte
rb int
}
func NewLoader(path string, rb, queue int) (l *Loader) {
l = new(Loader)
l.queue = make(chan *Message, queue)
l.ring = ringbuffer.NewRing(queue, queue)
l.rb = rb
l.sep = []byte("\r\n")
l.path = path
return
}
func (l *Loader) FetchLine(rbuf *ringbuffer.Buffer) (*Message, *ringbuffer.Buffer) {
/*
m, ok := <-l.queue
if !ok {
return nil, errors.New("close")
}
return m, nil
*/
var e interface{}
if e, rbuf = l.ring.Read(rbuf); rbuf == nil {
return nil, nil
}
return e.(*Message), rbuf
}
func (l *Loader) Run() {
go func() {
// runtime.LockOSThread()
defer func() {
close(l.queue)
// runtime.UnlockOSThread()
}()
log.Println("start")
if err := l.load(); err != nil {
log.Println(err)
}
log.Println("end", l.count)
}()
}
func (l *Loader) load() error {
file, err := os.Open(l.path)
if err != nil {
log.Fatalln(err)
}
var msg *Message
remain := 0
buffer := make([]byte, l.rb)
var wbuf *ringbuffer.Buffer
for {
n, err := file.Read(buffer[remain:])
if err != nil {
if err == io.EOF {
l.ring.Stop(wbuf)
return err
}
continue
}
tmp := buffer[:n+remain]
remain = 0
for len(tmp) > 0 {
if i := bytes.Index(tmp, l.sep); i == -1 {
remain = len(tmp)
buffer = make([]byte, l.rb)
copy(buffer[:], tmp)
break
} else {
msg = new(Message)
msg.Seq = l.count
msg.SeqStr = strconv.Itoa(int(l.count))
msg.Payload = tmp[:i]
wbuf = l.ring.Write(wbuf, msg)
l.count++
tmp = tmp[i+2:]
}
}
if len(tmp) == 0 {
buffer = make([]byte, l.rb)
}
}
}
@skoowoo
Copy link
Author

skoowoo commented Sep 17, 2014

可以忽略 Loader 部分的实现,此程序也需要 ringbuffer 库才能编译。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment