Created
December 28, 2018 06:24
-
-
Save adamgoose/005cda8d9ce4030047e05e5b07953fd4 to your computer and use it in GitHub Desktop.
A sloppy attempt at a GraphQL Subscription Client for golang
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package subs | |
import ( | |
"context" | |
"encoding/json" | |
"log" | |
"github.com/gorilla/websocket" | |
uuid "github.com/satori/go.uuid" | |
) | |
type Subs struct { | |
conn *websocket.Conn | |
Context context.Context | |
subs map[uuid.UUID]chan json.RawMessage | |
} | |
func New(url string) (s *Subs, err error) { | |
ctx, cancel := context.WithCancel(context.Background()) | |
s = &Subs{ | |
subs: make(map[uuid.UUID]chan json.RawMessage), | |
Context: ctx, | |
} | |
d := websocket.DefaultDialer | |
d.Subprotocols = []string{"graphql-ws"} | |
s.conn, _, err = d.DialContext(s.Context, url, nil) | |
if err != nil { | |
return nil, err | |
} | |
go func() { | |
for { | |
m := SubscriptionMessage{} | |
if err := s.conn.ReadJSON(&m); err != nil { | |
log.Println(err) | |
cancel() | |
return | |
} | |
id, err := uuid.FromString(m.ID) | |
if err != nil { | |
continue | |
} | |
if m.Type != DATA { | |
continue | |
} | |
if c, ok := s.subs[id]; ok { | |
c <- m.Payload.Data | |
} | |
} | |
}() | |
if err := s.conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"connection_init","payload":{}}`)); err != nil { | |
return nil, err | |
} | |
return s, nil | |
} | |
// SubStruct enables you to provide your query in a shurcooL/graphql-compliant way. | |
// Values passed on the return chan can safely be type-asserted to the type of query | |
func (s *Subs) SubStruct(query interface{}, variables map[string]interface{}) (chan interface{}, context.CancelFunc) { | |
newChan := make(chan interface{}, 50) | |
q := constructSubscription(query, variables) | |
c, cancel := s.SubJSON("", q, variables) | |
go func() { | |
for { | |
msg := <-c | |
if err := UnmarshalGraphQL(msg, query); err == nil { | |
newChan <- query | |
} else { | |
log.Println("Unable to unmarshal graphql: ", err) | |
} | |
} | |
}() | |
return newChan, cancel | |
} | |
// SubJSON enables you to provide your query in a string format. | |
// Values passed on the return chan contain a JSON representation of | |
// the "data" attribute returned by the GraphQL server | |
func (s *Subs) SubJSON(oprationName, subscription string, variables map[string]interface{}) (chan json.RawMessage, context.CancelFunc) { | |
id := uuid.Must(uuid.NewV4()) | |
sub := Subscription{ | |
ID: id.String(), | |
Type: START, | |
Payload: GraphQLPayload{ | |
OperationName: oprationName, | |
Query: subscription, | |
Variables: variables, | |
}, | |
} | |
if err := s.conn.WriteJSON(sub); err != nil { | |
log.Fatal(err) | |
} | |
c := make(chan json.RawMessage, 1) | |
s.subs[id] = c | |
ctx, cancel := context.WithCancel(context.Background()) | |
go func() { | |
<-ctx.Done() | |
sub.Type = STOP | |
delete(s.subs, id) | |
close(c) | |
}() | |
return c, cancel | |
} | |
func (s Subs) Done() <-chan struct{} { | |
return s.Context.Done() | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Package jsonutil provides a function for decoding JSON | |
// into a GraphQL query data structure. | |
// ====== borrowed from https://github.com/shurcooL/graphql/blob/master/internal/jsonutil/graphql.go ====== | |
package subs | |
import ( | |
"bytes" | |
"encoding/json" | |
"errors" | |
"fmt" | |
"io" | |
"reflect" | |
"strings" | |
) | |
// UnmarshalGraphQL parses the JSON-encoded GraphQL response data and stores | |
// the result in the GraphQL query data structure pointed to by v. | |
// | |
// The implementation is created on top of the JSON tokenizer available | |
// in "encoding/json".Decoder. | |
func UnmarshalGraphQL(data []byte, v interface{}) error { | |
dec := json.NewDecoder(bytes.NewReader(data)) | |
dec.UseNumber() | |
err := (&decoder{tokenizer: dec}).Decode(v) | |
if err != nil { | |
return err | |
} | |
tok, err := dec.Token() | |
switch err { | |
case io.EOF: | |
// Expect to get io.EOF. There shouldn't be any more | |
// tokens left after we've decoded v successfully. | |
return nil | |
case nil: | |
return fmt.Errorf("invalid token '%v' after top-level value", tok) | |
default: | |
return err | |
} | |
} | |
// decoder is a JSON decoder that performs custom unmarshaling behavior | |
// for GraphQL query data structures. It's implemented on top of a JSON tokenizer. | |
type decoder struct { | |
tokenizer interface { | |
Token() (json.Token, error) | |
} | |
// Stack of what part of input JSON we're in the middle of - objects, arrays. | |
parseState []json.Delim | |
// Stacks of values where to unmarshal. | |
// The top of each stack is the reflect.Value where to unmarshal next JSON value. | |
// | |
// The reason there's more than one stack is because we might be unmarshaling | |
// a single JSON value into multiple GraphQL fragments or embedded structs, so | |
// we keep track of them all. | |
vs [][]reflect.Value | |
} | |
// Decode decodes a single JSON value from d.tokenizer into v. | |
func (d *decoder) Decode(v interface{}) error { | |
rv := reflect.ValueOf(v) | |
if rv.Kind() != reflect.Ptr { | |
return fmt.Errorf("cannot decode into non-pointer %T", v) | |
} | |
d.vs = [][]reflect.Value{{rv.Elem()}} | |
return d.decode() | |
} | |
// decode decodes a single JSON value from d.tokenizer into d.vs. | |
func (d *decoder) decode() error { | |
// The loop invariant is that the top of each d.vs stack | |
// is where we try to unmarshal the next JSON value we see. | |
for len(d.vs) > 0 { | |
tok, err := d.tokenizer.Token() | |
if err == io.EOF { | |
return errors.New("unexpected end of JSON input") | |
} else if err != nil { | |
return err | |
} | |
switch { | |
// Are we inside an object and seeing next key (rather than end of object)? | |
case d.state() == '{' && tok != json.Delim('}'): | |
key, ok := tok.(string) | |
if !ok { | |
return errors.New("unexpected non-key in JSON input") | |
} | |
someFieldExist := false | |
for i := range d.vs { | |
v := d.vs[i][len(d.vs[i])-1] | |
if v.Kind() == reflect.Ptr { | |
v = v.Elem() | |
} | |
var f reflect.Value | |
if v.Kind() == reflect.Struct { | |
f = fieldByGraphQLName(v, key) | |
if f.IsValid() { | |
someFieldExist = true | |
} | |
} | |
d.vs[i] = append(d.vs[i], f) | |
} | |
if !someFieldExist { | |
return fmt.Errorf("struct field for %s doesn't exist in any of %v places to unmarshal", key, len(d.vs)) | |
} | |
// We've just consumed the current token, which was the key. | |
// Read the next token, which should be the value, and let the rest of code process it. | |
tok, err = d.tokenizer.Token() | |
if err == io.EOF { | |
return errors.New("unexpected end of JSON input") | |
} else if err != nil { | |
return err | |
} | |
// Are we inside an array and seeing next value (rather than end of array)? | |
case d.state() == '[' && tok != json.Delim(']'): | |
someSliceExist := false | |
for i := range d.vs { | |
v := d.vs[i][len(d.vs[i])-1] | |
if v.Kind() == reflect.Ptr { | |
v = v.Elem() | |
} | |
var f reflect.Value | |
if v.Kind() == reflect.Slice { | |
v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem()))) // v = append(v, T). | |
f = v.Index(v.Len() - 1) | |
someSliceExist = true | |
} | |
d.vs[i] = append(d.vs[i], f) | |
} | |
if !someSliceExist { | |
return fmt.Errorf("slice doesn't exist in any of %v places to unmarshal", len(d.vs)) | |
} | |
} | |
switch tok := tok.(type) { | |
case string, json.Number, bool, nil: | |
// Value. | |
for i := range d.vs { | |
v := d.vs[i][len(d.vs[i])-1] | |
if !v.IsValid() { | |
continue | |
} | |
err := unmarshalValue(tok, v) | |
if err != nil { | |
return err | |
} | |
} | |
d.popAllVs() | |
case json.Delim: | |
switch tok { | |
case '{': | |
// Start of object. | |
d.pushState(tok) | |
frontier := make([]reflect.Value, len(d.vs)) // Places to look for GraphQL fragments/embedded structs. | |
for i := range d.vs { | |
v := d.vs[i][len(d.vs[i])-1] | |
frontier[i] = v | |
// TODO: Do this recursively or not? Add a test case if needed. | |
if v.Kind() == reflect.Ptr && v.IsNil() { | |
v.Set(reflect.New(v.Type().Elem())) // v = new(T). | |
} | |
} | |
// Find GraphQL fragments/embedded structs recursively, adding to frontier | |
// as new ones are discovered and exploring them further. | |
for len(frontier) > 0 { | |
v := frontier[0] | |
frontier = frontier[1:] | |
if v.Kind() == reflect.Ptr { | |
v = v.Elem() | |
} | |
if v.Kind() != reflect.Struct { | |
continue | |
} | |
for i := 0; i < v.NumField(); i++ { | |
if isGraphQLFragment(v.Type().Field(i)) || v.Type().Field(i).Anonymous { | |
// Add GraphQL fragment or embedded struct. | |
d.vs = append(d.vs, []reflect.Value{v.Field(i)}) | |
frontier = append(frontier, v.Field(i)) | |
} | |
} | |
} | |
case '[': | |
// Start of array. | |
d.pushState(tok) | |
for i := range d.vs { | |
v := d.vs[i][len(d.vs[i])-1] | |
// TODO: Confirm this is needed, write a test case. | |
//if v.Kind() == reflect.Ptr && v.IsNil() { | |
// v.Set(reflect.New(v.Type().Elem())) // v = new(T). | |
//} | |
// Reset slice to empty (in case it had non-zero initial value). | |
if v.Kind() == reflect.Ptr { | |
v = v.Elem() | |
} | |
if v.Kind() != reflect.Slice { | |
continue | |
} | |
v.Set(reflect.MakeSlice(v.Type(), 0, 0)) // v = make(T, 0, 0). | |
} | |
case '}', ']': | |
// End of object or array. | |
d.popAllVs() | |
d.popState() | |
default: | |
return errors.New("unexpected delimiter in JSON input") | |
} | |
default: | |
return errors.New("unexpected token in JSON input") | |
} | |
} | |
return nil | |
} | |
// pushState pushes a new parse state s onto the stack. | |
func (d *decoder) pushState(s json.Delim) { | |
d.parseState = append(d.parseState, s) | |
} | |
// popState pops a parse state (already obtained) off the stack. | |
// The stack must be non-empty. | |
func (d *decoder) popState() { | |
d.parseState = d.parseState[:len(d.parseState)-1] | |
} | |
// state reports the parse state on top of stack, or 0 if empty. | |
func (d *decoder) state() json.Delim { | |
if len(d.parseState) == 0 { | |
return 0 | |
} | |
return d.parseState[len(d.parseState)-1] | |
} | |
// popAllVs pops from all d.vs stacks, keeping only non-empty ones. | |
func (d *decoder) popAllVs() { | |
var nonEmpty [][]reflect.Value | |
for i := range d.vs { | |
d.vs[i] = d.vs[i][:len(d.vs[i])-1] | |
if len(d.vs[i]) > 0 { | |
nonEmpty = append(nonEmpty, d.vs[i]) | |
} | |
} | |
d.vs = nonEmpty | |
} | |
// fieldByGraphQLName returns a struct field of struct v that matches GraphQL name, | |
// or invalid reflect.Value if none found. | |
func fieldByGraphQLName(v reflect.Value, name string) reflect.Value { | |
for i := 0; i < v.NumField(); i++ { | |
if hasGraphQLName(v.Type().Field(i), name) { | |
return v.Field(i) | |
} | |
} | |
return reflect.Value{} | |
} | |
// hasGraphQLName reports whether struct field f has GraphQL name. | |
func hasGraphQLName(f reflect.StructField, name string) bool { | |
value, ok := f.Tag.Lookup("graphql") | |
if !ok { | |
// TODO: caseconv package is relatively slow. Optimize it, then consider using it here. | |
//return caseconv.MixedCapsToLowerCamelCase(f.Name) == name | |
return strings.EqualFold(f.Name, name) | |
} | |
value = strings.TrimSpace(value) // TODO: Parse better. | |
if strings.HasPrefix(value, "...") { | |
// GraphQL fragment. It doesn't have a name. | |
return false | |
} | |
if i := strings.Index(value, "("); i != -1 { | |
value = value[:i] | |
} | |
if i := strings.Index(value, ":"); i != -1 { | |
value = value[:i] | |
} | |
return strings.TrimSpace(value) == name | |
} | |
// isGraphQLFragment reports whether struct field f is a GraphQL fragment. | |
func isGraphQLFragment(f reflect.StructField) bool { | |
value, ok := f.Tag.Lookup("graphql") | |
if !ok { | |
return false | |
} | |
value = strings.TrimSpace(value) // TODO: Parse better. | |
return strings.HasPrefix(value, "...") | |
} | |
// unmarshalValue unmarshals JSON value into v. | |
func unmarshalValue(value json.Token, v reflect.Value) error { | |
b, err := json.Marshal(value) // TODO: Short-circuit (if profiling says it's worth it). | |
if err != nil { | |
return err | |
} | |
if !v.CanAddr() { | |
return fmt.Errorf("value %v is not addressable", v) | |
} | |
return json.Unmarshal(b, v.Addr().Interface()) | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// ======= borrowed from https://github.com/shurcooL/graphql/blob/master/query.go ======= | |
package subs | |
import ( | |
"bytes" | |
"encoding/json" | |
"io" | |
"reflect" | |
"sort" | |
"github.com/shurcooL/graphql/ident" | |
) | |
func constructQuery(v interface{}, variables map[string]interface{}) string { | |
query := query(v) | |
if len(variables) > 0 { | |
return "query(" + queryArguments(variables) + ")" + query | |
} | |
return query | |
} | |
func constructMutation(v interface{}, variables map[string]interface{}) string { | |
query := query(v) | |
if len(variables) > 0 { | |
return "mutation(" + queryArguments(variables) + ")" + query | |
} | |
return "mutation" + query | |
} | |
func constructSubscription(v interface{}, variables map[string]interface{}) string { | |
query := query(v) | |
if len(variables) > 0 { | |
return "subscription(" + queryArguments(variables) + ")" + query | |
} | |
return "subscription" + query | |
} | |
// queryArguments constructs a minified arguments string for variables. | |
// | |
// E.g., map[string]interface{}{"a": Int(123), "b": NewBoolean(true)} -> "$a:Int!$b:Boolean". | |
func queryArguments(variables map[string]interface{}) string { | |
// Sort keys in order to produce deterministic output for testing purposes. | |
// TODO: If tests can be made to work with non-deterministic output, then no need to sort. | |
keys := make([]string, 0, len(variables)) | |
for k := range variables { | |
keys = append(keys, k) | |
} | |
sort.Strings(keys) | |
var buf bytes.Buffer | |
for _, k := range keys { | |
io.WriteString(&buf, "$") | |
io.WriteString(&buf, k) | |
io.WriteString(&buf, ":") | |
writeArgumentType(&buf, reflect.TypeOf(variables[k]), true) | |
// Don't insert a comma here. | |
// Commas in GraphQL are insignificant, and we want minified output. | |
// See https://facebook.github.io/graphql/October2016/#sec-Insignificant-Commas. | |
} | |
return buf.String() | |
} | |
// writeArgumentType writes a minified GraphQL type for t to w. | |
// value indicates whether t is a value (required) type or pointer (optional) type. | |
// If value is true, then "!" is written at the end of t. | |
func writeArgumentType(w io.Writer, t reflect.Type, value bool) { | |
if t.Kind() == reflect.Ptr { | |
// Pointer is an optional type, so no "!" at the end of the pointer's underlying type. | |
writeArgumentType(w, t.Elem(), false) | |
return | |
} | |
switch t.Kind() { | |
case reflect.Slice, reflect.Array: | |
// List. E.g., "[Int]". | |
io.WriteString(w, "[") | |
writeArgumentType(w, t.Elem(), true) | |
io.WriteString(w, "]") | |
default: | |
// Named type. E.g., "Int". | |
name := t.Name() | |
if name == "string" { // HACK: Workaround for https://github.com/shurcooL/githubv4/issues/12. | |
name = "ID" | |
} | |
io.WriteString(w, name) | |
} | |
if value { | |
// Value is a required type, so add "!" to the end. | |
io.WriteString(w, "!") | |
} | |
} | |
// query uses writeQuery to recursively construct | |
// a minified query string from the provided struct v. | |
// | |
// E.g., struct{Foo Int, BarBaz *Boolean} -> "{foo,barBaz}". | |
func query(v interface{}) string { | |
var buf bytes.Buffer | |
writeQuery(&buf, reflect.TypeOf(v), false) | |
return buf.String() | |
} | |
// writeQuery writes a minified query for t to w. | |
// If inline is true, the struct fields of t are inlined into parent struct. | |
func writeQuery(w io.Writer, t reflect.Type, inline bool) { | |
switch t.Kind() { | |
case reflect.Ptr, reflect.Slice: | |
writeQuery(w, t.Elem(), false) | |
case reflect.Struct: | |
// If the type implements json.Unmarshaler, it's a scalar. Don't expand it. | |
if reflect.PtrTo(t).Implements(jsonUnmarshaler) { | |
return | |
} | |
if !inline { | |
io.WriteString(w, "{") | |
} | |
for i := 0; i < t.NumField(); i++ { | |
if i != 0 { | |
io.WriteString(w, ",") | |
} | |
f := t.Field(i) | |
value, ok := f.Tag.Lookup("graphql") | |
inlineField := f.Anonymous && !ok | |
if !inlineField { | |
if ok { | |
io.WriteString(w, value) | |
} else { | |
io.WriteString(w, ident.ParseMixedCaps(f.Name).ToLowerCamelCase()) | |
} | |
} | |
writeQuery(w, f.Type, inlineField) | |
} | |
if !inline { | |
io.WriteString(w, "}") | |
} | |
} | |
} | |
var jsonUnmarshaler = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package subs | |
import "encoding/json" | |
type GraphQLPayload struct { | |
OperationName string `json:"operationName"` | |
Query string `json:"query"` | |
Variables map[string]interface{} `json:"variables"` | |
Extensions map[string]interface{} `json:"extensions"` | |
} | |
type SubscriptionAction string | |
const ( | |
START SubscriptionAction = "start" | |
STOP SubscriptionAction = "stop" | |
DATA SubscriptionAction = "data" | |
ERROR SubscriptionAction = "error" | |
) | |
type Subscription struct { | |
ID string `json:"id"` | |
Type SubscriptionAction `json:"type"` | |
Payload GraphQLPayload `json:"payload"` | |
} | |
type SubscriptionMessage struct { | |
ID string `json:"id"` | |
Type SubscriptionAction `json:"type"` | |
Payload struct { | |
Data json.RawMessage `json:"data"` | |
} `json:"payload"` | |
// Payload json.RawMessage `json:"payload"` | |
Errors json.RawMessage `json:"errors"` | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment