From https://www.rodrigoaraujo.me/posts/golang-pattern-graceful-shutdown-of-concurrent-events/
Last active
July 5, 2023 17:27
-
-
Save zacharysyoung/12b500701cc32bdae09535f272f521c0 to your computer and use it in GitHub Desktop.
Start an HTTP server and listen for a response, but only for so long
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"context" | |
"fmt" | |
"log" | |
"net/http" | |
"time" | |
) | |
// server.Shutdown() by endpoint: https://medium.com/@int128/shutdown-http-server-by-endpoint-in-go-2a0e2d7f9b8c | |
// context.WithTimeout(): https://stackoverflow.com/a/46511560 | |
func startServer() string { | |
timeout := 10 * time.Second | |
ctx, cancel := context.WithTimeout(context.Background(), timeout) | |
defer cancel() | |
log.Printf("set timeout for %q", timeout) | |
var code string | |
m := http.NewServeMux() | |
s := http.Server{Addr: ":8888", Handler: m} | |
m.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { | |
log.Println("received request", r.URL) | |
q := r.URL.Query() | |
val, ok := q["code"] | |
if !ok { | |
return | |
} | |
code = val[0] | |
log.Println("got code", code) | |
s.Shutdown(context.Background()) | |
}) | |
endpointTerm := make(chan bool, 1) | |
go func() { | |
log.Println("starting server") | |
s.ListenAndServe() | |
endpointTerm <- true | |
}() | |
select { | |
case <-endpointTerm: | |
log.Println("endpoint shut down server") | |
case <-ctx.Done(): | |
log.Println(ctx.Err()) | |
} | |
return code | |
} | |
func main() { | |
fmt.Println(startServer()) | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"context" | |
"fmt" | |
"log" | |
"net/http" | |
"os" | |
"os/signal" | |
"time" | |
) | |
// https://gobyexample.com/timeouts: much simpler than context.WithTimeout() ... context.Done() | |
// https://pkg.go.dev/net/http#example-Server.Shutdown | |
// listenForCode sits as the third leg, in a 3-legged OAuth2 flow, and looks for the code in | |
// the redirect request from the issuing service. | |
func listenForCode(srv *http.Server, code_ch chan<- string) { | |
codeHandler := func(w http.ResponseWriter, r *http.Request) { | |
log.Println("received request", r.URL) | |
if vals, ok := r.URL.Query()["code"]; ok { | |
code_ch <- vals[0] | |
} | |
code_ch <- "" | |
} | |
http.HandleFunc("/", codeHandler) | |
log.Println("starting server") | |
if err := srv.ListenAndServe(); err != http.ErrServerClosed { | |
log.Println("server error:") | |
} | |
} | |
func main() { | |
timeout := 10 * time.Second | |
sigint_ch := make(chan os.Signal, 1) | |
signal.Notify(sigint_ch, os.Interrupt) | |
srv := &http.Server{ | |
Addr: ":8888", | |
} | |
code_ch := make(chan string, 1) | |
go listenForCode(srv, code_ch) | |
select { | |
case code := <-code_ch: | |
log.Println("done:", code) | |
case <-time.After(timeout): | |
log.Println("done: timeout after", timeout) | |
case <-sigint_ch: | |
fmt.Println("") // deal with "^C" echo'ing in terminal | |
log.Println("done: user interrupted") | |
} | |
srv.Shutdown(context.Background()) | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"context" | |
"fmt" | |
"net/http" | |
"os" | |
"os/signal" | |
"sync" | |
"syscall" | |
"time" | |
"log" | |
) | |
func slowJob1(name string, wg *sync.WaitGroup) { | |
defer wg.Done() | |
fmt.Printf("starting job 1 for %s\n", name) | |
time.Sleep(5 * time.Second) | |
fmt.Printf("finished job 1 for %s\n", name) | |
} | |
func slowJob2(name string, wg *sync.WaitGroup) { | |
defer wg.Done() | |
fmt.Printf("starting job 2 for %s\n", name) | |
time.Sleep(4 * time.Second) | |
fmt.Printf("finished job 2 for %s\n", name) | |
} | |
func slowJob3(name string, wg *sync.WaitGroup) { | |
defer wg.Done() | |
fmt.Printf("starting job 3 for %s\n", name) | |
time.Sleep(3 * time.Second) | |
fmt.Printf("finished job 3 for %s\n", name) | |
} | |
func consumer(ctx context.Context, jobQueue chan string, doneChan chan interface{}) { | |
wg := &sync.WaitGroup{} | |
for { | |
select { | |
// If the context was cancelled, a SIGTERM was captured | |
// So we wait for the jobs to finish, write to the done channel and return | |
case <-ctx.Done(): | |
// Note that the waiting time here is unbounded and can take a long time. | |
// If that's an issue you can: | |
// (1) issue a SIGKILL after a certain time or | |
// (2) use a context with timeout | |
wg.Wait() | |
fmt.Println("writing to done channel") | |
doneChan <- struct{}{} | |
log.Println("Done, shutting down the consumer") | |
return | |
case job := <-jobQueue: | |
wg.Add(3) | |
go slowJob1(job, wg) | |
go slowJob2(job, wg) | |
go slowJob3(job, wg) | |
} | |
} | |
} | |
// Our custom handler that holds a wait group used to block the shutdown while | |
// it's running the jobs. | |
type CustomHandler struct { | |
jobQueue chan string | |
} | |
func NewCustomHandler(jobQueue chan string) *CustomHandler { | |
// You can check for wg == nil if feeling paranoid | |
return &CustomHandler{jobQueue: jobQueue} | |
} | |
func (h *CustomHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { | |
jobName := r.URL.Path[1:] | |
h.jobQueue <- jobName | |
fmt.Fprintf(w, "job %s started", jobName) | |
} | |
func main() { | |
jobQueue := make(chan string) | |
customHandler := NewCustomHandler(jobQueue) | |
ctx, cancel := context.WithCancel(context.Background()) | |
httpServer := &http.Server{ | |
Addr: ":8889", | |
} | |
http.Handle("/", customHandler) | |
// Handle sigterm and await termChan signal | |
termChan := make(chan os.Signal, 1) | |
signal.Notify(termChan, syscall.SIGTERM, syscall.SIGINT) | |
go func() { | |
if err := httpServer.ListenAndServe(); err != nil { | |
if err != http.ErrServerClosed { | |
log.Printf("HTTP server closed with: %v\n", err) | |
} | |
log.Printf("HTTP server shut down") | |
} | |
}() | |
// doneChan will be the channel we'll be listening on | |
// to know all already started jobs have finished | |
// before we actually exit the program | |
doneChan := make(chan interface{}) | |
go consumer(ctx, jobQueue, doneChan) | |
// Wait for SIGTERM to be captured | |
<-termChan | |
log.Println("SIGTERM received. Shutdown process initiated") | |
// Shutdown the HTTP server | |
if err := httpServer.Shutdown(ctx); err != nil { | |
log.Fatalf("Server Shutdown Failed:%+v", err) | |
} | |
// Cancel the context, this will make the consumer stop | |
cancel() | |
// Wait for the consumer's jobs to finish | |
log.Println("waiting consumer to finish its jobs") | |
<-doneChan | |
log.Println("done. returning.") | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment