Last active
January 9, 2020 18:56
-
-
Save JonasDoe/51925baeb407040c6b93b172d2425ad3 to your computer and use it in GitHub Desktop.
Middleware timeouts don't not help against a faulty client
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package demo | |
import ( | |
"context" | |
"github.com/go-chi/chi" | |
"github.com/go-chi/chi/middleware" | |
"github.com/stretchr/testify/assert" | |
"io/ioutil" | |
"net/http" | |
"testing" | |
"time" | |
) | |
// there IS a timeout triggered, but still the test is blocking. http/server#finishRequest keeps waiting ... | |
func TestSlowRequest(t *testing.T) { | |
// | |
// SETUP | |
// | |
var err error | |
timeout := 1 * time.Second | |
port := "1256" | |
router := chi.NewRouter() | |
router.Use(middleware.Timeout(timeout)) | |
router.Handle("/", middlewareHandler{}) | |
//start the server without having any timeout values set -> is is what makes the server vulnerable | |
go func() { | |
err = http.ListenAndServe(":"+port, router) | |
if err != nil { | |
panic(err) | |
} | |
}() | |
//wait until the server is avaible | |
for attempts := 0; ; attempts++ { | |
get, err := http.Get("http://localhost:" + port) | |
if err == nil && get != nil && get.StatusCode == http.StatusOK { | |
break | |
} | |
time.Sleep(50 * time.Millisecond) | |
if attempts == 10 { | |
panic("could not set up server") | |
} | |
} | |
// | |
//TESTS | |
// | |
tests := []struct { | |
name string | |
reader *delayedReader | |
wantErr bool | |
}{ | |
//{"without timeout", &delayedReader{Delay: 0, Content: "hallo=welt"}, false}, | |
{"with timeout", &delayedReader{Delay: timeout * 5, Content: "hallo=welt"}, true}, | |
} | |
for _, tt := range tests { | |
t.Run(tt.name, func(t *testing.T) { | |
println(time.Now().String() + ": test start") | |
reader := tt.reader | |
request, err := http.NewRequest(http.MethodPost, "http://localhost:"+port, reader) | |
if err != nil { | |
panic(err) | |
} | |
request.ContentLength = int64(len([]byte(reader.Content))) | |
client := http.DefaultClient | |
response, err := client.Do(request) | |
if tt.wantErr { | |
assert.NotEqual(t, response.StatusCode, http.StatusOK) | |
} else { | |
assert.Nil(t, err) | |
assert.Equal(t, response.StatusCode, http.StatusOK) | |
} | |
println(time.Now().String() + ": test end") | |
}) | |
} | |
} | |
type delayedReader struct { | |
Delay time.Duration | |
Content string | |
done bool | |
} | |
func (dr *delayedReader) Read(p []byte) (n int, err error) { | |
if dr.done { | |
return 0, nil | |
} | |
payloadBytes := []byte(dr.Content) | |
for i, b := range payloadBytes { | |
p[i] = b | |
time.Sleep(time.Duration(int(dr.Delay) / len(payloadBytes))) | |
} | |
dr.done = true | |
return len(dr.Content), nil | |
} | |
type middlewareHandler struct{} | |
func (t middlewareHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { | |
err := handleWithTimeout(request.Context(), func() error { | |
_, err := ioutil.ReadAll(request.Body) | |
return err | |
}) | |
if err != nil { | |
println(time.Now().String() + ": " + err.Error()) // context deadline exceeded | |
} | |
} | |
func handleWithTimeout(ctx context.Context, handler func() error) (err error) { | |
errChan := make(chan error, 1) | |
go func() { | |
errChan <- handler() | |
}() | |
select { | |
case err = <-errChan: | |
return err | |
case <-ctx.Done(): | |
return ctx.Err() | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment