Last active
January 16, 2024 21:18
-
-
Save mariash/cf75a2deff1d16af14ef8020393ccd48 to your computer and use it in GitHub Desktop.
Golang ReverseProxy race condition
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bytes" | |
"io" | |
"net/http" | |
"net/http/httptrace" | |
"net/http/httputil" | |
"net/url" | |
"time" | |
) | |
func main() { | |
readyCh := make(chan struct{}) | |
backendServer := http.NewServeMux() | |
backendServer.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) { | |
rw.WriteHeader(http.StatusContinue) | |
hj, _ := rw.(http.Hijacker) | |
conn, _, _ := hj.Hijack() | |
conn.Close() | |
}) | |
go http.ListenAndServe(":8081", backendServer) | |
time.Sleep(1 * time.Second) | |
handler := func(p *httputil.ReverseProxy) func(http.ResponseWriter, *http.Request) { | |
return func(rw http.ResponseWriter, r *http.Request) { | |
r = r.WithContext(httptrace.WithClientTrace(r.Context(), &httptrace.ClientTrace{ | |
Got100Continue: func() { | |
// Delay the 1xx hook | |
<-readyCh | |
}, | |
})) | |
p.ServeHTTP(rw, r) | |
rw.Header().Set("X-Something", "Hello") | |
} | |
} | |
// trigger trace context once, blocking first time | |
go func() { | |
readyCh <- struct{}{} | |
}() | |
target, err := url.Parse("http://localhost:8081") | |
if err != nil { | |
panic(err) | |
} | |
proxy := httputil.NewSingleHostReverseProxy(target) | |
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) { | |
readyCh <- struct{}{} | |
for i := 0; i < 10000000; i++ { | |
rw.Header().Set("X-Something", "Hello") | |
} | |
} | |
proxy.Transport = &RetryTransport{T: http.DefaultTransport} | |
http.HandleFunc("/", handler(proxy)) | |
go http.ListenAndServe(":8080", nil) | |
time.Sleep(1 * time.Second) | |
data := bytes.NewBufferString("Hello!") | |
req, err := http.NewRequest("POST", "http://localhost:8080", data) | |
if err != nil { | |
panic(err) | |
} | |
req.Header.Set("Expect", "100-continue") | |
resp, err := http.DefaultClient.Do(req) | |
if err != nil { | |
panic(err) | |
} | |
defer resp.Body.Close() | |
_, err = io.ReadAll(resp.Body) | |
if err != nil { | |
panic(err) | |
} | |
time.Sleep(10 * time.Second) | |
println("done") | |
} | |
type RetryTransport struct { | |
T http.RoundTripper | |
} | |
func (tr *RetryTransport) RoundTrip(req *http.Request) (res *http.Response, err error) { | |
for i := 0; i < 3; i++ { | |
res, err = tr.T.RoundTrip(req) | |
if err == nil { | |
return res, err | |
} | |
} | |
return res, err | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment