Skip to content

Commit 7fb9b09

Browse files
subscription: check all fields until find method
1 parent 2198f0b commit 7fb9b09

File tree

1 file changed

+84
-54
lines changed

1 file changed

+84
-54
lines changed

internal/exec/subscribe.go

Lines changed: 84 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/golangid/graphql-go/internal/exec/resolvable"
1414
"github.com/golangid/graphql-go/internal/exec/selected"
1515
"github.com/golangid/graphql-go/internal/query"
16+
"github.com/golangid/graphql-go/internal/schema"
1617
)
1718

1819
type Response struct {
@@ -22,66 +23,30 @@ type Response struct {
2223

2324
func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *query.Operation) <-chan *Response {
2425
var result reflect.Value
25-
var f *fieldToExec
2626
var err *errors.QueryError
27-
func() {
28-
defer r.handlePanic(ctx)
29-
30-
sels := selected.ApplyOperation(&r.Request, s, op)
31-
var fields []*fieldToExec
32-
collectFieldsToResolve(sels, s, s.ResolverSubscription, &fields, make(map[string]*fieldToExec))
33-
34-
// TODO: move this check into validation.Validate
35-
if len(fields) != 1 {
36-
err = errors.Errorf("%s", "can subscribe to at most one subscription at a time")
37-
return
38-
}
39-
f = fields[0]
40-
41-
// TODO: add check all childs
42-
func() {
43-
tmpF := *f
44-
defer func() {
45-
if r := recover(); r != nil {
46-
f = &tmpF
47-
}
48-
}()
49-
50-
if f.resolver.Kind() == reflect.Ptr {
51-
f.resolver = f.resolver.Elem()
52-
}
53-
f.resolver = f.resolver.FieldByIndex(f.field.FieldIndex)
54-
55-
var fieldsDeep []*fieldToExec
56-
collectFieldsToResolve(f.sels, s, f.resolver, &fieldsDeep, make(map[string]*fieldToExec))
57-
if len(fieldsDeep) == 1 {
58-
f = fieldsDeep[0]
59-
} else {
60-
f = &tmpF
61-
}
62-
}()
63-
64-
var in []reflect.Value
65-
if f.field.HasContext {
66-
in = append(in, reflect.ValueOf(ctx))
67-
}
68-
if f.field.ArgsPacker != nil {
69-
in = append(in, f.field.PackedArgs)
70-
}
71-
callOut := f.resolver.Method(f.field.MethodIndex).Call(in)
72-
result = callOut[0]
73-
74-
if f.field.HasError && !callOut[1].IsNil() {
75-
resolverErr := callOut[1].Interface().(error)
76-
err = errors.Errorf("%s", resolverErr)
77-
err.ResolverError = resolverErr
78-
}
79-
}()
27+
sels := selected.ApplyOperation(&r.Request, s, op)
8028

29+
f := r.subscriptionSearchFieldMethod(ctx, sels, nil, s, s.ResolverSubscription)
8130
if f == nil {
8231
return sendAndReturnClosed(&Response{Errors: []*errors.QueryError{err}})
8332
}
8433

34+
var in []reflect.Value
35+
if f.field.HasContext {
36+
in = append(in, reflect.ValueOf(ctx))
37+
}
38+
if f.field.ArgsPacker != nil {
39+
in = append(in, f.field.PackedArgs)
40+
}
41+
callOut := f.resolver.Method(f.field.MethodIndex).Call(in)
42+
result = callOut[0]
43+
44+
if f.field.HasError && !callOut[1].IsNil() {
45+
resolverErr := callOut[1].Interface().(error)
46+
err = errors.Errorf("%s", resolverErr)
47+
err.ResolverError = resolverErr
48+
}
49+
8550
if err != nil {
8651
if _, nonNullChild := f.field.Type.(*common.NonNull); nonNullChild {
8752
return sendAndReturnClosed(&Response{Errors: []*errors.QueryError{err}})
@@ -186,3 +151,68 @@ func sendAndReturnClosed(resp *Response) chan *Response {
186151
close(c)
187152
return c
188153
}
154+
155+
func (r *Request) subscriptionSearchFieldMethod(ctx context.Context, sels []selected.Selection, path *pathSegment, s *resolvable.Schema, resolver reflect.Value) (foundField *fieldToExec) {
156+
157+
var collectFields func(sels []selected.Selection, path *pathSegment, s *resolvable.Schema, resolver reflect.Value)
158+
var execField func(s *resolvable.Schema, f *fieldToExec, path *pathSegment)
159+
var execSelectionSet func(sels []selected.Selection, typ common.Type, path *pathSegment, s *resolvable.Schema, resolver reflect.Value)
160+
161+
collectFields = func(sels []selected.Selection, path *pathSegment, s *resolvable.Schema, resolver reflect.Value) {
162+
var fields []*fieldToExec
163+
collectFieldsToResolve(sels, s, resolver, &fields, make(map[string]*fieldToExec))
164+
165+
for _, f := range fields {
166+
execField(s, f, &pathSegment{path, f.field.Alias})
167+
if f.field.UseMethodResolver() && !f.field.FixedResult.IsValid() {
168+
foundField = f
169+
return
170+
}
171+
}
172+
}
173+
174+
execField = func(s *resolvable.Schema, f *fieldToExec, path *pathSegment) {
175+
var result reflect.Value
176+
177+
if f.field.FixedResult.IsValid() {
178+
result = f.field.FixedResult
179+
return
180+
}
181+
182+
res := f.resolver
183+
if !f.field.UseMethodResolver() {
184+
res = unwrapPtr(res)
185+
result = res.FieldByIndex(f.field.FieldIndex)
186+
}
187+
188+
execSelectionSet(f.sels, f.field.Type, path, s, result)
189+
}
190+
191+
execSelectionSet = func(sels []selected.Selection, typ common.Type, path *pathSegment, s *resolvable.Schema, resolver reflect.Value) {
192+
t, nonNull := unwrapNonNull(typ)
193+
switch t := t.(type) {
194+
case *schema.Object, *schema.Interface, *schema.Union:
195+
if resolver.Kind() == reflect.Invalid || ((resolver.Kind() == reflect.Ptr || resolver.Kind() == reflect.Interface) && resolver.IsNil()) {
196+
if nonNull {
197+
err := errors.Errorf("graphql: got nil for non-null %q", t)
198+
err.Path = path.toSlice()
199+
r.AddError(err)
200+
}
201+
return
202+
}
203+
204+
collectFields(sels, path, s, resolver)
205+
return
206+
}
207+
}
208+
209+
collectFields(sels, path, s, resolver)
210+
return
211+
}
212+
213+
func unwrapPtr(v reflect.Value) reflect.Value {
214+
if v.Kind() == reflect.Ptr {
215+
return v.Elem()
216+
}
217+
return v
218+
}

0 commit comments

Comments
 (0)