Created
January 9, 2019 11:10
-
-
Save blixt/48f4581437eab9863b977ca7dc3001af to your computer and use it in GitHub Desktop.
Testing ModifyResponse with ReverseProxy and web sockets
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bufio" | |
"fmt" | |
"io" | |
"io/ioutil" | |
"log" | |
"net/http" | |
"net/http/httptest" | |
"net/http/httputil" | |
"net/url" | |
"strings" | |
"testing" | |
) | |
func TestReverseProxyWebSocketModifyResponse(t *testing.T) { | |
backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" { | |
t.Error("unexpected backend request") | |
http.Error(w, "unexpected request", 400) | |
return | |
} | |
c, _, err := w.(http.Hijacker).Hijack() | |
if err != nil { | |
t.Error(err) | |
return | |
} | |
defer c.Close() | |
io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n") | |
bs := bufio.NewScanner(c) | |
if !bs.Scan() { | |
t.Errorf("backend failed to read line from client: %v", bs.Err()) | |
return | |
} | |
fmt.Fprintf(c, "backend got %q\n", bs.Text()) | |
})) | |
defer backendServer.Close() | |
backURL, _ := url.Parse(backendServer.URL) | |
rproxy := httputil.NewSingleHostReverseProxy(backURL) | |
rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests | |
rproxy.ModifyResponse = func(resp *http.Response) error { | |
resp.Header.Set("X-Important-Header", "HelloWorld") | |
return nil | |
} | |
frontendProxy := httptest.NewServer(rproxy) | |
defer frontendProxy.Close() | |
req, _ := http.NewRequest("GET", frontendProxy.URL, nil) | |
req.Header.Set("Connection", "Upgrade") | |
req.Header.Set("Upgrade", "websocket") | |
c := frontendProxy.Client() | |
res, err := c.Do(req) | |
if err != nil { | |
t.Fatal(err) | |
} | |
if res.StatusCode != 101 { | |
t.Fatalf("status = %v; want 101", res.Status) | |
} | |
if strings.ToLower(res.Header.Get("Upgrade")) != "websocket" { | |
t.Fatalf("not websocket upgrade; got %#v", res.Header) | |
} | |
if res.Header.Get("X-Important-Header") != "HelloWorld" { | |
t.Fatalf("missing/invalid custom header; got %#v", res.Header) | |
} | |
rwc, ok := res.Body.(io.ReadWriteCloser) | |
if !ok { | |
t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body) | |
} | |
defer rwc.Close() | |
io.WriteString(rwc, "Hello\n") | |
bs := bufio.NewScanner(rwc) | |
if !bs.Scan() { | |
t.Fatalf("Scan: %v", bs.Err()) | |
} | |
got := bs.Text() | |
want := `backend got "Hello"` | |
if got != want { | |
t.Errorf("got %#q, want %#q", got, want) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment