Skip to content

Instantly share code, notes, and snippets.

@udhos
Forked from silkeh/grace.go
Created October 28, 2023 21:22
Show Gist options
  • Save udhos/eeaf13b262150c486757296888768c41 to your computer and use it in GitHub Desktop.
Save udhos/eeaf13b262150c486757296888768c41 to your computer and use it in GitHub Desktop.
Golang graceful restart with TCP connections
package main
import (
"encoding/json"
"flag"
"io/ioutil"
"log"
"net"
"os"
"os/signal"
"syscall"
"time"
)
const (
// RestartSignal represents a signal that triggers a graceful restart
RestartSignal = syscall.SIGHUP
// CacheFile to store state to be transferred
CacheFile = "cache.json"
)
var (
address = flag.String("addr", "[::1]:1234", "Address and port to listen on")
now = time.Now().Format(time.RFC3339) + "\n"
state = NewState()
)
// State represents the TCP file descriptor state of this program
type State struct {
Listener *net.TCPListener
ListenFD uintptr
Conns map[string]uintptr
}
// NewState creates a usable state
func NewState() *State {
return &State{Conns: make(map[string]uintptr)}
}
// FDs returns the connection mapping a a slice of strings and one of file descriptors
func (s *State) FDs() ([]string, []uintptr) {
ids := make([]string, 1, len(s.Conns)+1)
fds := make([]uintptr, 1, len(s.Conns)+1)
fds[0] = s.ListenFD
for id, fd := range s.Conns {
ids = append(ids, id)
fds = append(fds, fd)
}
return ids, fds
}
// SetListener sets the listener and listener file descriptor from a TCPListener pointer
func (s *State) SetListener(l *net.TCPListener) {
f, _ := l.File()
s.Listener = l
s.ListenFD = f.Fd()
}
// AddClient adds a new connection to the state
func (s *State) AddClient(c *Client) {
f, _ := c.File()
s.Conns[c.Name] = f.Fd()
}
// DelClient removes a connection from the state
func (s *State) DelClient(c *Client) {
delete(s.Conns, c.Name)
}
// NewListener starts a new listener and registers this in the state
func (s *State) NewListener(address string) {
addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
log.Fatalf("Error resolving address: %s", err)
}
ln, err := net.ListenTCP("tcp", addr)
if err != nil {
log.Fatalf("Error listening on %s: %s", address, err)
}
state.SetListener(ln)
}
// restoreFDs restores the connection map from a list of IDs (ordered the same as the given FDs)
func (s *State) restoreFDs(ids []string) {
s.ListenFD = 3
for i, id := range ids[1:] {
s.Conns[id] = uintptr(i) + 4
}
}
// Resume resumes a TCP listener and TCP connections from a list of IDs.
// These IDs are ordered the same way as the FDs passed when forking.
func (s *State) Resume(ids []string) {
s.restoreFDs(ids)
ln, err := net.FileListener(os.NewFile(s.ListenFD, ""))
if err != nil {
log.Fatalf("Cannot open listener: %s", err)
}
for k, fd := range state.Conns {
c, err := net.FileConn(os.NewFile(fd, ""))
if err != nil {
log.Printf("Error resuming connection: %s")
continue
}
go NewClient(k, c.(*net.TCPConn)).Run()
}
state.SetListener(ln.(*net.TCPListener))
}
// Client represents a named TCP client
type Client struct {
*net.TCPConn
Name string
}
// NewClient creates a named client
func NewClient(name string, conn *net.TCPConn) *Client {
return &Client{Name: name, TCPConn: conn}
}
// Run handles the interaction for a client
func (c *Client) Run() {
for {
_, err := c.Write([]byte(c.Name + " " + now))
if err != nil {
log.Printf("Error sending data: %s", err)
break
}
time.Sleep(3600 * time.Second)
}
c.Close()
state.DelClient(c)
}
// forkExec forks a new version of this program with the same file descriptors
func forkExec(fds []uintptr) (int, error) {
attr := &syscall.ProcAttr{
Env: os.Environ(),
Files: append([]uintptr{os.Stdin.Fd(), os.Stdout.Fd(), os.Stderr.Fd()}, fds...),
}
return syscall.ForkExec(os.Args[0], os.Args, attr)
}
// fork writes the cache and runs a new version of this program.
// File descriptors are given to the new process.
func fork() {
ids, fds := state.FDs()
if err := writeCache(ids); err != nil {
log.Printf("Error while writing cache (not forking): %s", err)
return
}
pid, err := forkExec(fds)
if err != nil {
log.Printf("Error while forking: %s", err)
return
}
log.Printf("Forked to %v", pid)
os.Exit(0)
}
// writeCache creates the cache file
func writeCache(ids []string) error {
data, _ := json.Marshal(ids)
return ioutil.WriteFile(CacheFile, data, 0644)
}
// readCache reads the information from the cache file
func readCache() (ids []string, err error) {
file, err := os.Open(CacheFile)
if err != nil {
return nil, err
}
err = json.NewDecoder(file).Decode(&ids)
if err != nil {
return nil, err
}
return
}
// setupSignals configures signal handling
func setupSignals() {
c := make(chan os.Signal)
signal.Notify(c, RestartSignal)
go func() {
<-c
fork()
}()
}
func main() {
flag.Parse()
setupSignals()
if _, err := os.Stat(CacheFile); err == nil {
ids, err := readCache()
if err != nil {
log.Fatalf("Error reading cache: %s", err)
}
state.Resume(ids)
} else {
state.NewListener(*address)
}
for {
conn, err := state.Listener.AcceptTCP()
if err != nil {
log.Printf("Error accepting connection: %s", err)
}
c := NewClient(conn.RemoteAddr().String(), conn)
state.AddClient(c)
go c.Run()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment