-
-
Save udhos/eeaf13b262150c486757296888768c41 to your computer and use it in GitHub Desktop.
Golang graceful restart with TCP connections
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"encoding/json" | |
"flag" | |
"io/ioutil" | |
"log" | |
"net" | |
"os" | |
"os/signal" | |
"syscall" | |
"time" | |
) | |
const ( | |
// RestartSignal represents a signal that triggers a graceful restart | |
RestartSignal = syscall.SIGHUP | |
// CacheFile to store state to be transferred | |
CacheFile = "cache.json" | |
) | |
var ( | |
address = flag.String("addr", "[::1]:1234", "Address and port to listen on") | |
now = time.Now().Format(time.RFC3339) + "\n" | |
state = NewState() | |
) | |
// State represents the TCP file descriptor state of this program | |
type State struct { | |
Listener *net.TCPListener | |
ListenFD uintptr | |
Conns map[string]uintptr | |
} | |
// NewState creates a usable state | |
func NewState() *State { | |
return &State{Conns: make(map[string]uintptr)} | |
} | |
// FDs returns the connection mapping a a slice of strings and one of file descriptors | |
func (s *State) FDs() ([]string, []uintptr) { | |
ids := make([]string, 1, len(s.Conns)+1) | |
fds := make([]uintptr, 1, len(s.Conns)+1) | |
fds[0] = s.ListenFD | |
for id, fd := range s.Conns { | |
ids = append(ids, id) | |
fds = append(fds, fd) | |
} | |
return ids, fds | |
} | |
// SetListener sets the listener and listener file descriptor from a TCPListener pointer | |
func (s *State) SetListener(l *net.TCPListener) { | |
f, _ := l.File() | |
s.Listener = l | |
s.ListenFD = f.Fd() | |
} | |
// AddClient adds a new connection to the state | |
func (s *State) AddClient(c *Client) { | |
f, _ := c.File() | |
s.Conns[c.Name] = f.Fd() | |
} | |
// DelClient removes a connection from the state | |
func (s *State) DelClient(c *Client) { | |
delete(s.Conns, c.Name) | |
} | |
// NewListener starts a new listener and registers this in the state | |
func (s *State) NewListener(address string) { | |
addr, err := net.ResolveTCPAddr("tcp", address) | |
if err != nil { | |
log.Fatalf("Error resolving address: %s", err) | |
} | |
ln, err := net.ListenTCP("tcp", addr) | |
if err != nil { | |
log.Fatalf("Error listening on %s: %s", address, err) | |
} | |
state.SetListener(ln) | |
} | |
// restoreFDs restores the connection map from a list of IDs (ordered the same as the given FDs) | |
func (s *State) restoreFDs(ids []string) { | |
s.ListenFD = 3 | |
for i, id := range ids[1:] { | |
s.Conns[id] = uintptr(i) + 4 | |
} | |
} | |
// Resume resumes a TCP listener and TCP connections from a list of IDs. | |
// These IDs are ordered the same way as the FDs passed when forking. | |
func (s *State) Resume(ids []string) { | |
s.restoreFDs(ids) | |
ln, err := net.FileListener(os.NewFile(s.ListenFD, "")) | |
if err != nil { | |
log.Fatalf("Cannot open listener: %s", err) | |
} | |
for k, fd := range state.Conns { | |
c, err := net.FileConn(os.NewFile(fd, "")) | |
if err != nil { | |
log.Printf("Error resuming connection: %s") | |
continue | |
} | |
go NewClient(k, c.(*net.TCPConn)).Run() | |
} | |
state.SetListener(ln.(*net.TCPListener)) | |
} | |
// Client represents a named TCP client | |
type Client struct { | |
*net.TCPConn | |
Name string | |
} | |
// NewClient creates a named client | |
func NewClient(name string, conn *net.TCPConn) *Client { | |
return &Client{Name: name, TCPConn: conn} | |
} | |
// Run handles the interaction for a client | |
func (c *Client) Run() { | |
for { | |
_, err := c.Write([]byte(c.Name + " " + now)) | |
if err != nil { | |
log.Printf("Error sending data: %s", err) | |
break | |
} | |
time.Sleep(3600 * time.Second) | |
} | |
c.Close() | |
state.DelClient(c) | |
} | |
// forkExec forks a new version of this program with the same file descriptors | |
func forkExec(fds []uintptr) (int, error) { | |
attr := &syscall.ProcAttr{ | |
Env: os.Environ(), | |
Files: append([]uintptr{os.Stdin.Fd(), os.Stdout.Fd(), os.Stderr.Fd()}, fds...), | |
} | |
return syscall.ForkExec(os.Args[0], os.Args, attr) | |
} | |
// fork writes the cache and runs a new version of this program. | |
// File descriptors are given to the new process. | |
func fork() { | |
ids, fds := state.FDs() | |
if err := writeCache(ids); err != nil { | |
log.Printf("Error while writing cache (not forking): %s", err) | |
return | |
} | |
pid, err := forkExec(fds) | |
if err != nil { | |
log.Printf("Error while forking: %s", err) | |
return | |
} | |
log.Printf("Forked to %v", pid) | |
os.Exit(0) | |
} | |
// writeCache creates the cache file | |
func writeCache(ids []string) error { | |
data, _ := json.Marshal(ids) | |
return ioutil.WriteFile(CacheFile, data, 0644) | |
} | |
// readCache reads the information from the cache file | |
func readCache() (ids []string, err error) { | |
file, err := os.Open(CacheFile) | |
if err != nil { | |
return nil, err | |
} | |
err = json.NewDecoder(file).Decode(&ids) | |
if err != nil { | |
return nil, err | |
} | |
return | |
} | |
// setupSignals configures signal handling | |
func setupSignals() { | |
c := make(chan os.Signal) | |
signal.Notify(c, RestartSignal) | |
go func() { | |
<-c | |
fork() | |
}() | |
} | |
func main() { | |
flag.Parse() | |
setupSignals() | |
if _, err := os.Stat(CacheFile); err == nil { | |
ids, err := readCache() | |
if err != nil { | |
log.Fatalf("Error reading cache: %s", err) | |
} | |
state.Resume(ids) | |
} else { | |
state.NewListener(*address) | |
} | |
for { | |
conn, err := state.Listener.AcceptTCP() | |
if err != nil { | |
log.Printf("Error accepting connection: %s", err) | |
} | |
c := NewClient(conn.RemoteAddr().String(), conn) | |
state.AddClient(c) | |
go c.Run() | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment