Skip to content

Instantly share code, notes, and snippets.

@bejaneps
Last active November 29, 2023 09:29
Show Gist options
  • Save bejaneps/2aef6889c9b3069504ed236b22ddf0d8 to your computer and use it in GitHub Desktop.
Save bejaneps/2aef6889c9b3069504ed236b22ddf0d8 to your computer and use it in GitHub Desktop.
Graceful shutdown
package main
func main() {
cfg := config{}
if err := env.Parse(&cfg); err != nil {
log.Fatal(err.Error())
}
ctx, cancelFunc := context.WithCancel(context.Background())
ctx = log.ContextWithAttributes(ctx, log.Attributes{"serverPort": cfg.ServerPort})
cacheClient, err := cache.New(cfg.CacheConfig)
if err != nil {
log.Fatal(err.Error())
}
dbClient, err := db.New(cfg.DBConfig)
if err != nil {
cacheClient.Close()
log.Fatal(err.Error())
}
repo := repoServer.New(cacheClient, dbClient)
service := serviceServer.New(repo)
s, err := server.New(ctx, service, "0.0.0.0:"+cfg.ServerPort)
if err != nil {
logIfError(dbClient.Close())
logIfError(cacheClient.Close())
log.Fatal(err.Error())
}
if err := runtime.RunUntilSignal(
func() error { // start func
return s.ListenAndAccept(ctx)
},
func(ctx context.Context) error { // stop func
cancelFunc()
logIfError(cacheClient.Close())
logIfError(dbClient.Close())
return s.Close()
}, time.Duration(cfg.StopServerTimeout)*time.Second,
); err != nil {
log.Fatal(err.Error())
}
}
func logIfError(err error) {
if err != nil {
log.Error(err.Error())
}
}
package runtime
import (
"context"
"os"
"os/signal"
"syscall"
"time"
"log" // custom logger
)
// WaitSignal method is a runtime utility function that blocks the runtime until
// a signal is received
func WaitSignal() os.Signal {
return <-getSignalChan()
}
func getSignalChan() chan os.Signal {
sig := make(chan os.Signal, 1)
signal.Notify(sig,
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGQUIT)
return sig
}
func RunUntilSignal(run func() error, stop func(context.Context) error, timeout time.Duration) error {
sigChan := getSignalChan()
errSig := make(chan error)
go func() {
errSig <- run()
}()
select {
case err := <-errSig:
return err
case sig := <-sigChan:
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
log.Info("received signal", log.String("signal", sig.String()))
if stop != nil {
err := stop(ctx)
if err != nil {
log.Error("could not stop server:", log.StdError(err))
}
}
}
return nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment