Forked from ilyaigpetrov/how-to-use-golang-framer.go
Last active
October 8, 2022 12:48
-
-
Save jochumdev/2dd7af64d368ed97c947b800eea1ee2a to your computer and use it in GitHub Desktop.
How To Use Golang Framer?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// This is based on https://github.com/nghttp2/nghttp2/blob/master/integration-tests/server_tester.go | |
package main | |
import ( | |
"bytes" | |
"crypto/tls" | |
"errors" | |
"fmt" | |
"log" | |
"net" | |
"net/http" | |
"os" | |
"sort" | |
"strconv" | |
"time" | |
"golang.org/x/net/http2" | |
"golang.org/x/net/http2/hpack" | |
) | |
func pair(name, value string) hpack.HeaderField { | |
return hpack.HeaderField{ | |
Name: name, | |
Value: value, | |
} | |
} | |
func cloneHeader(h http.Header) http.Header { | |
h2 := make(http.Header, len(h)) | |
for k, vv := range h { | |
vv2 := make([]string, len(vv)) | |
copy(vv2, vv) | |
h2[k] = vv2 | |
} | |
return h2 | |
} | |
func streamEnded(mainSr *serverResponse, streams map[uint32]*serverResponse, sr *serverResponse) bool { | |
delete(streams, sr.streamID) | |
if mainSr.streamID != sr.streamID { | |
mainSr.pushResponse = append(mainSr.pushResponse, sr) | |
} | |
return len(streams) == 0 | |
} | |
type ByStreamID []*serverResponse | |
func (b ByStreamID) Len() int { | |
return len(b) | |
} | |
func (b ByStreamID) Swap(i, j int) { | |
b[i], b[j] = b[j], b[i] | |
} | |
func (b ByStreamID) Less(i, j int) bool { | |
return b[i].streamID < b[j].streamID | |
} | |
type RequestParam struct { | |
StreamID uint32 // stream ID, automatically assigned if 0 | |
Method string // method, defaults to GET | |
Scheme string // scheme, defaults to https | |
Path string // path, defaults to / | |
Header []hpack.HeaderField // additional request header fields | |
Body []byte // request body | |
Trailer []hpack.HeaderField // trailer part | |
HttpUpgrade bool // true if upgraded to HTTP/2 through HTTP Upgrade | |
NoEndStream bool // true if END_STREAM should not be sent | |
} | |
type serverResponse struct { | |
status int // HTTP status code | |
header http.Header // response header fields | |
body []byte // response body | |
streamID uint32 // stream ID in HTTP/2 | |
errCode http2.ErrCode // error code received in HTTP/2 RST_STREAM or GOAWAY | |
connErr bool // true if HTTP/2 connection error | |
reqHeader http.Header // http request header, currently only sotres pushed request header | |
pushResponse []*serverResponse // pushed response | |
} | |
func NewClient(authority string, tlsConfig *tls.Config) (*Client, error) { | |
c := &Client{ | |
authority: authority, | |
nextStreamID: 1, | |
frCh: make(chan http2.Frame), | |
errCh: make(chan error), | |
} | |
// Setup tls.Config | |
if tlsConfig == nil { | |
tlsConfig = &tls.Config{ | |
InsecureSkipVerify: true, | |
} | |
} | |
tlsConfig.NextProtos = []string{http2.NextProtoTLS} | |
conn, err := tls.Dial("tcp", c.authority, tlsConfig) | |
if err != nil { | |
return nil, err | |
} | |
c.conn = conn | |
err = conn.Handshake() | |
if err == nil { | |
return nil, err | |
} | |
c.fr = http2.NewFramer(c.conn, c.conn) | |
c.enc = hpack.NewEncoder(&c.headerBlkBuf) | |
c.dec = hpack.NewDecoder(4096, func(f hpack.HeaderField) { | |
c.header.Add(f.Name, f.Value) | |
}) | |
return c, nil | |
} | |
type Client struct { | |
conn net.Conn // connection to frontend server | |
h2PrefaceSent bool // HTTP/2 preface was sent in conn | |
nextStreamID uint32 // next stream ID | |
fr *http2.Framer // HTTP/2 framer | |
headerBlkBuf bytes.Buffer // buffer to store encoded header block | |
enc *hpack.Encoder // HTTP/2 HPACK encoder | |
header http.Header // received header fields | |
dec *hpack.Decoder // HTTP/2 HPACK decoder | |
authority string // server's host:port | |
frCh chan http2.Frame // used for incoming HTTP/2 frame | |
errCh chan error | |
} | |
func (c *Client) Close() { | |
if c.conn != nil { | |
c.conn.Close() | |
} | |
} | |
func (c *Client) readFrame() (http2.Frame, error) { | |
go func() { | |
f, err := c.fr.ReadFrame() | |
if err != nil { | |
c.errCh <- err | |
return | |
} | |
c.frCh <- f | |
}() | |
select { | |
case f := <-c.frCh: | |
return f, nil | |
case err := <-c.errCh: | |
return nil, err | |
case <-time.After(5 * time.Second): | |
return nil, errors.New("timeout waiting for frame") | |
} | |
} | |
func (c *Client) Do(rp RequestParam) (*serverResponse, error) { | |
c.headerBlkBuf.Reset() | |
c.header = make(http.Header) | |
var id uint32 | |
if rp.StreamID != 0 { | |
id = rp.StreamID | |
if id >= c.nextStreamID && id%2 == 1 { | |
c.nextStreamID = id + 2 | |
} | |
} else { | |
id = c.nextStreamID | |
c.nextStreamID += 2 | |
} | |
if !c.h2PrefaceSent { | |
c.h2PrefaceSent = true | |
fmt.Fprint(c.conn, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") | |
if err := c.fr.WriteSettings(); err != nil { | |
return nil, err | |
} | |
} | |
res := &serverResponse{ | |
streamID: id, | |
} | |
streams := make(map[uint32]*serverResponse) | |
streams[id] = res | |
if !rp.HttpUpgrade { | |
method := "GET" | |
if rp.Method != "" { | |
method = rp.Method | |
} | |
_ = c.enc.WriteField(pair(":method", method)) | |
scheme := "https" | |
if rp.Scheme != "" { | |
scheme = rp.Scheme | |
} | |
_ = c.enc.WriteField(pair(":scheme", scheme)) | |
_ = c.enc.WriteField(pair(":authority", c.authority)) | |
path := "/" | |
if rp.Path != "" { | |
path = rp.Path | |
} | |
_ = c.enc.WriteField(pair(":path", path)) | |
for _, h := range rp.Header { | |
_ = c.enc.WriteField(h) | |
} | |
err := c.fr.WriteHeaders(http2.HeadersFrameParam{ | |
StreamID: id, | |
EndStream: len(rp.Body) == 0 && len(rp.Trailer) == 0 && !rp.NoEndStream, | |
EndHeaders: true, | |
BlockFragment: c.headerBlkBuf.Bytes(), | |
}) | |
if err != nil { | |
return nil, err | |
} | |
if len(rp.Body) != 0 { | |
// TODO we assume rp.body fits in 1 frame | |
if err := c.fr.WriteData(id, len(rp.Trailer) == 0 && !rp.NoEndStream, rp.Body); err != nil { | |
return nil, err | |
} | |
} | |
if len(rp.Trailer) != 0 { | |
c.headerBlkBuf.Reset() | |
for _, h := range rp.Trailer { | |
_ = c.enc.WriteField(h) | |
} | |
err := c.fr.WriteHeaders(http2.HeadersFrameParam{ | |
StreamID: id, | |
EndStream: true, | |
EndHeaders: true, | |
BlockFragment: c.headerBlkBuf.Bytes(), | |
}) | |
if err != nil { | |
return nil, err | |
} | |
} | |
} | |
loop: | |
for { | |
fr, err := c.readFrame() | |
if err != nil { | |
return res, err | |
} | |
switch f := fr.(type) { | |
case *http2.HeadersFrame: | |
_, err := c.dec.Write(f.HeaderBlockFragment()) | |
if err != nil { | |
return res, err | |
} | |
sr, ok := streams[f.FrameHeader.StreamID] | |
if !ok { | |
c.header = make(http.Header) | |
break | |
} | |
sr.header = cloneHeader(c.header) | |
var status int | |
status, err = strconv.Atoi(sr.header.Get(":status")) | |
if err != nil { | |
return res, fmt.Errorf("error parsing status code: %w", err) | |
} | |
sr.status = status | |
if f.StreamEnded() { | |
if streamEnded(res, streams, sr) { | |
break loop | |
} | |
} | |
case *http2.PushPromiseFrame: | |
_, err := c.dec.Write(f.HeaderBlockFragment()) | |
if err != nil { | |
return res, err | |
} | |
sr := &serverResponse{ | |
streamID: f.PromiseID, | |
reqHeader: cloneHeader(c.header), | |
} | |
streams[sr.streamID] = sr | |
case *http2.DataFrame: | |
sr, ok := streams[f.FrameHeader.StreamID] | |
if !ok { | |
break | |
} | |
sr.body = append(sr.body, f.Data()...) | |
if f.StreamEnded() { | |
if streamEnded(res, streams, sr) { | |
break loop | |
} | |
} | |
case *http2.RSTStreamFrame: | |
sr, ok := streams[f.FrameHeader.StreamID] | |
if !ok { | |
break | |
} | |
sr.errCode = f.ErrCode | |
if streamEnded(res, streams, sr) { | |
break loop | |
} | |
case *http2.GoAwayFrame: | |
if f.ErrCode == http2.ErrCodeNo { | |
break | |
} | |
res.errCode = f.ErrCode | |
res.connErr = true | |
break loop | |
case *http2.SettingsFrame: | |
if f.IsAck() { | |
break | |
} | |
if err := c.fr.WriteSettingsAck(); err != nil { | |
return res, err | |
} | |
} | |
} | |
sort.Sort(ByStreamID(res.pushResponse)) | |
return res, nil | |
} | |
func main() { | |
os.Setenv("GODEBUG", "http2debug=1") | |
c, err := NewClient("www.google.com:443", nil) | |
if err != nil { | |
log.Fatal(err) | |
} | |
sr, err := c.Do(RequestParam{ | |
Method: http.MethodGet, | |
Path: "/", | |
}) | |
if err != nil { | |
log.Fatal(err) | |
} | |
if sr.status != http.StatusOK { | |
fmt.Println("Bad status") | |
} | |
fmt.Print(string(sr.body)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment