diff --git a/graphql.go b/graphql.go index a2f21f84..88dbec2a 100644 --- a/graphql.go +++ b/graphql.go @@ -18,6 +18,7 @@ import ( "github.com/graph-gophers/graphql-go/introspection" "github.com/graph-gophers/graphql-go/log" "github.com/graph-gophers/graphql-go/trace" + "github.com/graph-gophers/graphql-go/types" ) // ParseSchema parses a GraphQL schema and attaches the given root resolver. It returns an error if @@ -42,7 +43,7 @@ func ParseSchema(schemaString string, resolver interface{}, opts ...SchemaOpt) ( } } - if err := s.schema.Parse(schemaString, s.useStringDescriptions); err != nil { + if err := schema.Parse(s.schema, schemaString, s.useStringDescriptions); err != nil { return nil, err } if err := s.validateSchema(); err != nil { @@ -69,7 +70,7 @@ func MustParseSchema(schemaString string, resolver interface{}, opts ...SchemaOp // Schema represents a GraphQL schema with an optional resolver. type Schema struct { - schema *schema.Schema + schema *types.Schema res *resolvable.Schema maxDepth int @@ -82,6 +83,10 @@ type Schema struct { subscribeResolverTimeout time.Duration } +func (s *Schema) ASTSchema() *types.Schema { + return s.schema +} + // SchemaOpt is an option to pass to ParseSchema or MustParseSchema. type SchemaOpt func(*Schema) @@ -228,7 +233,7 @@ func (s *Schema) exec(ctx context.Context, queryString string, operationName str } for _, v := range op.Vars { if _, ok := variables[v.Name.Name]; !ok && v.Default != nil { - variables[v.Name.Name] = v.Default.Value(nil) + variables[v.Name.Name] = v.Default.Deserialize(nil) } } @@ -288,7 +293,7 @@ func (t *validationBridgingTracer) TraceValidation(context.Context) trace.TraceV return t.tracer.TraceValidation() } -func validateRootOp(s *schema.Schema, name string, mandatory bool) error { +func validateRootOp(s *types.Schema, name string, mandatory bool) error { t, ok := s.EntryPoints[name] if !ok { if mandatory { @@ -302,7 +307,7 @@ func validateRootOp(s *schema.Schema, name string, mandatory bool) error { return nil } -func getOperation(document *query.Document, operationName string) (*query.Operation, error) { +func getOperation(document *types.ExecutableDefinition, operationName string) (*types.OperationDefinition, error) { if len(document.Operations) == 0 { return nil, fmt.Errorf("no operations in query document") } diff --git a/internal/common/directive.go b/internal/common/directive.go index 62dca47f..f767e28f 100644 --- a/internal/common/directive.go +++ b/internal/common/directive.go @@ -1,32 +1,18 @@ package common -type Directive struct { - Name Ident - Args ArgumentList -} +import "github.com/graph-gophers/graphql-go/types" -func ParseDirectives(l *Lexer) DirectiveList { - var directives DirectiveList +func ParseDirectives(l *Lexer) types.DirectiveList { + var directives types.DirectiveList for l.Peek() == '@' { l.ConsumeToken('@') - d := &Directive{} + d := &types.Directive{} d.Name = l.ConsumeIdentWithLoc() d.Name.Loc.Column-- if l.Peek() == '(' { - d.Args = ParseArguments(l) + d.Arguments = ParseArgumentList(l) } directives = append(directives, d) } return directives } - -type DirectiveList []*Directive - -func (l DirectiveList) Get(name string) *Directive { - for _, d := range l { - if d.Name.Name == name { - return d - } - } - return nil -} diff --git a/internal/common/lexer.go b/internal/common/lexer.go index af385ecc..49130686 100644 --- a/internal/common/lexer.go +++ b/internal/common/lexer.go @@ -8,6 +8,7 @@ import ( "text/scanner" "github.com/graph-gophers/graphql-go/errors" + "github.com/graph-gophers/graphql-go/types" ) type syntaxError string @@ -30,7 +31,6 @@ func NewLexer(s string, useStringDescriptions bool) *Lexer { } sc.Init(strings.NewReader(s)) - l := Lexer{sc: sc, useStringDescriptions: useStringDescriptions} l.sc.Error = l.CatchScannerError @@ -119,11 +119,11 @@ func (l *Lexer) ConsumeIdent() string { return name } -func (l *Lexer) ConsumeIdentWithLoc() Ident { +func (l *Lexer) ConsumeIdentWithLoc() types.Ident { loc := l.Location() name := l.sc.TokenText() l.ConsumeToken(scanner.Ident) - return Ident{name, loc} + return types.Ident{name, loc} } func (l *Lexer) ConsumeKeyword(keyword string) { @@ -133,8 +133,8 @@ func (l *Lexer) ConsumeKeyword(keyword string) { l.ConsumeWhitespace() } -func (l *Lexer) ConsumeLiteral() *BasicLit { - lit := &BasicLit{Type: l.next, Text: l.sc.TokenText()} +func (l *Lexer) ConsumeLiteral() *types.PrimitiveValue { + lit := &types.PrimitiveValue{Type: l.next, Text: l.sc.TokenText()} l.ConsumeWhitespace() return lit } diff --git a/internal/common/literals.go b/internal/common/literals.go index e7bbe263..9f3b979d 100644 --- a/internal/common/literals.go +++ b/internal/common/literals.go @@ -1,160 +1,12 @@ package common import ( - "strconv" - "strings" "text/scanner" - "github.com/graph-gophers/graphql-go/errors" + "github.com/graph-gophers/graphql-go/types" ) -type Literal interface { - Value(vars map[string]interface{}) interface{} - String() string - Location() errors.Location -} - -type BasicLit struct { - Type rune - Text string - Loc errors.Location -} - -func (lit *BasicLit) Value(vars map[string]interface{}) interface{} { - switch lit.Type { - case scanner.Int: - value, err := strconv.ParseInt(lit.Text, 10, 32) - if err != nil { - panic(err) - } - return int32(value) - - case scanner.Float: - value, err := strconv.ParseFloat(lit.Text, 64) - if err != nil { - panic(err) - } - return value - - case scanner.String: - value, err := strconv.Unquote(lit.Text) - if err != nil { - panic(err) - } - return value - - case scanner.Ident: - switch lit.Text { - case "true": - return true - case "false": - return false - default: - return lit.Text - } - - default: - panic("invalid literal") - } -} - -func (lit *BasicLit) String() string { - return lit.Text -} - -func (lit *BasicLit) Location() errors.Location { - return lit.Loc -} - -type ListLit struct { - Entries []Literal - Loc errors.Location -} - -func (lit *ListLit) Value(vars map[string]interface{}) interface{} { - entries := make([]interface{}, len(lit.Entries)) - for i, entry := range lit.Entries { - entries[i] = entry.Value(vars) - } - return entries -} - -func (lit *ListLit) String() string { - entries := make([]string, len(lit.Entries)) - for i, entry := range lit.Entries { - entries[i] = entry.String() - } - return "[" + strings.Join(entries, ", ") + "]" -} - -func (lit *ListLit) Location() errors.Location { - return lit.Loc -} - -type ObjectLit struct { - Fields []*ObjectLitField - Loc errors.Location -} - -type ObjectLitField struct { - Name Ident - Value Literal -} - -func (lit *ObjectLit) Value(vars map[string]interface{}) interface{} { - fields := make(map[string]interface{}, len(lit.Fields)) - for _, f := range lit.Fields { - fields[f.Name.Name] = f.Value.Value(vars) - } - return fields -} - -func (lit *ObjectLit) String() string { - entries := make([]string, 0, len(lit.Fields)) - for _, f := range lit.Fields { - entries = append(entries, f.Name.Name+": "+f.Value.String()) - } - return "{" + strings.Join(entries, ", ") + "}" -} - -func (lit *ObjectLit) Location() errors.Location { - return lit.Loc -} - -type NullLit struct { - Loc errors.Location -} - -func (lit *NullLit) Value(vars map[string]interface{}) interface{} { - return nil -} - -func (lit *NullLit) String() string { - return "null" -} - -func (lit *NullLit) Location() errors.Location { - return lit.Loc -} - -type Variable struct { - Name string - Loc errors.Location -} - -func (v Variable) Value(vars map[string]interface{}) interface{} { - return vars[v.Name] -} - -func (v Variable) String() string { - return "$" + v.Name -} - -func (v *Variable) Location() errors.Location { - return v.Loc -} - -func ParseLiteral(l *Lexer, constOnly bool) Literal { +func ParseLiteral(l *Lexer, constOnly bool) types.Value { loc := l.Location() switch l.Peek() { case '$': @@ -163,12 +15,12 @@ func ParseLiteral(l *Lexer, constOnly bool) Literal { panic("unreachable") } l.ConsumeToken('$') - return &Variable{l.ConsumeIdent(), loc} + return &types.Variable{l.ConsumeIdent(), loc} case scanner.Int, scanner.Float, scanner.String, scanner.Ident: lit := l.ConsumeLiteral() if lit.Type == scanner.Ident && lit.Text == "null" { - return &NullLit{loc} + return &types.NullValue{loc} } lit.Loc = loc return lit @@ -180,24 +32,24 @@ func ParseLiteral(l *Lexer, constOnly bool) Literal { return lit case '[': l.ConsumeToken('[') - var list []Literal + var list []types.Value for l.Peek() != ']' { list = append(list, ParseLiteral(l, constOnly)) } l.ConsumeToken(']') - return &ListLit{list, loc} + return &types.ListValue{list, loc} case '{': l.ConsumeToken('{') - var fields []*ObjectLitField + var fields []*types.ObjectField for l.Peek() != '}' { name := l.ConsumeIdentWithLoc() l.ConsumeToken(':') value := ParseLiteral(l, constOnly) - fields = append(fields, &ObjectLitField{name, value}) + fields = append(fields, &types.ObjectField{name, value}) } l.ConsumeToken('}') - return &ObjectLit{fields, loc} + return &types.ObjectValue{fields, loc} default: l.SyntaxError("invalid value") diff --git a/internal/common/types.go b/internal/common/types.go index a20ca309..4a30f46e 100644 --- a/internal/common/types.go +++ b/internal/common/types.go @@ -2,70 +2,57 @@ package common import ( "github.com/graph-gophers/graphql-go/errors" + "github.com/graph-gophers/graphql-go/types" ) -type Type interface { - Kind() string - String() string -} - -type List struct { - OfType Type -} - -type NonNull struct { - OfType Type -} - -type TypeName struct { - Ident -} - -func (*List) Kind() string { return "LIST" } -func (*NonNull) Kind() string { return "NON_NULL" } -func (*TypeName) Kind() string { panic("TypeName needs to be resolved to actual type") } - -func (t *List) String() string { return "[" + t.OfType.String() + "]" } -func (t *NonNull) String() string { return t.OfType.String() + "!" } -func (*TypeName) String() string { panic("TypeName needs to be resolved to actual type") } - -func ParseType(l *Lexer) Type { +func ParseType(l *Lexer) types.Type { t := parseNullType(l) if l.Peek() == '!' { l.ConsumeToken('!') - return &NonNull{OfType: t} + return &types.NonNull{OfType: t} } return t } -func parseNullType(l *Lexer) Type { +func parseNullType(l *Lexer) types.Type { if l.Peek() == '[' { l.ConsumeToken('[') ofType := ParseType(l) l.ConsumeToken(']') - return &List{OfType: ofType} + return &types.List{OfType: ofType} } - return &TypeName{Ident: l.ConsumeIdentWithLoc()} + return &types.TypeName{Ident: l.ConsumeIdentWithLoc()} } -type Resolver func(name string) Type +type Resolver func(name string) types.Type -func ResolveType(t Type, resolver Resolver) (Type, *errors.QueryError) { +// ResolveType attempts to resolve a type's name against a resolving function. +// This function is used when one needs to check if a TypeName exists in the resolver (typically a Schema). +// +// In the example below, ResolveType would be used to check if the resolving function +// returns a valid type for Dimension: +// +// type Profile { +// picture(dimensions: Dimension): Url +// } +// +// ResolveType recursively unwraps List and NonNull types until a NamedType is reached. +func ResolveType(t types.Type, resolver Resolver) (types.Type, *errors.QueryError) { switch t := t.(type) { - case *List: + case *types.List: ofType, err := ResolveType(t.OfType, resolver) if err != nil { return nil, err } - return &List{OfType: ofType}, nil - case *NonNull: + return &types.List{OfType: ofType}, nil + case *types.NonNull: ofType, err := ResolveType(t.OfType, resolver) if err != nil { return nil, err } - return &NonNull{OfType: ofType}, nil - case *TypeName: + return &types.NonNull{OfType: ofType}, nil + case *types.TypeName: refT := resolver(t.Name) if refT == nil { err := errors.Errorf("Unknown type %q.", t.Name) diff --git a/internal/common/values.go b/internal/common/values.go index f2af39e4..2d6e0b54 100644 --- a/internal/common/values.go +++ b/internal/common/values.go @@ -1,33 +1,11 @@ package common import ( - "github.com/graph-gophers/graphql-go/errors" + "github.com/graph-gophers/graphql-go/types" ) -// http://facebook.github.io/graphql/draft/#InputValueDefinition -type InputValue struct { - Name Ident - Type Type - Default Literal - Desc string - Directives DirectiveList - Loc errors.Location - TypeLoc errors.Location -} - -type InputValueList []*InputValue - -func (l InputValueList) Get(name string) *InputValue { - for _, v := range l { - if v.Name.Name == name { - return v - } - } - return nil -} - -func ParseInputValue(l *Lexer) *InputValue { - p := &InputValue{} +func ParseInputValue(l *Lexer) *types.InputValueDefinition { + p := &types.InputValueDefinition{} p.Loc = l.Location() p.Desc = l.DescComment() p.Name = l.ConsumeIdentWithLoc() @@ -42,38 +20,17 @@ func ParseInputValue(l *Lexer) *InputValue { return p } -type Argument struct { - Name Ident - Value Literal -} - -type ArgumentList []Argument - -func (l ArgumentList) Get(name string) (Literal, bool) { - for _, arg := range l { - if arg.Name.Name == name { - return arg.Value, true - } - } - return nil, false -} - -func (l ArgumentList) MustGet(name string) Literal { - value, ok := l.Get(name) - if !ok { - panic("argument not found") - } - return value -} - -func ParseArguments(l *Lexer) ArgumentList { - var args ArgumentList +func ParseArgumentList(l *Lexer) types.ArgumentList { + var args types.ArgumentList l.ConsumeToken('(') for l.Peek() != ')' { name := l.ConsumeIdentWithLoc() l.ConsumeToken(':') value := ParseLiteral(l, false) - args = append(args, Argument{Name: name, Value: value}) + args = append(args, &types.Argument{ + Name: name, + Value: value, + }) } l.ConsumeToken(')') return args diff --git a/internal/exec/exec.go b/internal/exec/exec.go index 1e409bb8..9ef66c7f 100644 --- a/internal/exec/exec.go +++ b/internal/exec/exec.go @@ -10,13 +10,12 @@ import ( "time" "github.com/graph-gophers/graphql-go/errors" - "github.com/graph-gophers/graphql-go/internal/common" "github.com/graph-gophers/graphql-go/internal/exec/resolvable" "github.com/graph-gophers/graphql-go/internal/exec/selected" "github.com/graph-gophers/graphql-go/internal/query" - "github.com/graph-gophers/graphql-go/internal/schema" "github.com/graph-gophers/graphql-go/log" "github.com/graph-gophers/graphql-go/trace" + "github.com/graph-gophers/graphql-go/types" ) type Request struct { @@ -42,7 +41,7 @@ func makePanicError(value interface{}) *errors.QueryError { return errors.Errorf("panic occurred: %v", value) } -func (r *Request) Execute(ctx context.Context, s *resolvable.Schema, op *query.Operation) ([]byte, []*errors.QueryError) { +func (r *Request) Execute(ctx context.Context, s *resolvable.Schema, op *types.OperationDefinition) ([]byte, []*errors.QueryError) { var out bytes.Buffer func() { defer r.handlePanic(ctx) @@ -98,7 +97,7 @@ func (r *Request) execSelections(ctx context.Context, sels []selected.Selection, // If a non-nullable child resolved to null, an error was added to the // "errors" list in the response, so this field resolves to null. // If this field is non-nullable, the error is propagated to its parent. - if _, ok := f.field.Type.(*common.NonNull); ok && resolvedToNull(f.out) { + if _, ok := f.field.Type.(*types.NonNull); ok && resolvedToNull(f.out) { out.Reset() out.Write([]byte("null")) return @@ -170,7 +169,7 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f var result reflect.Value var err *errors.QueryError - traceCtx, finish := r.Tracer.TraceField(ctx, f.field.TraceLabel, f.field.TypeName, f.field.Name, !f.field.Async, f.field.Args) + traceCtx, finish := r.Tracer.TraceField(ctx, f.field.TraceLabel, f.field.TypeName, f.field.Name.Name, !f.field.Async, f.field.Args) defer func() { finish(err) }() @@ -239,7 +238,7 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f r.execSelectionSet(traceCtx, f.sels, f.field.Type, path, s, result, f.out) } -func (r *Request) execSelectionSet(ctx context.Context, sels []selected.Selection, typ common.Type, path *pathSegment, s *resolvable.Schema, resolver reflect.Value, out *bytes.Buffer) { +func (r *Request) execSelectionSet(ctx context.Context, sels []selected.Selection, typ types.Type, path *pathSegment, s *resolvable.Schema, resolver reflect.Value, out *bytes.Buffer) { t, nonNull := unwrapNonNull(typ) // a reflect.Value of a nil interface will show up as an Invalid value @@ -257,7 +256,7 @@ func (r *Request) execSelectionSet(ctx context.Context, sels []selected.Selectio } switch t.(type) { - case *schema.Object, *schema.Interface, *schema.Union: + case *types.ObjectTypeDefinition, *types.InterfaceTypeDefinition, *types.Union: r.execSelections(ctx, sels, path, s, resolver, out, false) return } @@ -269,10 +268,10 @@ func (r *Request) execSelectionSet(ctx context.Context, sels []selected.Selectio } switch t := t.(type) { - case *common.List: + case *types.List: r.execList(ctx, sels, t, path, s, resolver, out) - case *schema.Scalar: + case *types.ScalarTypeDefinition: v := resolver.Interface() data, err := json.Marshal(v) if err != nil { @@ -280,21 +279,21 @@ func (r *Request) execSelectionSet(ctx context.Context, sels []selected.Selectio } out.Write(data) - case *schema.Enum: + case *types.EnumTypeDefinition: var stringer fmt.Stringer = resolver if s, ok := resolver.Interface().(fmt.Stringer); ok { stringer = s } name := stringer.String() var valid bool - for _, v := range t.Values { - if v.Name == name { + for _, v := range t.EnumValuesDefinition { + if v.EnumValue == name { valid = true break } } if !valid { - err := errors.Errorf("Invalid value %s.\nExpected type %s, found %s.", name, t.Name, name) + err := errors.Errorf("Invalid value %s.\nExpected type %s, found %s.", name, t.Name.Name, name) err.Path = path.toSlice() r.AddError(err) out.WriteString("null") @@ -309,7 +308,7 @@ func (r *Request) execSelectionSet(ctx context.Context, sels []selected.Selectio } } -func (r *Request) execList(ctx context.Context, sels []selected.Selection, typ *common.List, path *pathSegment, s *resolvable.Schema, resolver reflect.Value, out *bytes.Buffer) { +func (r *Request) execList(ctx context.Context, sels []selected.Selection, typ *types.List, path *pathSegment, s *resolvable.Schema, resolver reflect.Value, out *bytes.Buffer) { l := resolver.Len() entryouts := make([]bytes.Buffer, l) @@ -335,7 +334,7 @@ func (r *Request) execList(ctx context.Context, sels []selected.Selection, typ * } } - _, listOfNonNull := typ.OfType.(*common.NonNull) + _, listOfNonNull := typ.OfType.(*types.NonNull) out.WriteByte('[') for i, entryout := range entryouts { @@ -355,8 +354,8 @@ func (r *Request) execList(ctx context.Context, sels []selected.Selection, typ * out.WriteByte(']') } -func unwrapNonNull(t common.Type) (common.Type, bool) { - if nn, ok := t.(*common.NonNull); ok { +func unwrapNonNull(t types.Type) (types.Type, bool) { + if nn, ok := t.(*types.NonNull); ok { return nn.OfType, true } return t, false diff --git a/internal/exec/packer/packer.go b/internal/exec/packer/packer.go index deadacb8..dd53e048 100644 --- a/internal/exec/packer/packer.go +++ b/internal/exec/packer/packer.go @@ -7,8 +7,7 @@ import ( "strings" "github.com/graph-gophers/graphql-go/errors" - "github.com/graph-gophers/graphql-go/internal/common" - "github.com/graph-gophers/graphql-go/internal/schema" + "github.com/graph-gophers/graphql-go/types" ) type packer interface { @@ -21,7 +20,7 @@ type Builder struct { } type typePair struct { - graphQLType common.Type + graphQLType types.Type resolverType reflect.Type } @@ -47,7 +46,7 @@ func (b *Builder) Finish() error { p.defaultStruct = reflect.New(p.structType).Elem() for _, f := range p.fields { if defaultVal := f.field.Default; defaultVal != nil { - v, err := f.fieldPacker.Pack(defaultVal.Value(nil)) + v, err := f.fieldPacker.Pack(defaultVal.Deserialize(nil)) if err != nil { return err } @@ -59,7 +58,7 @@ func (b *Builder) Finish() error { return nil } -func (b *Builder) assignPacker(target *packer, schemaType common.Type, reflectType reflect.Type) error { +func (b *Builder) assignPacker(target *packer, schemaType types.Type, reflectType reflect.Type) error { k := typePair{schemaType, reflectType} ref, ok := b.packerMap[k] if !ok { @@ -75,13 +74,13 @@ func (b *Builder) assignPacker(target *packer, schemaType common.Type, reflectTy return nil } -func (b *Builder) makePacker(schemaType common.Type, reflectType reflect.Type) (packer, error) { +func (b *Builder) makePacker(schemaType types.Type, reflectType reflect.Type) (packer, error) { t, nonNull := unwrapNonNull(schemaType) if !nonNull { if reflectType.Kind() == reflect.Ptr { elemType := reflectType.Elem() addPtr := true - if _, ok := t.(*schema.InputObject); ok { + if _, ok := t.(*types.InputObject); ok { elemType = reflectType // keep pointer for input objects addPtr = false } @@ -114,7 +113,7 @@ func (b *Builder) makePacker(schemaType common.Type, reflectType reflect.Type) ( return b.makeNonNullPacker(t, reflectType) } -func (b *Builder) makeNonNullPacker(schemaType common.Type, reflectType reflect.Type) (packer, error) { +func (b *Builder) makeNonNullPacker(schemaType types.Type, reflectType reflect.Type) (packer, error) { if u, ok := reflect.New(reflectType).Interface().(Unmarshaler); ok { if !u.ImplementsGraphQLType(schemaType.String()) { return nil, fmt.Errorf("can not unmarshal %s into %s", schemaType, reflectType) @@ -125,12 +124,12 @@ func (b *Builder) makeNonNullPacker(schemaType common.Type, reflectType reflect. } switch t := schemaType.(type) { - case *schema.Scalar: + case *types.ScalarTypeDefinition: return &ValuePacker{ ValueType: reflectType, }, nil - case *schema.Enum: + case *types.EnumTypeDefinition: if reflectType.Kind() != reflect.String { return nil, fmt.Errorf("wrong type, expected %s", reflect.String) } @@ -138,14 +137,14 @@ func (b *Builder) makeNonNullPacker(schemaType common.Type, reflectType reflect. ValueType: reflectType, }, nil - case *schema.InputObject: + case *types.InputObject: e, err := b.MakeStructPacker(t.Values, reflectType) if err != nil { return nil, err } return e, nil - case *common.List: + case *types.List: if reflectType.Kind() != reflect.Slice { return nil, fmt.Errorf("expected slice, got %s", reflectType) } @@ -157,7 +156,7 @@ func (b *Builder) makeNonNullPacker(schemaType common.Type, reflectType reflect. } return p, nil - case *schema.Object, *schema.Interface, *schema.Union: + case *types.ObjectTypeDefinition, *types.InterfaceTypeDefinition, *types.Union: return nil, fmt.Errorf("type of kind %s can not be used as input", t.Kind()) default: @@ -165,7 +164,7 @@ func (b *Builder) makeNonNullPacker(schemaType common.Type, reflectType reflect. } } -func (b *Builder) MakeStructPacker(values common.InputValueList, typ reflect.Type) (*StructPacker, error) { +func (b *Builder) MakeStructPacker(values []*types.InputValueDefinition, typ reflect.Type) (*StructPacker, error) { structType := typ usePtr := false if typ.Kind() == reflect.Ptr { @@ -195,7 +194,7 @@ func (b *Builder) MakeStructPacker(values common.InputValueList, typ reflect.Typ ft := v.Type if v.Default != nil { ft, _ = unwrapNonNull(ft) - ft = &common.NonNull{OfType: ft} + ft = &types.NonNull{OfType: ft} } if err := b.assignPacker(&fe.fieldPacker, ft, sf.Type); err != nil { @@ -222,7 +221,7 @@ type StructPacker struct { } type structPackerField struct { - field *common.InputValue + field *types.InputValueDefinition fieldIndex []int fieldPacker packer } @@ -372,8 +371,8 @@ func unmarshalInput(typ reflect.Type, input interface{}) (interface{}, error) { return nil, fmt.Errorf("incompatible type") } -func unwrapNonNull(t common.Type) (common.Type, bool) { - if nn, ok := t.(*common.NonNull); ok { +func unwrapNonNull(t types.Type) (types.Type, bool) { + if nn, ok := t.(*types.NonNull); ok { return nn.OfType, true } return t, false diff --git a/internal/exec/resolvable/meta.go b/internal/exec/resolvable/meta.go index e9707516..c1c9608d 100644 --- a/internal/exec/resolvable/meta.go +++ b/internal/exec/resolvable/meta.go @@ -4,9 +4,8 @@ import ( "fmt" "reflect" - "github.com/graph-gophers/graphql-go/internal/common" - "github.com/graph-gophers/graphql-go/internal/schema" "github.com/graph-gophers/graphql-go/introspection" + "github.com/graph-gophers/graphql-go/types" ) // Meta defines the details of the metadata schema for introspection. @@ -18,18 +17,18 @@ type Meta struct { Type *Object } -func newMeta(s *schema.Schema) *Meta { +func newMeta(s *types.Schema) *Meta { var err error b := newBuilder(s) - metaSchema := s.Types["__Schema"].(*schema.Object) - so, err := b.makeObjectExec(metaSchema.Name, metaSchema.Fields, nil, false, reflect.TypeOf(&introspection.Schema{})) + metaSchema := s.Types["__Schema"].(*types.ObjectTypeDefinition) + so, err := b.makeObjectExec(metaSchema.Name.Name, metaSchema.Fields, nil, false, reflect.TypeOf(&introspection.Schema{})) if err != nil { panic(err) } - metaType := s.Types["__Type"].(*schema.Object) - t, err := b.makeObjectExec(metaType.Name, metaType.Fields, nil, false, reflect.TypeOf(&introspection.Type{})) + metaType := s.Types["__Type"].(*types.ObjectTypeDefinition) + t, err := b.makeObjectExec(metaType.Name.Name, metaType.Fields, nil, false, reflect.TypeOf(&introspection.Type{})) if err != nil { panic(err) } @@ -39,24 +38,30 @@ func newMeta(s *schema.Schema) *Meta { } fieldTypename := Field{ - Field: schema.Field{ - Name: "__typename", - Type: &common.NonNull{OfType: s.Types["String"]}, + FieldDefinition: types.FieldDefinition{ + Name: types.Ident{ + Name: "__typename", + }, + Type: &types.NonNull{OfType: s.Types["String"]}, }, TraceLabel: fmt.Sprintf("GraphQL field: __typename"), } fieldSchema := Field{ - Field: schema.Field{ - Name: "__schema", + FieldDefinition: types.FieldDefinition{ + Name: types.Ident{ + Name: "__schema", + }, Type: s.Types["__Schema"], }, TraceLabel: fmt.Sprintf("GraphQL field: __schema"), } fieldType := Field{ - Field: schema.Field{ - Name: "__type", + FieldDefinition: types.FieldDefinition{ + Name: types.Ident{ + Name: "__type", + }, Type: s.Types["__Type"], }, TraceLabel: fmt.Sprintf("GraphQL field: __type"), diff --git a/internal/exec/resolvable/resolvable.go b/internal/exec/resolvable/resolvable.go index a3a50481..b076ff64 100644 --- a/internal/exec/resolvable/resolvable.go +++ b/internal/exec/resolvable/resolvable.go @@ -6,14 +6,13 @@ import ( "reflect" "strings" - "github.com/graph-gophers/graphql-go/internal/common" "github.com/graph-gophers/graphql-go/internal/exec/packer" - "github.com/graph-gophers/graphql-go/internal/schema" + "github.com/graph-gophers/graphql-go/types" ) type Schema struct { *Meta - schema.Schema + types.Schema Query Resolvable Mutation Resolvable Subscription Resolvable @@ -31,7 +30,7 @@ type Object struct { } type Field struct { - schema.Field + types.FieldDefinition TypeName string MethodIndex int FieldIndex []int @@ -61,7 +60,7 @@ func (*Object) isResolvable() {} func (*List) isResolvable() {} func (*Scalar) isResolvable() {} -func ApplyResolver(s *schema.Schema, resolver interface{}) (*Schema, error) { +func ApplyResolver(s *types.Schema, resolver interface{}) (*Schema, error) { if resolver == nil { return &Schema{Meta: newMeta(s), Schema: *s}, nil } @@ -103,13 +102,13 @@ func ApplyResolver(s *schema.Schema, resolver interface{}) (*Schema, error) { } type execBuilder struct { - schema *schema.Schema + schema *types.Schema resMap map[typePair]*resMapEntry packerBuilder *packer.Builder } type typePair struct { - graphQLType common.Type + graphQLType types.Type resolverType reflect.Type } @@ -118,7 +117,7 @@ type resMapEntry struct { targets []*Resolvable } -func newBuilder(s *schema.Schema) *execBuilder { +func newBuilder(s *types.Schema) *execBuilder { return &execBuilder{ schema: s, resMap: make(map[typePair]*resMapEntry), @@ -136,7 +135,7 @@ func (b *execBuilder) finish() error { return b.packerBuilder.Finish() } -func (b *execBuilder) assignExec(target *Resolvable, t common.Type, resolverType reflect.Type) error { +func (b *execBuilder) assignExec(target *Resolvable, t types.Type, resolverType reflect.Type) error { k := typePair{t, resolverType} ref, ok := b.resMap[k] if !ok { @@ -152,19 +151,19 @@ func (b *execBuilder) assignExec(target *Resolvable, t common.Type, resolverType return nil } -func (b *execBuilder) makeExec(t common.Type, resolverType reflect.Type) (Resolvable, error) { +func (b *execBuilder) makeExec(t types.Type, resolverType reflect.Type) (Resolvable, error) { var nonNull bool t, nonNull = unwrapNonNull(t) switch t := t.(type) { - case *schema.Object: - return b.makeObjectExec(t.Name, t.Fields, nil, nonNull, resolverType) + case *types.ObjectTypeDefinition: + return b.makeObjectExec(t.Name.Name, t.Fields, nil, nonNull, resolverType) - case *schema.Interface: + case *types.InterfaceTypeDefinition: return b.makeObjectExec(t.Name, t.Fields, t.PossibleTypes, nonNull, resolverType) - case *schema.Union: - return b.makeObjectExec(t.Name, nil, t.PossibleTypes, nonNull, resolverType) + case *types.Union: + return b.makeObjectExec(t.Name.Name, nil, t.UnionMemberTypes, nonNull, resolverType) } if !nonNull { @@ -175,13 +174,13 @@ func (b *execBuilder) makeExec(t common.Type, resolverType reflect.Type) (Resolv } switch t := t.(type) { - case *schema.Scalar: + case *types.ScalarTypeDefinition: return makeScalarExec(t, resolverType) - case *schema.Enum: + case *types.EnumTypeDefinition: return &Scalar{}, nil - case *common.List: + case *types.List: if resolverType.Kind() != reflect.Slice { return nil, fmt.Errorf("%s is not a slice", resolverType) } @@ -196,27 +195,28 @@ func (b *execBuilder) makeExec(t common.Type, resolverType reflect.Type) (Resolv } } -func makeScalarExec(t *schema.Scalar, resolverType reflect.Type) (Resolvable, error) { +func makeScalarExec(t *types.ScalarTypeDefinition, resolverType reflect.Type) (Resolvable, error) { implementsType := false switch r := reflect.New(resolverType).Interface().(type) { case *int32: - implementsType = t.Name == "Int" + implementsType = t.Name.Name == "Int" case *float64: - implementsType = t.Name == "Float" + implementsType = t.Name.Name == "Float" case *string: - implementsType = t.Name == "String" + implementsType = t.Name.Name == "String" case *bool: - implementsType = t.Name == "Boolean" + implementsType = t.Name.Name == "Boolean" case packer.Unmarshaler: - implementsType = r.ImplementsGraphQLType(t.Name) + implementsType = r.ImplementsGraphQLType(t.Name.Name) } + if !implementsType { - return nil, fmt.Errorf("can not use %s as %s", resolverType, t.Name) + return nil, fmt.Errorf("can not use %s as %s", resolverType, t.Name.Name) } return &Scalar{}, nil } -func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, possibleTypes []*schema.Object, +func (b *execBuilder) makeObjectExec(typeName string, fields types.FieldsDefinition, possibleTypes []*types.ObjectTypeDefinition, nonNull bool, resolverType reflect.Type) (*Object, error) { if !nonNull { if resolverType.Kind() != reflect.Ptr && resolverType.Kind() != reflect.Interface { @@ -231,16 +231,16 @@ func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, p fieldsCount := fieldCount(rt, map[string]int{}) for _, f := range fields { var fieldIndex []int - methodIndex := findMethod(resolverType, f.Name) + methodIndex := findMethod(resolverType, f.Name.Name) if b.schema.UseFieldResolvers && methodIndex == -1 { - if fieldsCount[strings.ToLower(stripUnderscore(f.Name))] > 1 { - return nil, fmt.Errorf("%s does not resolve %q: ambiguous field %q", resolverType, typeName, f.Name) + if fieldsCount[strings.ToLower(stripUnderscore(f.Name.Name))] > 1 { + return nil, fmt.Errorf("%s does not resolve %q: ambiguous field %q", resolverType, typeName, f.Name.Name) } - fieldIndex = findField(rt, f.Name, []int{}) + fieldIndex = findField(rt, f.Name.Name, []int{}) } if methodIndex == -1 && len(fieldIndex) == 0 { hint := "" - if findMethod(reflect.PtrTo(resolverType), f.Name) != -1 { + if findMethod(reflect.PtrTo(resolverType), f.Name.Name) != -1 { hint = " (hint: the method exists on the pointer type)" } return nil, fmt.Errorf("%s does not resolve %q: missing method for field %q%s", resolverType, typeName, f.Name, hint) @@ -257,7 +257,7 @@ func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, p if err != nil { return nil, fmt.Errorf("%s\n\tused by (%s).%s", err, resolverType, m.Name) } - Fields[f.Name] = fe + Fields[f.Name.Name] = fe } // Check type assertions when @@ -266,12 +266,12 @@ func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, p typeAssertions := make(map[string]*TypeAssertion) if !b.schema.UseFieldResolvers || resolverType.Kind() != reflect.Interface { for _, impl := range possibleTypes { - methodIndex := findMethod(resolverType, "To"+impl.Name) + methodIndex := findMethod(resolverType, "To"+impl.Name.Name) if methodIndex == -1 { - return nil, fmt.Errorf("%s does not resolve %q: missing method %q to convert to %q", resolverType, typeName, "To"+impl.Name, impl.Name) + return nil, fmt.Errorf("%s does not resolve %q: missing method %q to convert to %q", resolverType, typeName, "To"+impl.Name.Name, impl.Name) } if resolverType.Method(methodIndex).Type.NumOut() != 2 { - return nil, fmt.Errorf("%s does not resolve %q: method %q should return a value and a bool indicating success", resolverType, typeName, "To"+impl.Name) + return nil, fmt.Errorf("%s does not resolve %q: method %q should return a value and a bool indicating success", resolverType, typeName, "To"+impl.Name.Name) } a := &TypeAssertion{ MethodIndex: methodIndex, @@ -279,7 +279,7 @@ func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, p if err := b.assignExec(&a.TypeExec, impl, resolverType.Method(methodIndex).Type.Out(0)); err != nil { return nil, err } - typeAssertions[impl.Name] = a + typeAssertions[impl.Name.Name] = a } } @@ -293,7 +293,7 @@ func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, p var contextType = reflect.TypeOf((*context.Context)(nil)).Elem() var errorType = reflect.TypeOf((*error)(nil)).Elem() -func (b *execBuilder) makeFieldExec(typeName string, f *schema.Field, m reflect.Method, sf reflect.StructField, +func (b *execBuilder) makeFieldExec(typeName string, f *types.FieldDefinition, m reflect.Method, sf reflect.StructField, methodIndex int, fieldIndex []int, methodHasReceiver bool) (*Field, error) { var argsPacker *packer.StructPacker @@ -315,12 +315,12 @@ func (b *execBuilder) makeFieldExec(typeName string, f *schema.Field, m reflect. in = in[1:] } - if len(f.Args) > 0 { + if len(f.Arguments) > 0 { if len(in) == 0 { return nil, fmt.Errorf("must have parameter for field arguments") } var err error - argsPacker, err = b.packerBuilder.MakeStructPacker(f.Args, in[0]) + argsPacker, err = b.packerBuilder.MakeStructPacker(f.Arguments, in[0]) if err != nil { return nil, err } @@ -349,14 +349,14 @@ func (b *execBuilder) makeFieldExec(typeName string, f *schema.Field, m reflect. } fe := &Field{ - Field: *f, - TypeName: typeName, - MethodIndex: methodIndex, - FieldIndex: fieldIndex, - HasContext: hasContext, - ArgsPacker: argsPacker, - HasError: hasError, - TraceLabel: fmt.Sprintf("GraphQL field: %s.%s", typeName, f.Name), + FieldDefinition: *f, + TypeName: typeName, + MethodIndex: methodIndex, + FieldIndex: fieldIndex, + HasContext: hasContext, + ArgsPacker: argsPacker, + HasError: hasError, + TraceLabel: fmt.Sprintf("GraphQL field: %s.%s", typeName, f.Name.Name), } var out reflect.Type @@ -427,8 +427,8 @@ func fieldCount(t reflect.Type, count map[string]int) map[string]int { return count } -func unwrapNonNull(t common.Type) (common.Type, bool) { - if nn, ok := t.(*common.NonNull); ok { +func unwrapNonNull(t types.Type) (types.Type, bool) { + if nn, ok := t.(*types.NonNull); ok { return nn.OfType, true } return t, false diff --git a/internal/exec/selected/selected.go b/internal/exec/selected/selected.go index 1c601253..1698ccef 100644 --- a/internal/exec/selected/selected.go +++ b/internal/exec/selected/selected.go @@ -6,17 +6,16 @@ import ( "sync" "github.com/graph-gophers/graphql-go/errors" - "github.com/graph-gophers/graphql-go/internal/common" "github.com/graph-gophers/graphql-go/internal/exec/packer" "github.com/graph-gophers/graphql-go/internal/exec/resolvable" "github.com/graph-gophers/graphql-go/internal/query" - "github.com/graph-gophers/graphql-go/internal/schema" "github.com/graph-gophers/graphql-go/introspection" + "github.com/graph-gophers/graphql-go/types" ) type Request struct { - Schema *schema.Schema - Doc *query.Document + Schema *types.Schema + Doc *types.ExecutableDefinition Vars map[string]interface{} Mu sync.Mutex Errs []*errors.QueryError @@ -29,7 +28,7 @@ func (r *Request) AddError(err *errors.QueryError) { r.Mu.Unlock() } -func ApplyOperation(r *Request, s *resolvable.Schema, op *query.Operation) []Selection { +func ApplyOperation(r *Request, s *resolvable.Schema, op *types.OperationDefinition) []Selection { var obj *resolvable.Object switch op.Type { case query.Query: @@ -70,10 +69,10 @@ func (*SchemaField) isSelection() {} func (*TypeAssertion) isSelection() {} func (*TypenameField) isSelection() {} -func applySelectionSet(r *Request, s *resolvable.Schema, e *resolvable.Object, sels []query.Selection) (flattenedSels []Selection) { +func applySelectionSet(r *Request, s *resolvable.Schema, e *resolvable.Object, sels []types.Selection) (flattenedSels []Selection) { for _, sel := range sels { switch sel := sel.(type) { - case *query.Field: + case *types.Field: field := sel if skipByDirective(r, field.Directives) { continue @@ -93,7 +92,7 @@ func applySelectionSet(r *Request, s *resolvable.Schema, e *resolvable.Object, s flattenedSels = append(flattenedSels, &SchemaField{ Field: s.Meta.FieldSchema, Alias: field.Alias.Name, - Sels: applySelectionSet(r, s, s.Meta.Schema, field.Selections), + Sels: applySelectionSet(r, s, s.Meta.Schema, field.SelectionSet), Async: true, FixedResult: reflect.ValueOf(introspection.WrapSchema(r.Schema)), }) @@ -102,7 +101,7 @@ func applySelectionSet(r *Request, s *resolvable.Schema, e *resolvable.Object, s case "__type": if !r.DisableIntrospection { p := packer.ValuePacker{ValueType: reflect.TypeOf("")} - v, err := p.Pack(field.Arguments.MustGet("name").Value(r.Vars)) + v, err := p.Pack(field.Arguments.MustGet("name").Deserialize(r.Vars)) if err != nil { r.AddError(errors.Errorf("%s", err)) return nil @@ -116,7 +115,7 @@ func applySelectionSet(r *Request, s *resolvable.Schema, e *resolvable.Object, s flattenedSels = append(flattenedSels, &SchemaField{ Field: s.Meta.FieldType, Alias: field.Alias.Name, - Sels: applySelectionSet(r, s, s.Meta.Type, field.Selections), + Sels: applySelectionSet(r, s, s.Meta.Type, field.SelectionSet), Async: true, FixedResult: reflect.ValueOf(introspection.WrapType(t)), }) @@ -130,7 +129,7 @@ func applySelectionSet(r *Request, s *resolvable.Schema, e *resolvable.Object, s if fe.ArgsPacker != nil { args = make(map[string]interface{}) for _, arg := range field.Arguments { - args[arg.Name.Name] = arg.Value.Value(r.Vars) + args[arg.Name.Name] = arg.Value.Deserialize(r.Vars) } var err error packedArgs, err = fe.ArgsPacker.Pack(args) @@ -140,7 +139,7 @@ func applySelectionSet(r *Request, s *resolvable.Schema, e *resolvable.Object, s } } - fieldSels := applyField(r, s, fe.ValueExec, field.Selections) + fieldSels := applyField(r, s, fe.ValueExec, field.SelectionSet) flattenedSels = append(flattenedSels, &SchemaField{ Field: *fe, Alias: field.Alias.Name, @@ -151,14 +150,14 @@ func applySelectionSet(r *Request, s *resolvable.Schema, e *resolvable.Object, s }) } - case *query.InlineFragment: + case *types.InlineFragment: frag := sel if skipByDirective(r, frag.Directives) { continue } flattenedSels = append(flattenedSels, applyFragment(r, s, e, &frag.Fragment)...) - case *query.FragmentSpread: + case *types.FragmentSpread: spread := sel if skipByDirective(r, spread.Directives) { continue @@ -172,10 +171,10 @@ func applySelectionSet(r *Request, s *resolvable.Schema, e *resolvable.Object, s return } -func applyFragment(r *Request, s *resolvable.Schema, e *resolvable.Object, frag *query.Fragment) []Selection { +func applyFragment(r *Request, s *resolvable.Schema, e *resolvable.Object, frag *types.Fragment) []Selection { if frag.On.Name != e.Name { t := r.Schema.Resolve(frag.On.Name) - face, ok := t.(*schema.Interface) + face, ok := t.(*types.InterfaceTypeDefinition) if !ok && frag.On.Name != "" { a, ok := e.TypeAssertions[frag.On.Name] if !ok { @@ -190,11 +189,11 @@ func applyFragment(r *Request, s *resolvable.Schema, e *resolvable.Object, frag if ok && len(face.PossibleTypes) > 0 { sels := []Selection{} for _, t := range face.PossibleTypes { - if t.Name == e.Name { + if t.Name.Name == e.Name { return applySelectionSet(r, s, e, frag.Selections) } - if a, ok := e.TypeAssertions[t.Name]; ok { + if a, ok := e.TypeAssertions[t.Name.Name]; ok { sels = append(sels, &TypeAssertion{ TypeAssertion: *a, Sels: applySelectionSet(r, s, a.TypeExec.(*resolvable.Object), frag.Selections), @@ -210,7 +209,7 @@ func applyFragment(r *Request, s *resolvable.Schema, e *resolvable.Object, frag return applySelectionSet(r, s, e, frag.Selections) } -func applyField(r *Request, s *resolvable.Schema, e resolvable.Resolvable, sels []query.Selection) []Selection { +func applyField(r *Request, s *resolvable.Schema, e resolvable.Resolvable, sels []types.Selection) []Selection { switch e := e.(type) { case *resolvable.Object: return applySelectionSet(r, s, e, sels) @@ -223,10 +222,10 @@ func applyField(r *Request, s *resolvable.Schema, e resolvable.Resolvable, sels } } -func skipByDirective(r *Request, directives common.DirectiveList) bool { +func skipByDirective(r *Request, directives types.DirectiveList) bool { if d := directives.Get("skip"); d != nil { p := packer.ValuePacker{ValueType: reflect.TypeOf(false)} - v, err := p.Pack(d.Args.MustGet("if").Value(r.Vars)) + v, err := p.Pack(d.Arguments.MustGet("if").Deserialize(r.Vars)) if err != nil { r.AddError(errors.Errorf("%s", err)) } @@ -237,7 +236,7 @@ func skipByDirective(r *Request, directives common.DirectiveList) bool { if d := directives.Get("include"); d != nil { p := packer.ValuePacker{ValueType: reflect.TypeOf(false)} - v, err := p.Pack(d.Args.MustGet("if").Value(r.Vars)) + v, err := p.Pack(d.Arguments.MustGet("if").Deserialize(r.Vars)) if err != nil { r.AddError(errors.Errorf("%s", err)) } diff --git a/internal/exec/subscribe.go b/internal/exec/subscribe.go index a42a8634..0353dfda 100644 --- a/internal/exec/subscribe.go +++ b/internal/exec/subscribe.go @@ -9,10 +9,9 @@ import ( "time" "github.com/graph-gophers/graphql-go/errors" - "github.com/graph-gophers/graphql-go/internal/common" "github.com/graph-gophers/graphql-go/internal/exec/resolvable" "github.com/graph-gophers/graphql-go/internal/exec/selected" - "github.com/graph-gophers/graphql-go/internal/query" + "github.com/graph-gophers/graphql-go/types" ) type Response struct { @@ -20,7 +19,7 @@ type Response struct { Errors []*errors.QueryError } -func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *query.Operation) <-chan *Response { +func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *types.OperationDefinition) <-chan *Response { var result reflect.Value var f *fieldToExec var err *errors.QueryError @@ -71,7 +70,7 @@ func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *query } if err != nil { - if _, nonNullChild := f.field.Type.(*common.NonNull); nonNullChild { + if _, nonNullChild := f.field.Type.(*types.NonNull); nonNullChild { return sendAndReturnClosed(&Response{Errors: []*errors.QueryError{err}}) } return sendAndReturnClosed(&Response{Data: []byte(fmt.Sprintf(`{"%s":null}`, f.field.Alias)), Errors: []*errors.QueryError{err}}) @@ -142,7 +141,7 @@ func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *query subR.execSelectionSet(subCtx, f.sels, f.field.Type, &pathSegment{nil, f.field.Alias}, s, resp, &buf) propagateChildError := false - if _, nonNullChild := f.field.Type.(*common.NonNull); nonNullChild && resolvedToNull(&buf) { + if _, nonNullChild := f.field.Type.(*types.NonNull); nonNullChild && resolvedToNull(&buf) { propagateChildError = true } diff --git a/internal/query/query.go b/internal/query/query.go index fffc88e7..abeb32be 100644 --- a/internal/query/query.go +++ b/internal/query/query.go @@ -6,113 +6,35 @@ import ( "github.com/graph-gophers/graphql-go/errors" "github.com/graph-gophers/graphql-go/internal/common" + "github.com/graph-gophers/graphql-go/types" ) -type Document struct { - Operations OperationList - Fragments FragmentList -} - -type OperationList []*Operation - -func (l OperationList) Get(name string) *Operation { - for _, f := range l { - if f.Name.Name == name { - return f - } - } - return nil -} - -type FragmentList []*FragmentDecl - -func (l FragmentList) Get(name string) *FragmentDecl { - for _, f := range l { - if f.Name.Name == name { - return f - } - } - return nil -} - -type Operation struct { - Type OperationType - Name common.Ident - Vars common.InputValueList - Selections []Selection - Directives common.DirectiveList - Loc errors.Location -} - -type OperationType string - const ( - Query OperationType = "QUERY" - Mutation = "MUTATION" - Subscription = "SUBSCRIPTION" + Query types.OperationType = "QUERY" + Mutation = "MUTATION" + Subscription = "SUBSCRIPTION" ) -type Fragment struct { - On common.TypeName - Selections []Selection -} - -type FragmentDecl struct { - Fragment - Name common.Ident - Directives common.DirectiveList - Loc errors.Location -} - -type Selection interface { - isSelection() -} - -type Field struct { - Alias common.Ident - Name common.Ident - Arguments common.ArgumentList - Directives common.DirectiveList - Selections []Selection - SelectionSetLoc errors.Location -} - -type InlineFragment struct { - Fragment - Directives common.DirectiveList - Loc errors.Location -} - -type FragmentSpread struct { - Name common.Ident - Directives common.DirectiveList - Loc errors.Location -} - -func (Field) isSelection() {} -func (InlineFragment) isSelection() {} -func (FragmentSpread) isSelection() {} - -func Parse(queryString string) (*Document, *errors.QueryError) { +func Parse(queryString string) (*types.ExecutableDefinition, *errors.QueryError) { l := common.NewLexer(queryString, false) - var doc *Document - err := l.CatchSyntaxError(func() { doc = parseDocument(l) }) + var execDef *types.ExecutableDefinition + err := l.CatchSyntaxError(func() { execDef = parseExecutableDefinition(l) }) if err != nil { return nil, err } - return doc, nil + return execDef, nil } -func parseDocument(l *common.Lexer) *Document { - d := &Document{} +func parseExecutableDefinition(l *common.Lexer) *types.ExecutableDefinition { + ed := &types.ExecutableDefinition{} l.ConsumeWhitespace() for l.Peek() != scanner.EOF { if l.Peek() == '{' { - op := &Operation{Type: Query, Loc: l.Location()} + op := &types.OperationDefinition{Type: Query, Loc: l.Location()} op.Selections = parseSelectionSet(l) - d.Operations = append(d.Operations, op) + ed.Operations = append(ed.Operations, op) continue } @@ -121,28 +43,28 @@ func parseDocument(l *common.Lexer) *Document { case "query": op := parseOperation(l, Query) op.Loc = loc - d.Operations = append(d.Operations, op) + ed.Operations = append(ed.Operations, op) case "mutation": - d.Operations = append(d.Operations, parseOperation(l, Mutation)) + ed.Operations = append(ed.Operations, parseOperation(l, Mutation)) case "subscription": - d.Operations = append(d.Operations, parseOperation(l, Subscription)) + ed.Operations = append(ed.Operations, parseOperation(l, Subscription)) case "fragment": frag := parseFragment(l) frag.Loc = loc - d.Fragments = append(d.Fragments, frag) + ed.Fragments = append(ed.Fragments, frag) default: l.SyntaxError(fmt.Sprintf(`unexpected %q, expecting "fragment"`, x)) } } - return d + return ed } -func parseOperation(l *common.Lexer, opType OperationType) *Operation { - op := &Operation{Type: opType} +func parseOperation(l *common.Lexer, opType types.OperationType) *types.OperationDefinition { + op := &types.OperationDefinition{Type: opType} op.Name.Loc = l.Location() if l.Peek() == scanner.Ident { op.Name = l.ConsumeIdentWithLoc() @@ -163,18 +85,18 @@ func parseOperation(l *common.Lexer, opType OperationType) *Operation { return op } -func parseFragment(l *common.Lexer) *FragmentDecl { - f := &FragmentDecl{} +func parseFragment(l *common.Lexer) *types.FragmentDefinition { + f := &types.FragmentDefinition{} f.Name = l.ConsumeIdentWithLoc() l.ConsumeKeyword("on") - f.On = common.TypeName{Ident: l.ConsumeIdentWithLoc()} + f.On = types.TypeName{Ident: l.ConsumeIdentWithLoc()} f.Directives = common.ParseDirectives(l) f.Selections = parseSelectionSet(l) return f } -func parseSelectionSet(l *common.Lexer) []Selection { - var sels []Selection +func parseSelectionSet(l *common.Lexer) []types.Selection { + var sels []types.Selection l.ConsumeToken('{') for l.Peek() != '}' { sels = append(sels, parseSelection(l)) @@ -183,15 +105,15 @@ func parseSelectionSet(l *common.Lexer) []Selection { return sels } -func parseSelection(l *common.Lexer) Selection { +func parseSelection(l *common.Lexer) types.Selection { if l.Peek() == '.' { return parseSpread(l) } - return parseField(l) + return parseFieldDef(l) } -func parseField(l *common.Lexer) *Field { - f := &Field{} +func parseFieldDef(l *common.Lexer) *types.Field { + f := &types.Field{} f.Alias = l.ConsumeIdentWithLoc() f.Name = f.Alias if l.Peek() == ':' { @@ -199,34 +121,34 @@ func parseField(l *common.Lexer) *Field { f.Name = l.ConsumeIdentWithLoc() } if l.Peek() == '(' { - f.Arguments = common.ParseArguments(l) + f.Arguments = common.ParseArgumentList(l) } f.Directives = common.ParseDirectives(l) if l.Peek() == '{' { f.SelectionSetLoc = l.Location() - f.Selections = parseSelectionSet(l) + f.SelectionSet = parseSelectionSet(l) } return f } -func parseSpread(l *common.Lexer) Selection { +func parseSpread(l *common.Lexer) types.Selection { loc := l.Location() l.ConsumeToken('.') l.ConsumeToken('.') l.ConsumeToken('.') - f := &InlineFragment{Loc: loc} + f := &types.InlineFragment{Loc: loc} if l.Peek() == scanner.Ident { ident := l.ConsumeIdentWithLoc() if ident.Name != "on" { - fs := &FragmentSpread{ + fs := &types.FragmentSpread{ Name: ident, Loc: loc, } fs.Directives = common.ParseDirectives(l) return fs } - f.On = common.TypeName{Ident: l.ConsumeIdentWithLoc()} + f.On = types.TypeName{Ident: l.ConsumeIdentWithLoc()} } f.Directives = common.ParseDirectives(l) f.Selections = parseSelectionSet(l) diff --git a/internal/schema/meta.go b/internal/schema/meta.go index 2e311830..9f5bba56 100644 --- a/internal/schema/meta.go +++ b/internal/schema/meta.go @@ -1,17 +1,23 @@ package schema +import ( + "github.com/graph-gophers/graphql-go/types" +) + func init() { _ = newMeta() } // newMeta initializes an instance of the meta Schema. -func newMeta() *Schema { - s := &Schema{ - entryPointNames: make(map[string]string), - Types: make(map[string]NamedType), - Directives: make(map[string]*DirectiveDecl), +func newMeta() *types.Schema { + s := &types.Schema{ + EntryPointNames: make(map[string]string), + Types: make(map[string]types.NamedType), + Directives: make(map[string]*types.DirectiveDefinition), } - if err := s.Parse(metaSrc, false); err != nil { + + err := Parse(s, metaSrc, false) + if err != nil { panic(err) } return s diff --git a/internal/schema/schema.go b/internal/schema/schema.go index 38012040..09eed647 100644 --- a/internal/schema/schema.go +++ b/internal/schema/schema.go @@ -6,246 +6,15 @@ import ( "github.com/graph-gophers/graphql-go/errors" "github.com/graph-gophers/graphql-go/internal/common" + "github.com/graph-gophers/graphql-go/types" ) -// Schema represents a GraphQL service's collective type system capabilities. -// A schema is defined in terms of the types and directives it supports as well as the root -// operation types for each kind of operation: `query`, `mutation`, and `subscription`. -// -// For a more formal definition, read the relevant section in the specification: -// -// http://facebook.github.io/graphql/draft/#sec-Schema -type Schema struct { - // EntryPoints determines the place in the type system where `query`, `mutation`, and - // `subscription` operations begin. - // - // http://facebook.github.io/graphql/draft/#sec-Root-Operation-Types - // - // NOTE: The specification refers to this concept as "Root Operation Types". - // TODO: Rename the `EntryPoints` field to `RootOperationTypes` to align with spec terminology. - EntryPoints map[string]NamedType - - // Types are the fundamental unit of any GraphQL schema. - // There are six kinds of named types, and two wrapping types. - // - // http://facebook.github.io/graphql/draft/#sec-Types - Types map[string]NamedType - - // TODO: Type extensions? - // http://facebook.github.io/graphql/draft/#sec-Type-Extensions - - // Directives are used to annotate various parts of a GraphQL document as an indicator that they - // should be evaluated differently by a validator, executor, or client tool such as a code - // generator. - // - // http://facebook.github.io/graphql/draft/#sec-Type-System.Directives - Directives map[string]*DirectiveDecl - - UseFieldResolvers bool - - entryPointNames map[string]string - objects []*Object - unions []*Union - enums []*Enum - extensions []*Extension -} - -// Resolve a named type in the schema by its name. -func (s *Schema) Resolve(name string) common.Type { - return s.Types[name] -} - -// NamedType represents a type with a name. -// -// http://facebook.github.io/graphql/draft/#NamedType -type NamedType interface { - common.Type - TypeName() string - Description() string -} - -// Scalar types represent primitive leaf values (e.g. a string or an integer) in a GraphQL type -// system. -// -// GraphQL responses take the form of a hierarchical tree; the leaves on these trees are GraphQL -// scalars. -// -// http://facebook.github.io/graphql/draft/#sec-Scalars -type Scalar struct { - Name string - Desc string - Directives common.DirectiveList -} - -// Object types represent a list of named fields, each of which yield a value of a specific type. -// -// GraphQL queries are hierarchical and composed, describing a tree of information. -// While Scalar types describe the leaf values of these hierarchical types, Objects describe the -// intermediate levels. -// -// http://facebook.github.io/graphql/draft/#sec-Objects -type Object struct { - Name string - Interfaces []*Interface - Fields FieldList - Desc string - Directives common.DirectiveList - - interfaceNames []string -} - -// Interface types represent a list of named fields and their arguments. -// -// GraphQL objects can then implement these interfaces which requires that the object type will -// define all fields defined by those interfaces. -// -// http://facebook.github.io/graphql/draft/#sec-Interfaces -type Interface struct { - Name string - PossibleTypes []*Object - Fields FieldList // NOTE: the spec refers to this as `FieldsDefinition`. - Desc string - Directives common.DirectiveList -} - -// Union types represent objects that could be one of a list of GraphQL object types, but provides no -// guaranteed fields between those types. -// -// They also differ from interfaces in that object types declare what interfaces they implement, but -// are not aware of what unions contain them. -// -// http://facebook.github.io/graphql/draft/#sec-Unions -type Union struct { - Name string - PossibleTypes []*Object // NOTE: the spec refers to this as `UnionMemberTypes`. - Desc string - Directives common.DirectiveList - - typeNames []string -} - -// Enum types describe a set of possible values. -// -// Like scalar types, Enum types also represent leaf values in a GraphQL type system. -// -// http://facebook.github.io/graphql/draft/#sec-Enums -type Enum struct { - Name string - Values []*EnumValue // NOTE: the spec refers to this as `EnumValuesDefinition`. - Desc string - Directives common.DirectiveList -} - -// EnumValue types are unique values that may be serialized as a string: the name of the -// represented value. -// -// http://facebook.github.io/graphql/draft/#EnumValueDefinition -type EnumValue struct { - Name string - Directives common.DirectiveList - Desc string -} - -// InputObject types define a set of input fields; the input fields are either scalars, enums, or -// other input objects. -// -// This allows arguments to accept arbitrarily complex structs. -// -// http://facebook.github.io/graphql/draft/#sec-Input-Objects -type InputObject struct { - Name string - Desc string - Values common.InputValueList - Directives common.DirectiveList -} - -// Extension type defines a GraphQL type extension. -// Schemas, Objects, Inputs and Scalars can be extended. -// -// https://facebook.github.io/graphql/draft/#sec-Type-System-Extensions -type Extension struct { - Type NamedType - Directives common.DirectiveList -} - -// FieldsList is a list of an Object's Fields. -// -// http://facebook.github.io/graphql/draft/#FieldsDefinition -type FieldList []*Field - -// Get iterates over the field list, returning a pointer-to-Field when the field name matches the -// provided `name` argument. -// Returns nil when no field was found by that name. -func (l FieldList) Get(name string) *Field { - for _, f := range l { - if f.Name == name { - return f - } - } - return nil -} - -// Names returns a string slice of the field names in the FieldList. -func (l FieldList) Names() []string { - names := make([]string, len(l)) - for i, f := range l { - names[i] = f.Name - } - return names -} - -// http://facebook.github.io/graphql/draft/#sec-Type-System.Directives -type DirectiveDecl struct { - Name string - Desc string - Locs []string - Args common.InputValueList -} - -func (*Scalar) Kind() string { return "SCALAR" } -func (*Object) Kind() string { return "OBJECT" } -func (*Interface) Kind() string { return "INTERFACE" } -func (*Union) Kind() string { return "UNION" } -func (*Enum) Kind() string { return "ENUM" } -func (*InputObject) Kind() string { return "INPUT_OBJECT" } - -func (t *Scalar) String() string { return t.Name } -func (t *Object) String() string { return t.Name } -func (t *Interface) String() string { return t.Name } -func (t *Union) String() string { return t.Name } -func (t *Enum) String() string { return t.Name } -func (t *InputObject) String() string { return t.Name } - -func (t *Scalar) TypeName() string { return t.Name } -func (t *Object) TypeName() string { return t.Name } -func (t *Interface) TypeName() string { return t.Name } -func (t *Union) TypeName() string { return t.Name } -func (t *Enum) TypeName() string { return t.Name } -func (t *InputObject) TypeName() string { return t.Name } - -func (t *Scalar) Description() string { return t.Desc } -func (t *Object) Description() string { return t.Desc } -func (t *Interface) Description() string { return t.Desc } -func (t *Union) Description() string { return t.Desc } -func (t *Enum) Description() string { return t.Desc } -func (t *InputObject) Description() string { return t.Desc } - -// Field is a conceptual function which yields values. -// http://facebook.github.io/graphql/draft/#FieldDefinition -type Field struct { - Name string - Args common.InputValueList // NOTE: the spec refers to this as `ArgumentsDefinition`. - Type common.Type - Directives common.DirectiveList - Desc string -} - // New initializes an instance of Schema. -func New() *Schema { - s := &Schema{ - entryPointNames: make(map[string]string), - Types: make(map[string]NamedType), - Directives: make(map[string]*DirectiveDecl), +func New() *types.Schema { + s := &types.Schema{ + EntryPointNames: make(map[string]string), + Types: make(map[string]types.NamedType), + Directives: make(map[string]*types.DirectiveDefinition), } m := newMeta() for n, t := range m.Types { @@ -257,10 +26,8 @@ func New() *Schema { return s } -// Parse the schema string. -func (s *Schema) Parse(schemaString string, useStringDescriptions bool) error { +func Parse(s *types.Schema, schemaString string, useStringDescriptions bool) error { l := common.NewLexer(schemaString, useStringDescriptions) - err := l.CatchSyntaxError(func() { parseSchema(s, l) }) if err != nil { return err @@ -276,7 +43,7 @@ func (s *Schema) Parse(schemaString string, useStringDescriptions bool) error { } } for _, d := range s.Directives { - for _, arg := range d.Args { + for _, arg := range d.Arguments { t, err := common.ResolveType(arg.Type, s.Resolve) if err != nil { return err @@ -289,30 +56,28 @@ func (s *Schema) Parse(schemaString string, useStringDescriptions bool) error { // > While any type can be the root operation type for a GraphQL operation, the type system definition language can // > omit the schema definition when the query, mutation, and subscription root types are named Query, Mutation, // > and Subscription respectively. - if len(s.entryPointNames) == 0 { + if len(s.EntryPointNames) == 0 { if _, ok := s.Types["Query"]; ok { - s.entryPointNames["query"] = "Query" + s.EntryPointNames["query"] = "Query" } if _, ok := s.Types["Mutation"]; ok { - s.entryPointNames["mutation"] = "Mutation" + s.EntryPointNames["mutation"] = "Mutation" } if _, ok := s.Types["Subscription"]; ok { - s.entryPointNames["subscription"] = "Subscription" + s.EntryPointNames["subscription"] = "Subscription" } } - s.EntryPoints = make(map[string]NamedType) - for key, name := range s.entryPointNames { + s.EntryPoints = make(map[string]types.NamedType) + for key, name := range s.EntryPointNames { t, ok := s.Types[name] if !ok { - if !ok { - return errors.Errorf("type %q not found", name) - } + return errors.Errorf("type %q not found", name) } s.EntryPoints[key] = t } - for _, obj := range s.objects { - obj.Interfaces = make([]*Interface, len(obj.interfaceNames)) + for _, obj := range s.Objects { + obj.Interfaces = make([]*types.InterfaceTypeDefinition, len(obj.InterfaceNames)) if err := resolveDirectives(s, obj.Directives, "OBJECT"); err != nil { return err } @@ -321,18 +86,18 @@ func (s *Schema) Parse(schemaString string, useStringDescriptions bool) error { return err } } - for i, intfName := range obj.interfaceNames { + for i, intfName := range obj.InterfaceNames { t, ok := s.Types[intfName] if !ok { return errors.Errorf("interface %q not found", intfName) } - intf, ok := t.(*Interface) + intf, ok := t.(*types.InterfaceTypeDefinition) if !ok { return errors.Errorf("type %q is not an interface", intfName) } for _, f := range intf.Fields.Names() { if obj.Fields.Get(f) == nil { - return errors.Errorf("interface %q expects field %q but %q does not provide it", intfName, f, obj.Name) + return errors.Errorf("interface %q expects field %q but %q does not provide it", intfName, f, obj.Name.Name) } } obj.Interfaces[i] = intf @@ -340,29 +105,29 @@ func (s *Schema) Parse(schemaString string, useStringDescriptions bool) error { } } - for _, union := range s.unions { + for _, union := range s.Unions { if err := resolveDirectives(s, union.Directives, "UNION"); err != nil { return err } - union.PossibleTypes = make([]*Object, len(union.typeNames)) - for i, name := range union.typeNames { + union.UnionMemberTypes = make([]*types.ObjectTypeDefinition, len(union.TypeNames)) + for i, name := range union.TypeNames { t, ok := s.Types[name] if !ok { return errors.Errorf("object type %q not found", name) } - obj, ok := t.(*Object) + obj, ok := t.(*types.ObjectTypeDefinition) if !ok { return errors.Errorf("type %q is not an object", name) } - union.PossibleTypes[i] = obj + union.UnionMemberTypes[i] = obj } } - for _, enum := range s.enums { + for _, enum := range s.Enums { if err := resolveDirectives(s, enum.Directives, "ENUM"); err != nil { return err } - for _, value := range enum.Values { + for _, value := range enum.EnumValuesDefinition { if err := resolveDirectives(s, value.Directives, "ENUM_VALUE"); err != nil { return err } @@ -372,8 +137,14 @@ func (s *Schema) Parse(schemaString string, useStringDescriptions bool) error { return nil } -func mergeExtensions(s *Schema) error { - for _, ext := range s.extensions { +func ParseSchema(schemaString string, useStringDescriptions bool) (*types.Schema, error) { + s := New() + err := Parse(s, schemaString, useStringDescriptions) + return s, err +} + +func mergeExtensions(s *types.Schema) error { + for _, ext := range s.Extensions { typ := s.Types[ext.Type.TypeName()] if typ == nil { return fmt.Errorf("trying to extend unknown type %q", ext.Type.TypeName()) @@ -384,27 +155,27 @@ func mergeExtensions(s *Schema) error { } switch og := typ.(type) { - case *Object: - e := ext.Type.(*Object) + case *types.ObjectTypeDefinition: + e := ext.Type.(*types.ObjectTypeDefinition) for _, field := range e.Fields { - if og.Fields.Get(field.Name) != nil { - return fmt.Errorf("extended field %q already exists", field.Name) + if og.Fields.Get(field.Name.Name) != nil { + return fmt.Errorf("extended field %q already exists", field.Name.Name) } } og.Fields = append(og.Fields, e.Fields...) - for _, en := range e.interfaceNames { - for _, on := range og.interfaceNames { + for _, en := range e.InterfaceNames { + for _, on := range og.InterfaceNames { if on == en { - return fmt.Errorf("interface %q implemented in the extension is already implemented in %q", on, og.Name) + return fmt.Errorf("interface %q implemented in the extension is already implemented in %q", on, og.Name.Name) } } } - og.interfaceNames = append(og.interfaceNames, e.interfaceNames...) + og.InterfaceNames = append(og.InterfaceNames, e.InterfaceNames...) - case *InputObject: - e := ext.Type.(*InputObject) + case *types.InputObject: + e := ext.Type.(*types.InputObject) for _, field := range e.Values { if og.Values.Get(field.Name.Name) != nil { @@ -413,39 +184,39 @@ func mergeExtensions(s *Schema) error { } og.Values = append(og.Values, e.Values...) - case *Interface: - e := ext.Type.(*Interface) + case *types.InterfaceTypeDefinition: + e := ext.Type.(*types.InterfaceTypeDefinition) for _, field := range e.Fields { - if og.Fields.Get(field.Name) != nil { - return fmt.Errorf("extended field %s already exists", field.Name) + if og.Fields.Get(field.Name.Name) != nil { + return fmt.Errorf("extended field %s already exists", field.Name.Name) } } og.Fields = append(og.Fields, e.Fields...) - case *Union: - e := ext.Type.(*Union) + case *types.Union: + e := ext.Type.(*types.Union) - for _, en := range e.typeNames { - for _, on := range og.typeNames { + for _, en := range e.TypeNames { + for _, on := range og.TypeNames { if on == en { - return fmt.Errorf("union type %q already declared in %q", on, og.Name) + return fmt.Errorf("union type %q already declared in %q", on, og.Name.Name) } } } - og.typeNames = append(og.typeNames, e.typeNames...) + og.TypeNames = append(og.TypeNames, e.TypeNames...) - case *Enum: - e := ext.Type.(*Enum) + case *types.EnumTypeDefinition: + e := ext.Type.(*types.EnumTypeDefinition) - for _, en := range e.Values { - for _, on := range og.Values { - if on.Name == en.Name { - return fmt.Errorf("enum value %q already declared in %q", on.Name, og.Name) + for _, en := range e.EnumValuesDefinition { + for _, on := range og.EnumValuesDefinition { + if on.EnumValue == en.EnumValue { + return fmt.Errorf("enum value %q already declared in %q", on.EnumValue, og.Name.Name) } } } - og.Values = append(og.Values, e.Values...) + og.EnumValuesDefinition = append(og.EnumValuesDefinition, e.EnumValuesDefinition...) default: return fmt.Errorf(`unexpected %q, expecting "schema", "type", "enum", "interface", "union" or "input"`, og.TypeName()) } @@ -454,21 +225,21 @@ func mergeExtensions(s *Schema) error { return nil } -func resolveNamedType(s *Schema, t NamedType) error { +func resolveNamedType(s *types.Schema, t types.NamedType) error { switch t := t.(type) { - case *Object: + case *types.ObjectTypeDefinition: for _, f := range t.Fields { if err := resolveField(s, f); err != nil { return err } } - case *Interface: + case *types.InterfaceTypeDefinition: for _, f := range t.Fields { if err := resolveField(s, f); err != nil { return err } } - case *InputObject: + case *types.InputObject: if err := resolveInputObject(s, t.Values); err != nil { return err } @@ -476,7 +247,7 @@ func resolveNamedType(s *Schema, t NamedType) error { return nil } -func resolveField(s *Schema, f *Field) error { +func resolveField(s *types.Schema, f *types.FieldDefinition) error { t, err := common.ResolveType(f.Type, s.Resolve) if err != nil { return err @@ -485,10 +256,10 @@ func resolveField(s *Schema, f *Field) error { if err := resolveDirectives(s, f.Directives, "FIELD_DEFINITION"); err != nil { return err } - return resolveInputObject(s, f.Args) + return resolveInputObject(s, f.Arguments) } -func resolveDirectives(s *Schema, directives common.DirectiveList, loc string) error { +func resolveDirectives(s *types.Schema, directives types.DirectiveList, loc string) error { for _, d := range directives { dirName := d.Name.Name dd, ok := s.Directives[dirName] @@ -496,30 +267,30 @@ func resolveDirectives(s *Schema, directives common.DirectiveList, loc string) e return errors.Errorf("directive %q not found", dirName) } validLoc := false - for _, l := range dd.Locs { + for _, l := range dd.Locations { if l == loc { validLoc = true break } } if !validLoc { - return errors.Errorf("invalid location %q for directive %q (must be one of %v)", loc, dirName, dd.Locs) + return errors.Errorf("invalid location %q for directive %q (must be one of %v)", loc, dirName, dd.Locations) } - for _, arg := range d.Args { - if dd.Args.Get(arg.Name.Name) == nil { + for _, arg := range d.Arguments { + if dd.Arguments.Get(arg.Name.Name) == nil { return errors.Errorf("invalid argument %q for directive %q", arg.Name.Name, dirName) } } - for _, arg := range dd.Args { - if _, ok := d.Args.Get(arg.Name.Name); !ok { - d.Args = append(d.Args, common.Argument{Name: arg.Name, Value: arg.Default}) + for _, arg := range dd.Arguments { + if _, ok := d.Arguments.Get(arg.Name.Name); !ok { + d.Arguments = append(d.Arguments, &types.Argument{Name: arg.Name, Value: arg.Default}) } } } return nil } -func resolveInputObject(s *Schema, values common.InputValueList) error { +func resolveInputObject(s *types.Schema, values types.ArgumentsDefinition) error { for _, v := range values { t, err := common.ResolveType(v.Type, s.Resolve) if err != nil { @@ -530,7 +301,7 @@ func resolveInputObject(s *Schema, values common.InputValueList) error { return nil } -func parseSchema(s *Schema, l *common.Lexer) { +func parseSchema(s *types.Schema, l *common.Lexer) { l.ConsumeWhitespace() for l.Peek() != scanner.EOF { @@ -540,18 +311,19 @@ func parseSchema(s *Schema, l *common.Lexer) { case "schema": l.ConsumeToken('{') for l.Peek() != '}' { + name := l.ConsumeIdent() l.ConsumeToken(':') typ := l.ConsumeIdent() - s.entryPointNames[name] = typ + s.EntryPointNames[name] = typ } l.ConsumeToken('}') case "type": obj := parseObjectDef(l) obj.Desc = desc - s.Types[obj.Name] = obj - s.objects = append(s.objects, obj) + s.Types[obj.Name.Name] = obj + s.Objects = append(s.Objects, obj) case "interface": iface := parseInterfaceDef(l) @@ -561,29 +333,29 @@ func parseSchema(s *Schema, l *common.Lexer) { case "union": union := parseUnionDef(l) union.Desc = desc - s.Types[union.Name] = union - s.unions = append(s.unions, union) + s.Types[union.Name.Name] = union + s.Unions = append(s.Unions, union) case "enum": enum := parseEnumDef(l) enum.Desc = desc - s.Types[enum.Name] = enum - s.enums = append(s.enums, enum) + s.Types[enum.Name.Name] = enum + s.Enums = append(s.Enums, enum) case "input": input := parseInputDef(l) input.Desc = desc - s.Types[input.Name] = input + s.Types[input.Name.Name] = input case "scalar": - name := l.ConsumeIdent() + name := l.ConsumeIdentWithLoc() directives := common.ParseDirectives(l) - s.Types[name] = &Scalar{Name: name, Desc: desc, Directives: directives} + s.Types[name.Name] = &types.ScalarTypeDefinition{Name: name, Desc: desc, Directives: directives} case "directive": directive := parseDirectiveDef(l) directive.Desc = desc - s.Directives[directive.Name] = directive + s.Directives[directive.Name.Name] = directive case "extend": parseExtension(s, l) @@ -595,8 +367,8 @@ func parseSchema(s *Schema, l *common.Lexer) { } } -func parseObjectDef(l *common.Lexer) *Object { - object := &Object{Name: l.ConsumeIdent()} +func parseObjectDef(l *common.Lexer) *types.ObjectTypeDefinition { + object := &types.ObjectTypeDefinition{Name: l.ConsumeIdentWithLoc()} for { if l.Peek() == '{' { @@ -616,25 +388,25 @@ func parseObjectDef(l *common.Lexer) *Object { l.ConsumeToken('&') } - object.interfaceNames = append(object.interfaceNames, l.ConsumeIdent()) + object.InterfaceNames = append(object.InterfaceNames, l.ConsumeIdent()) } continue } - l.SyntaxError(fmt.Sprintf(`unexpected %q, expecting "implements", "directive" or "{"`, l.Peek())) } - l.ConsumeToken('{') object.Fields = parseFieldsDef(l) l.ConsumeToken('}') return object + } -func parseInterfaceDef(l *common.Lexer) *Interface { - i := &Interface{Name: l.ConsumeIdent()} +func parseInterfaceDef(l *common.Lexer) *types.InterfaceTypeDefinition { + i := &types.InterfaceTypeDefinition{Name: l.ConsumeIdent()} i.Directives = common.ParseDirectives(l) + l.ConsumeToken('{') i.Fields = parseFieldsDef(l) l.ConsumeToken('}') @@ -642,23 +414,23 @@ func parseInterfaceDef(l *common.Lexer) *Interface { return i } -func parseUnionDef(l *common.Lexer) *Union { - union := &Union{Name: l.ConsumeIdent()} +func parseUnionDef(l *common.Lexer) *types.Union { + union := &types.Union{Name: l.ConsumeIdentWithLoc()} union.Directives = common.ParseDirectives(l) l.ConsumeToken('=') - union.typeNames = []string{l.ConsumeIdent()} + union.TypeNames = []string{l.ConsumeIdent()} for l.Peek() == '|' { l.ConsumeToken('|') - union.typeNames = append(union.typeNames, l.ConsumeIdent()) + union.TypeNames = append(union.TypeNames, l.ConsumeIdent()) } return union } -func parseInputDef(l *common.Lexer) *InputObject { - i := &InputObject{} - i.Name = l.ConsumeIdent() +func parseInputDef(l *common.Lexer) *types.InputObject { + i := &types.InputObject{} + i.Name = l.ConsumeIdentWithLoc() i.Directives = common.ParseDirectives(l) l.ConsumeToken('{') for l.Peek() != '}' { @@ -668,33 +440,32 @@ func parseInputDef(l *common.Lexer) *InputObject { return i } -func parseEnumDef(l *common.Lexer) *Enum { - enum := &Enum{Name: l.ConsumeIdent()} +func parseEnumDef(l *common.Lexer) *types.EnumTypeDefinition { + enum := &types.EnumTypeDefinition{Name: l.ConsumeIdentWithLoc()} enum.Directives = common.ParseDirectives(l) l.ConsumeToken('{') for l.Peek() != '}' { - v := &EnumValue{ + v := &types.EnumValueDefinition{ Desc: l.DescComment(), - Name: l.ConsumeIdent(), + EnumValue: l.ConsumeIdent(), Directives: common.ParseDirectives(l), } - enum.Values = append(enum.Values, v) + enum.EnumValuesDefinition = append(enum.EnumValuesDefinition, v) } l.ConsumeToken('}') return enum } - -func parseDirectiveDef(l *common.Lexer) *DirectiveDecl { +func parseDirectiveDef(l *common.Lexer) *types.DirectiveDefinition { l.ConsumeToken('@') - d := &DirectiveDecl{Name: l.ConsumeIdent()} + d := &types.DirectiveDefinition{Name: l.ConsumeIdentWithLoc()} if l.Peek() == '(' { l.ConsumeToken('(') for l.Peek() != ')' { v := common.ParseInputValue(l) - d.Args = append(d.Args, v) + d.Arguments = append(d.Arguments, v) } l.ConsumeToken(')') } @@ -703,7 +474,7 @@ func parseDirectiveDef(l *common.Lexer) *DirectiveDecl { for { loc := l.ConsumeIdent() - d.Locs = append(d.Locs, loc) + d.Locations = append(d.Locations, loc) if l.Peek() != '|' { break } @@ -712,7 +483,7 @@ func parseDirectiveDef(l *common.Lexer) *DirectiveDecl { return d } -func parseExtension(s *Schema, l *common.Lexer) { +func parseExtension(s *types.Schema, l *common.Lexer) { switch x := l.ConsumeIdent(); x { case "schema": l.ConsumeToken('{') @@ -720,46 +491,46 @@ func parseExtension(s *Schema, l *common.Lexer) { name := l.ConsumeIdent() l.ConsumeToken(':') typ := l.ConsumeIdent() - s.entryPointNames[name] = typ + s.EntryPointNames[name] = typ } l.ConsumeToken('}') case "type": obj := parseObjectDef(l) - s.extensions = append(s.extensions, &Extension{Type: obj}) + s.Extensions = append(s.Extensions, &types.Extension{Type: obj}) case "interface": iface := parseInterfaceDef(l) - s.extensions = append(s.extensions, &Extension{Type: iface}) + s.Extensions = append(s.Extensions, &types.Extension{Type: iface}) case "union": union := parseUnionDef(l) - s.extensions = append(s.extensions, &Extension{Type: union}) + s.Extensions = append(s.Extensions, &types.Extension{Type: union}) case "enum": enum := parseEnumDef(l) - s.extensions = append(s.extensions, &Extension{Type: enum}) + s.Extensions = append(s.Extensions, &types.Extension{Type: enum}) case "input": input := parseInputDef(l) - s.extensions = append(s.extensions, &Extension{Type: input}) + s.Extensions = append(s.Extensions, &types.Extension{Type: input}) default: - // TODO: Add Scalar when adding directives + // TODO: Add ScalarTypeDefinition when adding directives l.SyntaxError(fmt.Sprintf(`unexpected %q, expecting "schema", "type", "enum", "interface", "union" or "input"`, x)) } } -func parseFieldsDef(l *common.Lexer) FieldList { - var fields FieldList +func parseFieldsDef(l *common.Lexer) types.FieldsDefinition { + var fields types.FieldsDefinition for l.Peek() != '}' { - f := &Field{} + f := &types.FieldDefinition{} f.Desc = l.DescComment() - f.Name = l.ConsumeIdent() + f.Name = l.ConsumeIdentWithLoc() if l.Peek() == '(' { l.ConsumeToken('(') for l.Peek() != ')' { - f.Args = append(f.Args, common.ParseInputValue(l)) + f.Arguments = append(f.Arguments, common.ParseInputValue(l)) } l.ConsumeToken(')') } diff --git a/internal/schema/schema_internal_test.go b/internal/schema/schema_internal_test.go index d652f5d5..0e731798 100644 --- a/internal/schema/schema_internal_test.go +++ b/internal/schema/schema_internal_test.go @@ -5,25 +5,26 @@ import ( "github.com/graph-gophers/graphql-go/errors" "github.com/graph-gophers/graphql-go/internal/common" + "github.com/graph-gophers/graphql-go/types" ) func TestParseInterfaceDef(t *testing.T) { type testCase struct { description string definition string - expected *Interface + expected *types.InterfaceTypeDefinition err *errors.QueryError } tests := []testCase{{ description: "Parses simple interface", definition: "Greeting { field: String }", - expected: &Interface{Name: "Greeting", Fields: []*Field{{Name: "field"}}}, + expected: &types.InterfaceTypeDefinition{Name: "Greeting", Fields: types.FieldsDefinition{&types.FieldDefinition{Name: types.Ident{Name: "field"}}}}, }} for _, test := range tests { t.Run(test.description, func(t *testing.T) { - var actual *Interface + var actual *types.InterfaceTypeDefinition lex := setup(t, test.definition) parse := func() { actual = parseInterfaceDef(lex) } @@ -41,31 +42,31 @@ func TestParseObjectDef(t *testing.T) { type testCase struct { description string definition string - expected *Object + expected *types.ObjectTypeDefinition err *errors.QueryError } tests := []testCase{{ description: "Parses type inheriting single interface", definition: "Hello implements World { field: String }", - expected: &Object{Name: "Hello", interfaceNames: []string{"World"}}, + expected: &types.ObjectTypeDefinition{Name: types.Ident{Name: "Hello", Loc: errors.Location{Line: 1, Column: 1}}, InterfaceNames: []string{"World"}}, }, { description: "Parses type inheriting multiple interfaces", definition: "Hello implements Wo & rld { field: String }", - expected: &Object{Name: "Hello", interfaceNames: []string{"Wo", "rld"}}, + expected: &types.ObjectTypeDefinition{Name: types.Ident{Name: "Hello", Loc: errors.Location{Line: 1, Column: 1}}, InterfaceNames: []string{"Wo", "rld"}}, }, { description: "Parses type inheriting multiple interfaces with leading ampersand", definition: "Hello implements & Wo & rld { field: String }", - expected: &Object{Name: "Hello", interfaceNames: []string{"Wo", "rld"}}, + expected: &types.ObjectTypeDefinition{Name: types.Ident{Name: "Hello", Loc: errors.Location{Line: 1, Column: 1}}, InterfaceNames: []string{"Wo", "rld"}}, }, { description: "Allows legacy SDL interfaces", definition: "Hello implements Wo, rld { field: String }", - expected: &Object{Name: "Hello", interfaceNames: []string{"Wo", "rld"}}, + expected: &types.ObjectTypeDefinition{Name: types.Ident{Name: "Hello", Loc: errors.Location{Line: 1, Column: 1}}, InterfaceNames: []string{"Wo", "rld"}}, }} for _, test := range tests { t.Run(test.description, func(t *testing.T) { - var actual *Object + var actual *types.ObjectTypeDefinition lex := setup(t, test.definition) parse := func() { actual = parseObjectDef(lex) } @@ -95,7 +96,7 @@ func compareErrors(t *testing.T, expected, actual *errors.QueryError) { } } -func compareInterfaces(t *testing.T, expected, actual *Interface) { +func compareInterfaces(t *testing.T, expected, actual *types.InterfaceTypeDefinition) { t.Helper() // TODO: We can probably extract this switch statement into its own function. @@ -117,13 +118,13 @@ func compareInterfaces(t *testing.T, expected, actual *Interface) { } for i, f := range expected.Fields { - if f.Name != actual.Fields[i].Name { + if f.Name.Name != actual.Fields[i].Name.Name { t.Errorf("fields[%d]: wrong field name: want %q, got %q", i, f.Name, actual.Fields[i].Name) } } } -func compareObjects(t *testing.T, expected, actual *Object) { +func compareObjects(t *testing.T, expected, actual *types.ObjectTypeDefinition) { t.Helper() switch { @@ -139,16 +140,16 @@ func compareObjects(t *testing.T, expected, actual *Object) { t.Errorf("wrong object name: want %q, got %q", expected.Name, actual.Name) } - if len(expected.interfaceNames) != len(actual.interfaceNames) { + if len(expected.InterfaceNames) != len(actual.InterfaceNames) { t.Fatalf( "wrong number of interface names: want %s, got %s", - expected.interfaceNames, - actual.interfaceNames, + expected.InterfaceNames, + actual.InterfaceNames, ) } - for i, expectedName := range expected.interfaceNames { - actualName := actual.interfaceNames[i] + for i, expectedName := range expected.InterfaceNames { + actualName := actual.InterfaceNames[i] if expectedName != actualName { t.Errorf("wrong interface name: want %q, got %q", expectedName, actualName) } diff --git a/internal/schema/schema_test.go b/internal/schema/schema_test.go index 1cb400cf..31a1eaaf 100644 --- a/internal/schema/schema_test.go +++ b/internal/schema/schema_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/graph-gophers/graphql-go/internal/schema" + "github.com/graph-gophers/graphql-go/types" ) func TestParse(t *testing.T) { @@ -13,14 +14,14 @@ func TestParse(t *testing.T) { sdl string useStringDescriptions bool validateError func(err error) error - validateSchema func(s *schema.Schema) error + validateSchema func(s *types.Schema) error }{ { name: "Parses interface definition", sdl: "interface Greeting { message: String! }", - validateSchema: func(s *schema.Schema) error { + validateSchema: func(s *types.Schema) error { const typeName = "Greeting" - typ, ok := s.Types[typeName].(*schema.Interface) + typ, ok := s.Types[typeName].(*types.InterfaceTypeDefinition) if !ok { return fmt.Errorf("interface %q not found", typeName) } @@ -28,7 +29,7 @@ func TestParse(t *testing.T) { return fmt.Errorf("invalid number of fields: want %d, have %d", want, have) } const fieldName = "message" - if typ.Fields[0].Name != fieldName { + if typ.Fields[0].Name.Name != fieldName { return fmt.Errorf("field %q not found", fieldName) } return nil @@ -60,9 +61,9 @@ func TestParse(t *testing.T) { field: String }`, useStringDescriptions: true, - validateSchema: func(s *schema.Schema) error { + validateSchema: func(s *types.Schema) error { const typeName = "Type" - typ, ok := s.Types[typeName].(*schema.Object) + typ, ok := s.Types[typeName].(*types.ObjectTypeDefinition) if !ok { return fmt.Errorf("type %q not found", typeName) } @@ -82,9 +83,9 @@ func TestParse(t *testing.T) { field: String }`, useStringDescriptions: true, - validateSchema: func(s *schema.Schema) error { + validateSchema: func(s *types.Schema) error { const typeName = "Type" - typ, ok := s.Types[typeName].(*schema.Object) + typ, ok := s.Types[typeName].(*types.ObjectTypeDefinition) if !ok { return fmt.Errorf("type %q not found", typeName) } @@ -103,9 +104,9 @@ func TestParse(t *testing.T) { field: String }`, useStringDescriptions: true, - validateSchema: func(s *schema.Schema) error { + validateSchema: func(s *types.Schema) error { const typeName = "Type" - typ, ok := s.Types[typeName].(*schema.Object) + typ, ok := s.Types[typeName].(*types.ObjectTypeDefinition) if !ok { return fmt.Errorf("type %q not found", typeName) } @@ -138,9 +139,9 @@ func TestParse(t *testing.T) { field: String }`, useStringDescriptions: true, - validateSchema: func(s *schema.Schema) error { + validateSchema: func(s *types.Schema) error { const typeName = "Type" - typ, ok := s.Types[typeName].(*schema.Object) + typ, ok := s.Types[typeName].(*types.ObjectTypeDefinition) if !ok { return fmt.Errorf("type %q not found", typeName) } @@ -163,9 +164,9 @@ Second line of the description. field: String }`, useStringDescriptions: true, - validateSchema: func(s *schema.Schema) error { + validateSchema: func(s *types.Schema) error { const typeName = "Type" - typ, ok := s.Types[typeName].(*schema.Object) + typ, ok := s.Types[typeName].(*types.ObjectTypeDefinition) if !ok { return fmt.Errorf("type %q not found", typeName) } @@ -194,9 +195,9 @@ Second line of the description. field: String }`, useStringDescriptions: true, - validateSchema: func(s *schema.Schema) error { + validateSchema: func(s *types.Schema) error { const typeName = "Type" - typ, ok := s.Types[typeName].(*schema.Object) + typ, ok := s.Types[typeName].(*types.ObjectTypeDefinition) if !ok { return fmt.Errorf("type %q not found", typeName) } @@ -218,9 +219,9 @@ Second line of the description. field: String }`, useStringDescriptions: true, - validateSchema: func(s *schema.Schema) error { + validateSchema: func(s *types.Schema) error { const typeName = "Type" - typ, ok := s.Types[typeName].(*schema.Object) + typ, ok := s.Types[typeName].(*types.ObjectTypeDefinition) if !ok { return fmt.Errorf("type %q not found", typeName) } @@ -239,7 +240,7 @@ Second line of the description. field: String }`, useStringDescriptions: true, - validateSchema: func(s *schema.Schema) error { + validateSchema: func(s *types.Schema) error { typ, ok := s.Types["Type"] if !ok { return fmt.Errorf("type %q not found", "Type") @@ -260,7 +261,7 @@ Second line of the description. type Type { field: String }`, - validateSchema: func(s *schema.Schema) error { + validateSchema: func(s *types.Schema) error { typ, ok := s.Types["MyInt"] if !ok { return fmt.Errorf("scalar %q not found", "MyInt") @@ -288,8 +289,8 @@ Second line of the description. concat(a: String!, b: String!): String! } `, - validateSchema: func(s *schema.Schema) error { - typq, ok := s.Types["Query"].(*schema.Object) + validateSchema: func(s *types.Schema) error { + typq, ok := s.Types["Query"].(*types.ObjectTypeDefinition) if !ok { return fmt.Errorf("type %q not found", "Query") } @@ -301,7 +302,7 @@ Second line of the description. return fmt.Errorf("field %q has an invalid type: %q", "hello", helloField.Type.String()) } - typm, ok := s.Types["Mutation"].(*schema.Object) + typm, ok := s.Types["Mutation"].(*types.ObjectTypeDefinition) if !ok { return fmt.Errorf("type %q not found", "Mutation") } @@ -312,8 +313,8 @@ Second line of the description. if concatField.Type.String() != "String!" { return fmt.Errorf("field %q has an invalid type: %q", "concat", concatField.Type.String()) } - if len(concatField.Args) != 2 || concatField.Args[0] == nil || concatField.Args[1] == nil || concatField.Args[0].Type.String() != "String!" || concatField.Args[1].Type.String() != "String!" { - return fmt.Errorf("field %q has an invalid args: %+v", "concat", concatField.Args) + if len(concatField.Arguments) != 2 || concatField.Arguments[0] == nil || concatField.Arguments[1] == nil || concatField.Arguments[0].Type.String() != "String!" || concatField.Arguments[1].Type.String() != "String!" { + return fmt.Errorf("field %q has an invalid args: %+v", "concat", concatField.Arguments) } return nil }, @@ -328,8 +329,8 @@ Second line of the description. extend type Query { world: String! }`, - validateSchema: func(s *schema.Schema) error { - typ, ok := s.Types["Query"].(*schema.Object) + validateSchema: func(s *types.Schema) error { + typ, ok := s.Types["Query"].(*types.ObjectTypeDefinition) if !ok { return fmt.Errorf("type %q not found", "Query") } @@ -368,8 +369,8 @@ Second line of the description. concat(a: String!, b: String!): String! } `, - validateSchema: func(s *schema.Schema) error { - typq, ok := s.Types["Query"].(*schema.Object) + validateSchema: func(s *types.Schema) error { + typq, ok := s.Types["Query"].(*types.ObjectTypeDefinition) if !ok { return fmt.Errorf("type %q not found", "Query") } @@ -381,7 +382,7 @@ Second line of the description. return fmt.Errorf("field %q has an invalid type: %q", "hello", helloField.Type.String()) } - typm, ok := s.Types["Mutation"].(*schema.Object) + typm, ok := s.Types["Mutation"].(*types.ObjectTypeDefinition) if !ok { return fmt.Errorf("type %q not found", "Mutation") } @@ -392,8 +393,8 @@ Second line of the description. if concatField.Type.String() != "String!" { return fmt.Errorf("field %q has an invalid type: %q", "concat", concatField.Type.String()) } - if len(concatField.Args) != 2 || concatField.Args[0] == nil || concatField.Args[1] == nil || concatField.Args[0].Type.String() != "String!" || concatField.Args[1].Type.String() != "String!" { - return fmt.Errorf("field %q has an invalid args: %+v", "concat", concatField.Args) + if len(concatField.Arguments) != 2 || concatField.Arguments[0] == nil || concatField.Arguments[1] == nil || concatField.Arguments[0].Type.String() != "String!" || concatField.Arguments[1].Type.String() != "String!" { + return fmt.Errorf("field %q has an invalid args: %+v", "concat", concatField.Arguments) } return nil }, @@ -410,8 +411,8 @@ Second line of the description. extend type Product implements Named { name: String! }`, - validateSchema: func(s *schema.Schema) error { - typ, ok := s.Types["Product"].(*schema.Object) + validateSchema: func(s *types.Schema) error { + typ, ok := s.Types["Product"].(*types.ObjectTypeDefinition) if !ok { return fmt.Errorf("type %q not found", "Product") } @@ -430,7 +431,7 @@ Second line of the description. return fmt.Errorf("field %q has an invalid type: %q", "name", nameField.Type.String()) } - ifc, ok := s.Types["Named"].(*schema.Interface) + ifc, ok := s.Types["Named"].(*types.InterfaceTypeDefinition) if !ok { return fmt.Errorf("type %q not found", "Named") } @@ -459,21 +460,21 @@ Second line of the description. } extend union Item = Coloured `, - validateSchema: func(s *schema.Schema) error { - typ, ok := s.Types["Item"].(*schema.Union) + validateSchema: func(s *types.Schema) error { + typ, ok := s.Types["Item"].(*types.Union) if !ok { return fmt.Errorf("type %q not found", "Item") } - if len(typ.PossibleTypes) != 3 { - return fmt.Errorf("Expected 3 possible types, but instead got %d types", len(typ.PossibleTypes)) + if len(typ.UnionMemberTypes) != 3 { + return fmt.Errorf("Expected 3 possible types, but instead got %d types", len(typ.UnionMemberTypes)) } posible := map[string]struct{}{ "Coloured": struct{}{}, "Named": struct{}{}, "Numbered": struct{}{}, } - for _, pt := range typ.PossibleTypes { - if _, ok := posible[pt.Name]; !ok { + for _, pt := range typ.UnionMemberTypes { + if _, ok := posible[pt.Name.Name]; !ok { return fmt.Errorf("Unexpected possible type %q", pt.Name) } } @@ -493,13 +494,13 @@ Second line of the description. GBP } `, - validateSchema: func(s *schema.Schema) error { - typ, ok := s.Types["Currencies"].(*schema.Enum) + validateSchema: func(s *types.Schema) error { + typ, ok := s.Types["Currencies"].(*types.EnumTypeDefinition) if !ok { return fmt.Errorf("enum %q not found", "Currencies") } - if len(typ.Values) != 5 { - return fmt.Errorf("Expected 5 enum values, but instead got %d types", len(typ.Values)) + if len(typ.EnumValuesDefinition) != 5 { + return fmt.Errorf("Expected 5 enum values, but instead got %d types", len(typ.EnumValuesDefinition)) } posible := map[string]struct{}{ "AUD": struct{}{}, @@ -508,9 +509,9 @@ Second line of the description. "BGN": struct{}{}, "GBP": struct{}{}, } - for _, v := range typ.Values { - if _, ok := posible[v.Name]; !ok { - return fmt.Errorf("Unexpected enum value %q", v.Name) + for _, v := range typ.EnumValuesDefinition { + if _, ok := posible[v.EnumValue]; !ok { + return fmt.Errorf("Unexpected enum value %q", v.EnumValue) } } return nil @@ -594,21 +595,21 @@ Second line of the description. extend union Item = Coloured `, - validateSchema: func(s *schema.Schema) error { - typ, ok := s.Types["Item"].(*schema.Union) + validateSchema: func(s *types.Schema) error { + typ, ok := s.Types["Item"].(*types.Union) if !ok { return fmt.Errorf("type %q not found", "Item") } - if len(typ.PossibleTypes) != 3 { - return fmt.Errorf("Expected 3 possible types, but instead got %d types", len(typ.PossibleTypes)) + if len(typ.UnionMemberTypes) != 3 { + return fmt.Errorf("Expected 3 possible types, but instead got %d types", len(typ.UnionMemberTypes)) } posible := map[string]struct{}{ "Coloured": struct{}{}, "Named": struct{}{}, "Numbered": struct{}{}, } - for _, pt := range typ.PossibleTypes { - if _, ok := posible[pt.Name]; !ok { + for _, pt := range typ.UnionMemberTypes { + if _, ok := posible[pt.Name.Name]; !ok { return fmt.Errorf("Unexpected possible type %q", pt.Name) } } @@ -631,8 +632,8 @@ Second line of the description. name: String! } `, - validateSchema: func(s *schema.Schema) error { - typ, ok := s.Types["Product"].(*schema.InputObject) + validateSchema: func(s *types.Schema) error { + typ, ok := s.Types["Product"].(*types.InputObject) if !ok { return fmt.Errorf("type %q not found", "Product") } @@ -731,8 +732,8 @@ Second line of the description. category: String! } `, - validateSchema: func(s *schema.Schema) error { - typ, ok := s.Types["Product"].(*schema.Interface) + validateSchema: func(s *types.Schema) error { + typ, ok := s.Types["Product"].(*types.InterfaceTypeDefinition) if !ok { return fmt.Errorf("type %q not found", "Product") } @@ -745,7 +746,7 @@ Second line of the description. "category": struct{}{}, } for _, f := range typ.Fields { - if _, ok := fields[f.Name]; !ok { + if _, ok := fields[f.Name.Name]; !ok { return fmt.Errorf("Unexpected field %q", f.Name) } } @@ -821,18 +822,18 @@ Second line of the description. union Union @uniondirective = Photo | Person `, - validateSchema: func(s *schema.Schema) error { - namedEntityDirectives := s.Types["NamedEntity"].(*schema.Interface).Directives + validateSchema: func(s *types.Schema) error { + namedEntityDirectives := s.Types["NamedEntity"].(*types.InterfaceTypeDefinition).Directives if len(namedEntityDirectives) != 1 || namedEntityDirectives[0].Name.Name != "directive" { return fmt.Errorf("missing directive on NamedEntity interface, expected @directive but got %v", namedEntityDirectives) } - timeDirectives := s.Types["Time"].(*schema.Scalar).Directives + timeDirectives := s.Types["Time"].(*types.ScalarTypeDefinition).Directives if len(timeDirectives) != 1 || timeDirectives[0].Name.Name != "directive" { return fmt.Errorf("missing directive on Time scalar, expected @directive but got %v", timeDirectives) } - photo := s.Types["Photo"].(*schema.Object) + photo := s.Types["Photo"].(*types.ObjectTypeDefinition) photoDirectives := photo.Directives if len(photoDirectives) != 1 || photoDirectives[0].Name.Name != "objectdirective" { return fmt.Errorf("missing directive on Time scalar, expected @objectdirective but got %v", photoDirectives) @@ -841,12 +842,12 @@ Second line of the description. return fmt.Errorf("expected Photo.id to have 2 directives but got %v", photoDirectives) } - directionDirectives := s.Types["Direction"].(*schema.Enum).Directives + directionDirectives := s.Types["Direction"].(*types.EnumTypeDefinition).Directives if len(directionDirectives) != 1 || directionDirectives[0].Name.Name != "enumdirective" { return fmt.Errorf("missing directive on Direction enum, expected @enumdirective but got %v", directionDirectives) } - unionDirectives := s.Types["Union"].(*schema.Union).Directives + unionDirectives := s.Types["Union"].(*types.Union).Directives if len(unionDirectives) != 1 || unionDirectives[0].Name.Name != "uniondirective" { return fmt.Errorf("missing directive on Union union, expected @uniondirective but got %v", unionDirectives) } @@ -855,8 +856,8 @@ Second line of the description. }, } { t.Run(test.name, func(t *testing.T) { - s := schema.New() - if err := s.Parse(test.sdl, test.useStringDescriptions); err != nil { + s, err := schema.ParseSchema(test.sdl, test.useStringDescriptions) + if err != nil { if test.validateError == nil { t.Fatal(err) } diff --git a/internal/validation/validate_max_depth_test.go b/internal/validation/validate_max_depth_test.go index abc337cb..f8bfd9a8 100644 --- a/internal/validation/validate_max_depth_test.go +++ b/internal/validation/validate_max_depth_test.go @@ -5,6 +5,7 @@ import ( "github.com/graph-gophers/graphql-go/internal/query" "github.com/graph-gophers/graphql-go/internal/schema" + "github.com/graph-gophers/graphql-go/types" ) const ( @@ -70,7 +71,7 @@ type maxDepthTestCase struct { expectedErrors []string } -func (tc maxDepthTestCase) Run(t *testing.T, s *schema.Schema) { +func (tc maxDepthTestCase) Run(t *testing.T, s *types.Schema) { t.Run(tc.name, func(t *testing.T) { doc, qErr := query.Parse(tc.query) if qErr != nil { @@ -103,9 +104,7 @@ func (tc maxDepthTestCase) Run(t *testing.T, s *schema.Schema) { } func TestMaxDepth(t *testing.T) { - s := schema.New() - - err := s.Parse(simpleSchema, false) + s, err := schema.ParseSchema(simpleSchema, false) if err != nil { t.Fatal(err) } @@ -179,9 +178,7 @@ func TestMaxDepth(t *testing.T) { } func TestMaxDepthInlineFragments(t *testing.T) { - s := schema.New() - - err := s.Parse(interfaceSimple, false) + s, err := schema.ParseSchema(interfaceSimple, false) if err != nil { t.Fatal(err) } @@ -228,9 +225,7 @@ func TestMaxDepthInlineFragments(t *testing.T) { } func TestMaxDepthFragmentSpreads(t *testing.T) { - s := schema.New() - - err := s.Parse(interfaceSimple, false) + s, err := schema.ParseSchema(interfaceSimple, false) if err != nil { t.Fatal(err) } @@ -315,9 +310,7 @@ func TestMaxDepthFragmentSpreads(t *testing.T) { } func TestMaxDepthUnknownFragmentSpreads(t *testing.T) { - s := schema.New() - - err := s.Parse(interfaceSimple, false) + s, err := schema.ParseSchema(interfaceSimple, false) if err != nil { t.Fatal(err) } @@ -350,9 +343,7 @@ func TestMaxDepthUnknownFragmentSpreads(t *testing.T) { } func TestMaxDepthValidation(t *testing.T) { - s := schema.New() - - err := s.Parse(interfaceSimple, false) + s, err := schema.ParseSchema(interfaceSimple, false) if err != nil { t.Fatal(err) } diff --git a/internal/validation/validation.go b/internal/validation/validation.go index c8be7354..c03b4cb3 100644 --- a/internal/validation/validation.go +++ b/internal/validation/validation.go @@ -11,25 +11,27 @@ import ( "github.com/graph-gophers/graphql-go/errors" "github.com/graph-gophers/graphql-go/internal/common" "github.com/graph-gophers/graphql-go/internal/query" - "github.com/graph-gophers/graphql-go/internal/schema" + "github.com/graph-gophers/graphql-go/types" ) -type varSet map[*common.InputValue]struct{} +type varSet map[*types.InputValueDefinition]struct{} -type selectionPair struct{ a, b query.Selection } +type selectionPair struct{ a, b types.Selection } + +type nameSet map[string]errors.Location type fieldInfo struct { - sf *schema.Field - parent schema.NamedType + sf *types.FieldDefinition + parent types.NamedType } type context struct { - schema *schema.Schema - doc *query.Document + schema *types.Schema + doc *types.ExecutableDefinition errs []*errors.QueryError - opErrs map[*query.Operation][]*errors.QueryError - usedVars map[*query.Operation]varSet - fieldMap map[*query.Field]fieldInfo + opErrs map[*types.OperationDefinition][]*errors.QueryError + usedVars map[*types.OperationDefinition]varSet + fieldMap map[*types.Field]fieldInfo overlapValidated map[selectionPair]struct{} maxDepth int } @@ -48,29 +50,29 @@ func (c *context) addErrMultiLoc(locs []errors.Location, rule string, format str type opContext struct { *context - ops []*query.Operation + ops []*types.OperationDefinition } -func newContext(s *schema.Schema, doc *query.Document, maxDepth int) *context { +func newContext(s *types.Schema, doc *types.ExecutableDefinition, maxDepth int) *context { return &context{ schema: s, doc: doc, - opErrs: make(map[*query.Operation][]*errors.QueryError), - usedVars: make(map[*query.Operation]varSet), - fieldMap: make(map[*query.Field]fieldInfo), + opErrs: make(map[*types.OperationDefinition][]*errors.QueryError), + usedVars: make(map[*types.OperationDefinition]varSet), + fieldMap: make(map[*types.Field]fieldInfo), overlapValidated: make(map[selectionPair]struct{}), maxDepth: maxDepth, } } -func Validate(s *schema.Schema, doc *query.Document, variables map[string]interface{}, maxDepth int) []*errors.QueryError { +func Validate(s *types.Schema, doc *types.ExecutableDefinition, variables map[string]interface{}, maxDepth int) []*errors.QueryError { c := newContext(s, doc, maxDepth) opNames := make(nameSet) - fragUsedBy := make(map[*query.FragmentDecl][]*query.Operation) + fragUsedBy := make(map[*types.FragmentDefinition][]*types.OperationDefinition) for _, op := range doc.Operations { c.usedVars[op] = make(varSet) - opc := &opContext{c, []*query.Operation{op}} + opc := &opContext{c, []*types.OperationDefinition{op}} // Check if max depth is exceeded, if it's set. If max depth is exceeded, // don't continue to validate the document and exit early. @@ -101,7 +103,7 @@ func Validate(s *schema.Schema, doc *query.Document, variables map[string]interf validateLiteral(opc, v.Default) if t != nil { - if nn, ok := t.(*common.NonNull); ok { + if nn, ok := t.(*types.NonNull); ok { c.addErr(v.Default.Location(), "DefaultValuesOfCorrectType", "Variable %q of type %q is required and will not use the default value. Perhaps you meant to use type %q.", "$"+v.Name.Name, t, nn.OfType) } @@ -112,7 +114,7 @@ func Validate(s *schema.Schema, doc *query.Document, variables map[string]interf } } - var entryPoint schema.NamedType + var entryPoint types.NamedType switch op.Type { case query.Query: entryPoint = s.EntryPoints["query"] @@ -126,7 +128,7 @@ func Validate(s *schema.Schema, doc *query.Document, variables map[string]interf validateSelectionSet(opc, op.Selections, entryPoint) - fragUsed := make(map[*query.FragmentDecl]struct{}) + fragUsed := make(map[*types.FragmentDefinition]struct{}) markUsedFragments(c, op.Selections, fragUsed) for frag := range fragUsed { fragUsedBy[frag] = append(fragUsedBy[frag], op) @@ -134,7 +136,7 @@ func Validate(s *schema.Schema, doc *query.Document, variables map[string]interf } fragNames := make(nameSet) - fragVisited := make(map[*query.FragmentDecl]struct{}) + fragVisited := make(map[*types.FragmentDefinition]struct{}) for _, frag := range doc.Fragments { opc := &opContext{c, fragUsedBy[frag]} @@ -179,15 +181,15 @@ func Validate(s *schema.Schema, doc *query.Document, variables map[string]interf return c.errs } -func validateValue(c *opContext, v *common.InputValue, val interface{}, t common.Type) { +func validateValue(c *opContext, v *types.InputValueDefinition, val interface{}, t types.Type) { switch t := t.(type) { - case *common.NonNull: + case *types.NonNull: if val == nil { c.addErr(v.Loc, "VariablesOfCorrectType", "Variable \"%s\" has invalid value null.\nExpected type \"%s\", found null.", v.Name.Name, t) return } validateValue(c, v, val, t.OfType) - case *common.List: + case *types.List: if val == nil { return } @@ -200,7 +202,7 @@ func validateValue(c *opContext, v *common.InputValue, val interface{}, t common for _, elem := range vv { validateValue(c, v, elem, t.OfType) } - case *schema.Enum: + case *types.EnumTypeDefinition: if val == nil { return } @@ -209,13 +211,13 @@ func validateValue(c *opContext, v *common.InputValue, val interface{}, t common c.addErr(v.Loc, "VariablesOfCorrectType", "Variable \"%s\" has invalid type %T.\nExpected type \"%s\", found %v.", v.Name.Name, val, t, val) return } - for _, option := range t.Values { - if option.Name == e { + for _, option := range t.EnumValuesDefinition { + if option.EnumValue == e { return } } c.addErr(v.Loc, "VariablesOfCorrectType", "Variable \"%s\" has invalid value %s.\nExpected type \"%s\", found %s.", v.Name.Name, e, t, e) - case *schema.InputObject: + case *types.InputObject: if val == nil { return } @@ -233,7 +235,7 @@ func validateValue(c *opContext, v *common.InputValue, val interface{}, t common // validates the query doesn't go deeper than maxDepth (if set). Returns whether // or not query validated max depth to avoid excessive recursion. -func validateMaxDepth(c *opContext, sels []query.Selection, depth int) bool { +func validateMaxDepth(c *opContext, sels []types.Selection, depth int) bool { // maxDepth checking is turned off when maxDepth is 0 if c.maxDepth == 0 { return false @@ -243,18 +245,18 @@ func validateMaxDepth(c *opContext, sels []query.Selection, depth int) bool { for _, sel := range sels { switch sel := sel.(type) { - case *query.Field: + case *types.Field: if depth > c.maxDepth { exceededMaxDepth = true c.addErr(sel.Alias.Loc, "MaxDepthExceeded", "Field %q has depth %d that exceeds max depth %d", sel.Name.Name, depth, c.maxDepth) continue } - exceededMaxDepth = exceededMaxDepth || validateMaxDepth(c, sel.Selections, depth+1) - case *query.InlineFragment: + exceededMaxDepth = exceededMaxDepth || validateMaxDepth(c, sel.SelectionSet, depth+1) + case *types.InlineFragment: // Depth is not checked because inline fragments resolve to other fields which are checked. // Depth is not incremented because inline fragments have the same depth as neighboring fields exceededMaxDepth = exceededMaxDepth || validateMaxDepth(c, sel.Selections, depth) - case *query.FragmentSpread: + case *types.FragmentSpread: // Depth is not checked because fragments resolve to other fields which are checked. frag := c.doc.Fragments.Get(sel.Name.Name) if frag == nil { @@ -270,7 +272,7 @@ func validateMaxDepth(c *opContext, sels []query.Selection, depth int) bool { return exceededMaxDepth } -func validateSelectionSet(c *opContext, sels []query.Selection, t schema.NamedType) { +func validateSelectionSet(c *opContext, sels []types.Selection, t types.NamedType) { for _, sel := range sels { validateSelection(c, sel, t) } @@ -282,31 +284,31 @@ func validateSelectionSet(c *opContext, sels []query.Selection, t schema.NamedTy } } -func validateSelection(c *opContext, sel query.Selection, t schema.NamedType) { +func validateSelection(c *opContext, sel types.Selection, t types.NamedType) { switch sel := sel.(type) { - case *query.Field: + case *types.Field: validateDirectives(c, "FIELD", sel.Directives) fieldName := sel.Name.Name - var f *schema.Field + var f *types.FieldDefinition switch fieldName { case "__typename": - f = &schema.Field{ - Name: "__typename", + f = &types.FieldDefinition{ + Name: types.Ident{Name: "__typename"}, Type: c.schema.Types["String"], } case "__schema": - f = &schema.Field{ - Name: "__schema", + f = &types.FieldDefinition{ + Name: types.Ident{Name: "__schema"}, Type: c.schema.Types["__Schema"], } case "__type": - f = &schema.Field{ - Name: "__type", - Args: common.InputValueList{ - &common.InputValue{ - Name: common.Ident{Name: "name"}, - Type: &common.NonNull{OfType: c.schema.Types["String"]}, + f = &types.FieldDefinition{ + Name: types.Ident{Name: "__type"}, + Arguments: types.ArgumentsDefinition{ + &types.InputValueDefinition{ + Name: types.Ident{Name: "name"}, + Type: &types.NonNull{OfType: c.schema.Types["String"]}, }, }, Type: c.schema.Types["__Type"], @@ -322,28 +324,28 @@ func validateSelection(c *opContext, sel query.Selection, t schema.NamedType) { validateArgumentLiterals(c, sel.Arguments) if f != nil { - validateArgumentTypes(c, sel.Arguments, f.Args, sel.Alias.Loc, + validateArgumentTypes(c, sel.Arguments, f.Arguments, sel.Alias.Loc, func() string { return fmt.Sprintf("field %q of type %q", fieldName, t) }, func() string { return fmt.Sprintf("Field %q", fieldName) }, ) } - var ft common.Type + var ft types.Type if f != nil { ft = f.Type sf := hasSubfields(ft) - if sf && sel.Selections == nil { + if sf && sel.SelectionSet == nil { c.addErr(sel.Alias.Loc, "ScalarLeafs", "Field %q of type %q must have a selection of subfields. Did you mean \"%s { ... }\"?", fieldName, ft, fieldName) } - if !sf && sel.Selections != nil { + if !sf && sel.SelectionSet != nil { c.addErr(sel.SelectionSetLoc, "ScalarLeafs", "Field %q must not have a selection since type %q has no subfields.", fieldName, ft) } } - if sel.Selections != nil { - validateSelectionSet(c, sel.Selections, unwrapType(ft)) + if sel.SelectionSet != nil { + validateSelectionSet(c, sel.SelectionSet, unwrapType(ft)) } - case *query.InlineFragment: + case *types.InlineFragment: validateDirectives(c, "INLINE_FRAGMENT", sel.Directives) if sel.On.Name != "" { fragTyp := unwrapType(resolveType(c.context, &sel.On)) @@ -359,7 +361,7 @@ func validateSelection(c *opContext, sel query.Selection, t schema.NamedType) { } validateSelectionSet(c, sel.Selections, unwrapType(t)) - case *query.FragmentSpread: + case *types.FragmentSpread: validateDirectives(c, "FRAGMENT_SPREAD", sel.Directives) frag := c.doc.Fragments.Get(sel.Name.Name) if frag == nil { @@ -376,7 +378,7 @@ func validateSelection(c *opContext, sel query.Selection, t schema.NamedType) { } } -func compatible(a, b common.Type) bool { +func compatible(a, b types.Type) bool { for _, pta := range possibleTypes(a) { for _, ptb := range possibleTypes(b) { if pta == ptb { @@ -387,31 +389,31 @@ func compatible(a, b common.Type) bool { return false } -func possibleTypes(t common.Type) []*schema.Object { +func possibleTypes(t types.Type) []*types.ObjectTypeDefinition { switch t := t.(type) { - case *schema.Object: - return []*schema.Object{t} - case *schema.Interface: - return t.PossibleTypes - case *schema.Union: + case *types.ObjectTypeDefinition: + return []*types.ObjectTypeDefinition{t} + case *types.InterfaceTypeDefinition: return t.PossibleTypes + case *types.Union: + return t.UnionMemberTypes default: return nil } } -func markUsedFragments(c *context, sels []query.Selection, fragUsed map[*query.FragmentDecl]struct{}) { +func markUsedFragments(c *context, sels []types.Selection, fragUsed map[*types.FragmentDefinition]struct{}) { for _, sel := range sels { switch sel := sel.(type) { - case *query.Field: - if sel.Selections != nil { - markUsedFragments(c, sel.Selections, fragUsed) + case *types.Field: + if sel.SelectionSet != nil { + markUsedFragments(c, sel.SelectionSet, fragUsed) } - case *query.InlineFragment: + case *types.InlineFragment: markUsedFragments(c, sel.Selections, fragUsed) - case *query.FragmentSpread: + case *types.FragmentSpread: frag := c.doc.Fragments.Get(sel.Name.Name) if frag == nil { return @@ -430,23 +432,23 @@ func markUsedFragments(c *context, sels []query.Selection, fragUsed map[*query.F } } -func detectFragmentCycle(c *context, sels []query.Selection, fragVisited map[*query.FragmentDecl]struct{}, spreadPath []*query.FragmentSpread, spreadPathIndex map[string]int) { +func detectFragmentCycle(c *context, sels []types.Selection, fragVisited map[*types.FragmentDefinition]struct{}, spreadPath []*types.FragmentSpread, spreadPathIndex map[string]int) { for _, sel := range sels { detectFragmentCycleSel(c, sel, fragVisited, spreadPath, spreadPathIndex) } } -func detectFragmentCycleSel(c *context, sel query.Selection, fragVisited map[*query.FragmentDecl]struct{}, spreadPath []*query.FragmentSpread, spreadPathIndex map[string]int) { +func detectFragmentCycleSel(c *context, sel types.Selection, fragVisited map[*types.FragmentDefinition]struct{}, spreadPath []*types.FragmentSpread, spreadPathIndex map[string]int) { switch sel := sel.(type) { - case *query.Field: - if sel.Selections != nil { - detectFragmentCycle(c, sel.Selections, fragVisited, spreadPath, spreadPathIndex) + case *types.Field: + if sel.SelectionSet != nil { + detectFragmentCycle(c, sel.SelectionSet, fragVisited, spreadPath, spreadPathIndex) } - case *query.InlineFragment: + case *types.InlineFragment: detectFragmentCycle(c, sel.Selections, fragVisited, spreadPath, spreadPathIndex) - case *query.FragmentSpread: + case *types.FragmentSpread: frag := c.doc.Fragments.Get(sel.Name.Name) if frag == nil { return @@ -486,7 +488,7 @@ func detectFragmentCycleSel(c *context, sel query.Selection, fragVisited map[*qu } } -func (c *context) validateOverlap(a, b query.Selection, reasons *[]string, locs *[]errors.Location) { +func (c *context) validateOverlap(a, b types.Selection, reasons *[]string, locs *[]errors.Location) { if a == b { return } @@ -498,9 +500,9 @@ func (c *context) validateOverlap(a, b query.Selection, reasons *[]string, locs c.overlapValidated[selectionPair{b, a}] = struct{}{} switch a := a.(type) { - case *query.Field: + case *types.Field: switch b := b.(type) { - case *query.Field: + case *types.Field: if b.Alias.Loc.Before(a.Alias.Loc) { a, b = b, a } @@ -516,12 +518,12 @@ func (c *context) validateOverlap(a, b query.Selection, reasons *[]string, locs *locs = append(*locs, locs2...) } - case *query.InlineFragment: + case *types.InlineFragment: for _, sel := range b.Selections { c.validateOverlap(a, sel, reasons, locs) } - case *query.FragmentSpread: + case *types.FragmentSpread: if frag := c.doc.Fragments.Get(b.Name.Name); frag != nil { for _, sel := range frag.Selections { c.validateOverlap(a, sel, reasons, locs) @@ -532,12 +534,12 @@ func (c *context) validateOverlap(a, b query.Selection, reasons *[]string, locs panic("unreachable") } - case *query.InlineFragment: + case *types.InlineFragment: for _, sel := range a.Selections { c.validateOverlap(sel, b, reasons, locs) } - case *query.FragmentSpread: + case *types.FragmentSpread: if frag := c.doc.Fragments.Get(a.Name.Name); frag != nil { for _, sel := range frag.Selections { c.validateOverlap(sel, b, reasons, locs) @@ -549,7 +551,7 @@ func (c *context) validateOverlap(a, b query.Selection, reasons *[]string, locs } } -func (c *context) validateFieldOverlap(a, b *query.Field) ([]string, []errors.Location) { +func (c *context) validateFieldOverlap(a, b *types.Field) ([]string, []errors.Location) { if a.Alias.Name != b.Alias.Name { return nil, nil } @@ -576,49 +578,49 @@ func (c *context) validateFieldOverlap(a, b *query.Field) ([]string, []errors.Lo var reasons []string var locs []errors.Location - for _, a2 := range a.Selections { - for _, b2 := range b.Selections { + for _, a2 := range a.SelectionSet { + for _, b2 := range b.SelectionSet { c.validateOverlap(a2, b2, &reasons, &locs) } } return reasons, locs } -func argumentsConflict(a, b common.ArgumentList) bool { +func argumentsConflict(a, b types.ArgumentList) bool { if len(a) != len(b) { return true } for _, argA := range a { valB, ok := b.Get(argA.Name.Name) - if !ok || !reflect.DeepEqual(argA.Value.Value(nil), valB.Value(nil)) { + if !ok || !reflect.DeepEqual(argA.Value.Deserialize(nil), valB.Deserialize(nil)) { return true } } return false } -func fields(t common.Type) schema.FieldList { +func fields(t types.Type) types.FieldsDefinition { switch t := t.(type) { - case *schema.Object: + case *types.ObjectTypeDefinition: return t.Fields - case *schema.Interface: + case *types.InterfaceTypeDefinition: return t.Fields default: return nil } } -func unwrapType(t common.Type) schema.NamedType { +func unwrapType(t types.Type) types.NamedType { if t == nil { return nil } for { switch t2 := t.(type) { - case schema.NamedType: + case types.NamedType: return t2 - case *common.List: + case *types.List: t = t2.OfType - case *common.NonNull: + case *types.NonNull: t = t2.OfType default: panic("unreachable") @@ -626,7 +628,7 @@ func unwrapType(t common.Type) schema.NamedType { } } -func resolveType(c *context, t common.Type) common.Type { +func resolveType(c *context, t types.Type) types.Type { t2, err := common.ResolveType(t, c.schema.Resolve) if err != nil { c.errs = append(c.errs, err) @@ -634,7 +636,7 @@ func resolveType(c *context, t common.Type) common.Type { return t2 } -func validateDirectives(c *opContext, loc string, directives common.DirectiveList) { +func validateDirectives(c *opContext, loc string, directives types.DirectiveList) { directiveNames := make(nameSet) for _, d := range directives { dirName := d.Name.Name @@ -642,7 +644,7 @@ func validateDirectives(c *opContext, loc string, directives common.DirectiveLis return fmt.Sprintf("The directive %q can only be used once at this location.", dirName) }) - validateArgumentLiterals(c, d.Args) + validateArgumentLiterals(c, d.Arguments) dd, ok := c.schema.Directives[dirName] if !ok { @@ -651,7 +653,7 @@ func validateDirectives(c *opContext, loc string, directives common.DirectiveLis } locOK := false - for _, allowedLoc := range dd.Locs { + for _, allowedLoc := range dd.Locations { if loc == allowedLoc { locOK = true break @@ -661,22 +663,20 @@ func validateDirectives(c *opContext, loc string, directives common.DirectiveLis c.addErr(d.Name.Loc, "KnownDirectives", "Directive %q may not be used on %s.", dirName, loc) } - validateArgumentTypes(c, d.Args, dd.Args, d.Name.Loc, + validateArgumentTypes(c, d.Arguments, dd.Arguments, d.Name.Loc, func() string { return fmt.Sprintf("directive %q", "@"+dirName) }, func() string { return fmt.Sprintf("Directive %q", "@"+dirName) }, ) } } -type nameSet map[string]errors.Location - -func validateName(c *context, set nameSet, name common.Ident, rule string, kind string) { +func validateName(c *context, set nameSet, name types.Ident, rule string, kind string) { validateNameCustomMsg(c, set, name, rule, func() string { return fmt.Sprintf("There can be only one %s named %q.", kind, name.Name) }) } -func validateNameCustomMsg(c *context, set nameSet, name common.Ident, rule string, msg func() string) { +func validateNameCustomMsg(c *context, set nameSet, name types.Ident, rule string, msg func() string) { if loc, ok := set[name.Name]; ok { c.addErrMultiLoc([]errors.Location{loc, name.Loc}, rule, msg()) return @@ -684,7 +684,7 @@ func validateNameCustomMsg(c *context, set nameSet, name common.Ident, rule stri set[name.Name] = name.Loc } -func validateArgumentTypes(c *opContext, args common.ArgumentList, argDecls common.InputValueList, loc errors.Location, owner1, owner2 func() string) { +func validateArgumentTypes(c *opContext, args types.ArgumentList, argDecls types.ArgumentsDefinition, loc errors.Location, owner1, owner2 func() string) { for _, selArg := range args { arg := argDecls.Get(selArg.Name.Name) if arg == nil { @@ -697,7 +697,7 @@ func validateArgumentTypes(c *opContext, args common.ArgumentList, argDecls comm } } for _, decl := range argDecls { - if _, ok := decl.Type.(*common.NonNull); ok { + if _, ok := decl.Type.(*types.NonNull); ok { if _, ok := args.Get(decl.Name.Name); !ok { c.addErr(loc, "ProvidedNonNullArguments", "%s argument %q of type %q is required but not provided.", owner2(), decl.Name.Name, decl.Type) } @@ -705,7 +705,7 @@ func validateArgumentTypes(c *opContext, args common.ArgumentList, argDecls comm } } -func validateArgumentLiterals(c *opContext, args common.ArgumentList) { +func validateArgumentLiterals(c *opContext, args types.ArgumentList) { argNames := make(nameSet) for _, arg := range args { validateName(c.context, argNames, arg.Name, "UniqueArgumentNames", "argument") @@ -713,19 +713,19 @@ func validateArgumentLiterals(c *opContext, args common.ArgumentList) { } } -func validateLiteral(c *opContext, l common.Literal) { +func validateLiteral(c *opContext, l types.Value) { switch l := l.(type) { - case *common.ObjectLit: + case *types.ObjectValue: fieldNames := make(nameSet) for _, f := range l.Fields { validateName(c.context, fieldNames, f.Name, "UniqueInputFieldNames", "input field") validateLiteral(c, f.Value) } - case *common.ListLit: - for _, entry := range l.Entries { + case *types.ListValue: + for _, entry := range l.Values { validateLiteral(c, entry) } - case *common.Variable: + case *types.Variable: for _, op := range c.ops { v := op.Vars.Get(l.Name) if v == nil { @@ -746,13 +746,13 @@ func validateLiteral(c *opContext, l common.Literal) { } } -func validateValueType(c *opContext, v common.Literal, t common.Type) (bool, string) { - if v, ok := v.(*common.Variable); ok { +func validateValueType(c *opContext, v types.Value, t types.Type) (bool, string) { + if v, ok := v.(*types.Variable); ok { for _, op := range c.ops { if v2 := op.Vars.Get(v.Name); v2 != nil { t2, err := common.ResolveType(v2.Type, c.schema.Resolve) - if _, ok := t2.(*common.NonNull); !ok && v2.Default != nil { - t2 = &common.NonNull{OfType: t2} + if _, ok := t2.(*types.NonNull); !ok && v2.Default != nil { + t2 = &types.NonNull{OfType: t2} } if err == nil && !typeCanBeUsedAs(t2, t) { c.addErrMultiLoc([]errors.Location{v2.Loc, v.Loc}, "VariablesInAllowedPosition", "Variable %q of type %q used in position expecting type %q.", "$"+v.Name, t2, t) @@ -762,7 +762,7 @@ func validateValueType(c *opContext, v common.Literal, t common.Type) (bool, str return true, "" } - if nn, ok := t.(*common.NonNull); ok { + if nn, ok := t.(*types.NonNull); ok { if isNull(v) { return false, fmt.Sprintf("Expected %q, found null.", t) } @@ -773,27 +773,27 @@ func validateValueType(c *opContext, v common.Literal, t common.Type) (bool, str } switch t := t.(type) { - case *schema.Scalar, *schema.Enum: - if lit, ok := v.(*common.BasicLit); ok { + case *types.ScalarTypeDefinition, *types.EnumTypeDefinition: + if lit, ok := v.(*types.PrimitiveValue); ok { if validateBasicLit(lit, t) { return true, "" } } - case *common.List: - list, ok := v.(*common.ListLit) + case *types.List: + list, ok := v.(*types.ListValue) if !ok { return validateValueType(c, v, t.OfType) // single value instead of list } - for i, entry := range list.Entries { + for i, entry := range list.Values { if ok, reason := validateValueType(c, entry, t.OfType); !ok { return false, fmt.Sprintf("In element #%d: %s", i, reason) } } return true, "" - case *schema.InputObject: - v, ok := v.(*common.ObjectLit) + case *types.InputObject: + v, ok := v.(*types.ObjectValue) if !ok { return false, fmt.Sprintf("Expected %q, found not an object.", t) } @@ -816,7 +816,7 @@ func validateValueType(c *opContext, v common.Literal, t common.Type) (bool, str } } if !found { - if _, ok := iv.Type.(*common.NonNull); ok && iv.Default == nil { + if _, ok := iv.Type.(*types.NonNull); ok && iv.Default == nil { return false, fmt.Sprintf("In field %q: Expected %q, found null.", iv.Name.Name, iv.Type) } } @@ -827,10 +827,10 @@ func validateValueType(c *opContext, v common.Literal, t common.Type) (bool, str return false, fmt.Sprintf("Expected type %q, found %s.", t, v) } -func validateBasicLit(v *common.BasicLit, t common.Type) bool { +func validateBasicLit(v *types.PrimitiveValue, t types.Type) bool { switch t := t.(type) { - case *schema.Scalar: - switch t.Name { + case *types.ScalarTypeDefinition: + switch t.Name.Name { case "Int": if v.Type != scanner.Int { return false @@ -853,12 +853,12 @@ func validateBasicLit(v *common.BasicLit, t common.Type) bool { return true } - case *schema.Enum: + case *types.EnumTypeDefinition: if v.Type != scanner.Ident { return false } - for _, option := range t.Values { - if option.Name == v.Text { + for _, option := range t.EnumValuesDefinition { + if option.EnumValue == v.Text { return true } } @@ -868,44 +868,44 @@ func validateBasicLit(v *common.BasicLit, t common.Type) bool { return false } -func canBeFragment(t common.Type) bool { +func canBeFragment(t types.Type) bool { switch t.(type) { - case *schema.Object, *schema.Interface, *schema.Union: + case *types.ObjectTypeDefinition, *types.InterfaceTypeDefinition, *types.Union: return true default: return false } } -func canBeInput(t common.Type) bool { +func canBeInput(t types.Type) bool { switch t := t.(type) { - case *schema.InputObject, *schema.Scalar, *schema.Enum: + case *types.InputObject, *types.ScalarTypeDefinition, *types.EnumTypeDefinition: return true - case *common.List: + case *types.List: return canBeInput(t.OfType) - case *common.NonNull: + case *types.NonNull: return canBeInput(t.OfType) default: return false } } -func hasSubfields(t common.Type) bool { +func hasSubfields(t types.Type) bool { switch t := t.(type) { - case *schema.Object, *schema.Interface, *schema.Union: + case *types.ObjectTypeDefinition, *types.InterfaceTypeDefinition, *types.Union: return true - case *common.List: + case *types.List: return hasSubfields(t.OfType) - case *common.NonNull: + case *types.NonNull: return hasSubfields(t.OfType) default: return false } } -func isLeaf(t common.Type) bool { +func isLeaf(t types.Type) bool { switch t.(type) { - case *schema.Scalar, *schema.Enum: + case *types.ScalarTypeDefinition, *types.EnumTypeDefinition: return true default: return false @@ -913,19 +913,19 @@ func isLeaf(t common.Type) bool { } func isNull(lit interface{}) bool { - _, ok := lit.(*common.NullLit) + _, ok := lit.(*types.NullValue) return ok } -func typesCompatible(a, b common.Type) bool { - al, aIsList := a.(*common.List) - bl, bIsList := b.(*common.List) +func typesCompatible(a, b types.Type) bool { + al, aIsList := a.(*types.List) + bl, bIsList := b.(*types.List) if aIsList || bIsList { return aIsList && bIsList && typesCompatible(al.OfType, bl.OfType) } - ann, aIsNN := a.(*common.NonNull) - bnn, bIsNN := b.(*common.NonNull) + ann, aIsNN := a.(*types.NonNull) + bnn, bIsNN := b.(*types.NonNull) if aIsNN || bIsNN { return aIsNN && bIsNN && typesCompatible(ann.OfType, bnn.OfType) } @@ -937,13 +937,13 @@ func typesCompatible(a, b common.Type) bool { return true } -func typeCanBeUsedAs(t, as common.Type) bool { - nnT, okT := t.(*common.NonNull) +func typeCanBeUsedAs(t, as types.Type) bool { + nnT, okT := t.(*types.NonNull) if okT { t = nnT.OfType } - nnAs, okAs := as.(*common.NonNull) + nnAs, okAs := as.(*types.NonNull) if okAs { as = nnAs.OfType if !okT { @@ -955,8 +955,8 @@ func typeCanBeUsedAs(t, as common.Type) bool { return true } - if lT, ok := t.(*common.List); ok { - if lAs, ok := as.(*common.List); ok { + if lT, ok := t.(*types.List); ok { + if lAs, ok := as.(*types.List); ok { return typeCanBeUsedAs(lT.OfType, lAs.OfType) } } diff --git a/internal/validation/validation_test.go b/internal/validation/validation_test.go index e287a526..c26647f8 100644 --- a/internal/validation/validation_test.go +++ b/internal/validation/validation_test.go @@ -12,6 +12,7 @@ import ( "github.com/graph-gophers/graphql-go/internal/query" "github.com/graph-gophers/graphql-go/internal/schema" "github.com/graph-gophers/graphql-go/internal/validation" + "github.com/graph-gophers/graphql-go/types" ) type Test struct { @@ -37,10 +38,11 @@ func TestValidate(t *testing.T) { t.Fatal(err) } - schemas := make([]*schema.Schema, len(testData.Schemas)) + schemas := make([]*types.Schema, len(testData.Schemas)) for i, schemaStr := range testData.Schemas { schemas[i] = schema.New() - if err := schemas[i].Parse(schemaStr, false); err != nil { + err := schema.Parse(schemas[i], schemaStr, false) + if err != nil { t.Fatal(err) } } diff --git a/introspection/introspection.go b/introspection/introspection.go index 2f4acad0..cee61694 100644 --- a/introspection/introspection.go +++ b/introspection/introspection.go @@ -3,16 +3,15 @@ package introspection import ( "sort" - "github.com/graph-gophers/graphql-go/internal/common" - "github.com/graph-gophers/graphql-go/internal/schema" + "github.com/graph-gophers/graphql-go/types" ) type Schema struct { - schema *schema.Schema + schema *types.Schema } // WrapSchema is only used internally. -func WrapSchema(schema *schema.Schema) *Schema { +func WrapSchema(schema *types.Schema) *Schema { return &Schema{schema} } @@ -69,11 +68,11 @@ func (r *Schema) SubscriptionType() *Type { } type Type struct { - typ common.Type + typ types.Type } // WrapType is only used internally. -func WrapType(typ common.Type) *Type { +func WrapType(typ types.Type) *Type { return &Type{typ} } @@ -82,7 +81,7 @@ func (r *Type) Kind() string { } func (r *Type) Name() *string { - if named, ok := r.typ.(schema.NamedType); ok { + if named, ok := r.typ.(types.NamedType); ok { name := named.TypeName() return &name } @@ -90,7 +89,7 @@ func (r *Type) Name() *string { } func (r *Type) Description() *string { - if named, ok := r.typ.(schema.NamedType); ok { + if named, ok := r.typ.(types.NamedType); ok { desc := named.Description() if desc == "" { return nil @@ -101,11 +100,11 @@ func (r *Type) Description() *string { } func (r *Type) Fields(args *struct{ IncludeDeprecated bool }) *[]*Field { - var fields schema.FieldList + var fields types.FieldsDefinition switch t := r.typ.(type) { - case *schema.Object: + case *types.ObjectTypeDefinition: fields = t.Fields - case *schema.Interface: + case *types.InterfaceTypeDefinition: fields = t.Fields default: return nil @@ -114,14 +113,14 @@ func (r *Type) Fields(args *struct{ IncludeDeprecated bool }) *[]*Field { var l []*Field for _, f := range fields { if d := f.Directives.Get("deprecated"); d == nil || args.IncludeDeprecated { - l = append(l, &Field{f}) + l = append(l, &Field{field: f}) } } return &l } func (r *Type) Interfaces() *[]*Type { - t, ok := r.typ.(*schema.Object) + t, ok := r.typ.(*types.ObjectTypeDefinition) if !ok { return nil } @@ -134,12 +133,12 @@ func (r *Type) Interfaces() *[]*Type { } func (r *Type) PossibleTypes() *[]*Type { - var possibleTypes []*schema.Object + var possibleTypes []*types.ObjectTypeDefinition switch t := r.typ.(type) { - case *schema.Interface: - possibleTypes = t.PossibleTypes - case *schema.Union: + case *types.InterfaceTypeDefinition: possibleTypes = t.PossibleTypes + case *types.Union: + possibleTypes = t.UnionMemberTypes default: return nil } @@ -152,13 +151,13 @@ func (r *Type) PossibleTypes() *[]*Type { } func (r *Type) EnumValues(args *struct{ IncludeDeprecated bool }) *[]*EnumValue { - t, ok := r.typ.(*schema.Enum) + t, ok := r.typ.(*types.EnumTypeDefinition) if !ok { return nil } var l []*EnumValue - for _, v := range t.Values { + for _, v := range t.EnumValuesDefinition { if d := v.Directives.Get("deprecated"); d == nil || args.IncludeDeprecated { l = append(l, &EnumValue{v}) } @@ -167,7 +166,7 @@ func (r *Type) EnumValues(args *struct{ IncludeDeprecated bool }) *[]*EnumValue } func (r *Type) InputFields() *[]*InputValue { - t, ok := r.typ.(*schema.InputObject) + t, ok := r.typ.(*types.InputObject) if !ok { return nil } @@ -181,9 +180,9 @@ func (r *Type) InputFields() *[]*InputValue { func (r *Type) OfType() *Type { switch t := r.typ.(type) { - case *common.List: + case *types.List: return &Type{t.OfType} - case *common.NonNull: + case *types.NonNull: return &Type{t.OfType} default: return nil @@ -191,11 +190,11 @@ func (r *Type) OfType() *Type { } type Field struct { - field *schema.Field + field *types.FieldDefinition } func (r *Field) Name() string { - return r.field.Name + return r.field.Name.Name } func (r *Field) Description() *string { @@ -206,8 +205,8 @@ func (r *Field) Description() *string { } func (r *Field) Args() []*InputValue { - l := make([]*InputValue, len(r.field.Args)) - for i, v := range r.field.Args { + l := make([]*InputValue, len(r.field.Arguments)) + for i, v := range r.field.Arguments { l[i] = &InputValue{v} } return l @@ -226,12 +225,12 @@ func (r *Field) DeprecationReason() *string { if d == nil { return nil } - reason := d.Args.MustGet("reason").Value(nil).(string) + reason := d.Arguments.MustGet("reason").Deserialize(nil).(string) return &reason } type InputValue struct { - value *common.InputValue + value *types.InputValueDefinition } func (r *InputValue) Name() string { @@ -258,11 +257,11 @@ func (r *InputValue) DefaultValue() *string { } type EnumValue struct { - value *schema.EnumValue + value *types.EnumValueDefinition } func (r *EnumValue) Name() string { - return r.value.Name + return r.value.EnumValue } func (r *EnumValue) Description() *string { @@ -281,16 +280,16 @@ func (r *EnumValue) DeprecationReason() *string { if d == nil { return nil } - reason := d.Args.MustGet("reason").Value(nil).(string) + reason := d.Arguments.MustGet("reason").Deserialize(nil).(string) return &reason } type Directive struct { - directive *schema.DirectiveDecl + directive *types.DirectiveDefinition } func (r *Directive) Name() string { - return r.directive.Name + return r.directive.Name.Name } func (r *Directive) Description() *string { @@ -301,12 +300,12 @@ func (r *Directive) Description() *string { } func (r *Directive) Locations() []string { - return r.directive.Locs + return r.directive.Locations } func (r *Directive) Args() []*InputValue { - l := make([]*InputValue, len(r.directive.Args)) - for i, v := range r.directive.Args { + l := make([]*InputValue, len(r.directive.Arguments)) + for i, v := range r.directive.Arguments { l[i] = &InputValue{v} } return l diff --git a/types/argument.go b/types/argument.go new file mode 100644 index 00000000..b2681a28 --- /dev/null +++ b/types/argument.go @@ -0,0 +1,44 @@ +package types + +// Argument is a representation of the GraphQL Argument. +// +// https://spec.graphql.org/draft/#sec-Language.Arguments +type Argument struct { + Name Ident + Value Value +} + +// ArgumentList is a collection of GraphQL Arguments. +type ArgumentList []*Argument + +// Returns a Value in the ArgumentList by name. +func (l ArgumentList) Get(name string) (Value, bool) { + for _, arg := range l { + if arg.Name.Name == name { + return arg.Value, true + } + } + return nil, false +} + +// MustGet returns a Value in the ArgumentList by name. +// MustGet will panic if the argument name is not found in the ArgumentList. +func (l ArgumentList) MustGet(name string) Value { + value, ok := l.Get(name) + if !ok { + panic("argument not found") + } + return value +} + +type ArgumentsDefinition []*InputValueDefinition + +// Get returns an InputValueDefinition in the ArgumentsDefinition by name or nil if not found. +func (a ArgumentsDefinition) Get(name string) *InputValueDefinition { + for _, inputValue := range a { + if inputValue.Name.Name == name { + return inputValue + } + } + return nil +} diff --git a/types/directive.go b/types/directive.go new file mode 100644 index 00000000..afc6fa2f --- /dev/null +++ b/types/directive.go @@ -0,0 +1,31 @@ +package types + +// Directive is a representation of the GraphQL Directive. +// +// http://spec.graphql.org/draft/#sec-Language.Directives +type Directive struct { + Name Ident + Arguments ArgumentList +} + +// DirectiveDefinition is a representation of the GraphQL DirectiveDefinition. +// +// http://spec.graphql.org/draft/#sec-Type-System.Directives +type DirectiveDefinition struct { + Name Ident + Desc string + Locations []string + Arguments ArgumentsDefinition +} + +type DirectiveList []*Directive + +// Returns the Directive in the DirectiveList by name or nil if not found. +func (l DirectiveList) Get(name string) *Directive { + for _, d := range l { + if d.Name.Name == name { + return d + } + } + return nil +} diff --git a/types/doc.go b/types/doc.go new file mode 100644 index 00000000..87caa60b --- /dev/null +++ b/types/doc.go @@ -0,0 +1,9 @@ +/* + Package types represents all types from the GraphQL specification in code. + + + The names of the Go types, whenever possible, match 1:1 with the names from + the specification. + +*/ +package types diff --git a/types/enum.go b/types/enum.go new file mode 100644 index 00000000..74bd0975 --- /dev/null +++ b/types/enum.go @@ -0,0 +1,28 @@ +package types + +// EnumTypeDefinition defines a set of possible enum values. +// +// Like scalar types, an EnumTypeDefinition also represents a leaf value in a GraphQL type system. +// +// http://spec.graphql.org/draft/#sec-Enums +type EnumTypeDefinition struct { + Name Ident + EnumValuesDefinition []*EnumValueDefinition + Desc string + Directives DirectiveList +} + +// EnumValueDefinition are unique values that may be serialized as a string: the name of the +// represented value. +// +// http://spec.graphql.org/draft/#EnumValueDefinition +type EnumValueDefinition struct { + EnumValue string + Directives DirectiveList + Desc string +} + +func (*EnumTypeDefinition) Kind() string { return "ENUM" } +func (t *EnumTypeDefinition) String() string { return t.Name.Name } +func (t *EnumTypeDefinition) TypeName() string { return t.Name.Name } +func (t *EnumTypeDefinition) Description() string { return t.Desc } diff --git a/types/extension.go b/types/extension.go new file mode 100644 index 00000000..029e90b4 --- /dev/null +++ b/types/extension.go @@ -0,0 +1,10 @@ +package types + +// Extension type defines a GraphQL type extension. +// Schemas, Objects, Inputs and Scalars can be extended. +// +// https://spec.graphql.org/draft/#sec-Type-System-Extensions +type Extension struct { + Type NamedType + Directives DirectiveList +} diff --git a/types/field.go b/types/field.go new file mode 100644 index 00000000..54b3b697 --- /dev/null +++ b/types/field.go @@ -0,0 +1,37 @@ +package types + +// FieldDefinition is a representation of a GraphQL FieldDefinition. +// +// http://spec.graphql.org/draft/#FieldDefinition +type FieldDefinition struct { + Alias Ident + Name Ident + Arguments ArgumentsDefinition + Type Type + Directives DirectiveList + Desc string +} + +// FieldsDefinition is a list of an ObjectTypeDefinition's Fields. +// +// https://spec.graphql.org/draft/#FieldsDefinition +type FieldsDefinition []*FieldDefinition + +// Get returns a FieldDefinition in a FieldsDefinition by name or nil if not found. +func (l FieldsDefinition) Get(name string) *FieldDefinition { + for _, f := range l { + if f.Name.Name == name { + return f + } + } + return nil +} + +// Names returns a slice of FieldDefinition names. +func (l FieldsDefinition) Names() []string { + names := make([]string, len(l)) + for i, f := range l { + names[i] = f.Name.Name + } + return names +} diff --git a/types/fragment.go b/types/fragment.go new file mode 100644 index 00000000..606219ca --- /dev/null +++ b/types/fragment.go @@ -0,0 +1,51 @@ +package types + +import "github.com/graph-gophers/graphql-go/errors" + +type Fragment struct { + On TypeName + Selections SelectionSet +} + +// InlineFragment is a representation of the GraphQL InlineFragment. +// +// http://spec.graphql.org/draft/#InlineFragment +type InlineFragment struct { + Fragment + Directives DirectiveList + Loc errors.Location +} + +// FragmentDefinition is a representation of the GraphQL FragmentDefinition. +// +// http://spec.graphql.org/draft/#FragmentDefinition +type FragmentDefinition struct { + Fragment + Name Ident + Directives DirectiveList + Loc errors.Location +} + +// FragmentSpread is a representation of the GraphQL FragmentSpread. +// +// http://spec.graphql.org/draft/#FragmentSpread +type FragmentSpread struct { + Name Ident + Directives DirectiveList + Loc errors.Location +} + +type FragmentList []*FragmentDefinition + +// Returns a FragmentDefinition by name or nil if not found. +func (l FragmentList) Get(name string) *FragmentDefinition { + for _, f := range l { + if f.Name.Name == name { + return f + } + } + return nil +} + +func (InlineFragment) isSelection() {} +func (FragmentSpread) isSelection() {} diff --git a/types/input.go b/types/input.go new file mode 100644 index 00000000..8ce922a2 --- /dev/null +++ b/types/input.go @@ -0,0 +1,46 @@ +package types + +import "github.com/graph-gophers/graphql-go/errors" + +// InputValueDefinition is a representation of the GraphQL InputValueDefinition. +// +// http://spec.graphql.org/draft/#InputValueDefinition +type InputValueDefinition struct { + Name Ident + Type Type + Default Value + Desc string + Directives DirectiveList + Loc errors.Location + TypeLoc errors.Location +} + +type InputValueDefinitionList []*InputValueDefinition + +// Returns an InputValueDefinition by name or nil if not found. +func (l InputValueDefinitionList) Get(name string) *InputValueDefinition { + for _, v := range l { + if v.Name.Name == name { + return v + } + } + return nil +} + +// InputObject types define a set of input fields; the input fields are either scalars, enums, or +// other input objects. +// +// This allows arguments to accept arbitrarily complex structs. +// +// http://spec.graphql.org/draft/#sec-Input-Objects +type InputObject struct { + Name Ident + Desc string + Values ArgumentsDefinition + Directives DirectiveList +} + +func (*InputObject) Kind() string { return "INPUT_OBJECT" } +func (t *InputObject) String() string { return t.Name.Name } +func (t *InputObject) TypeName() string { return t.Name.Name } +func (t *InputObject) Description() string { return t.Desc } diff --git a/types/interface.go b/types/interface.go new file mode 100644 index 00000000..58258b2e --- /dev/null +++ b/types/interface.go @@ -0,0 +1,20 @@ +package types + +// InterfaceTypeDefinition represents a list of named fields and their arguments. +// +// GraphQL objects can then implement these interfaces which requires that the object type will +// define all fields defined by those interfaces. +// +// http://spec.graphql.org/draft/#sec-Interfaces +type InterfaceTypeDefinition struct { + Name string + PossibleTypes []*ObjectTypeDefinition + Fields FieldsDefinition + Desc string + Directives DirectiveList +} + +func (*InterfaceTypeDefinition) Kind() string { return "INTERFACE" } +func (t *InterfaceTypeDefinition) String() string { return t.Name } +func (t *InterfaceTypeDefinition) TypeName() string { return t.Name } +func (t *InterfaceTypeDefinition) Description() string { return t.Desc } diff --git a/types/object.go b/types/object.go new file mode 100644 index 00000000..4e80bcc4 --- /dev/null +++ b/types/object.go @@ -0,0 +1,23 @@ +package types + +// ObjectTypeDefinition represents a GraphQL ObjectTypeDefinition. +// +// type FooObject { +// foo: String +// } +// +// https://spec.graphql.org/draft/#sec-Objects +type ObjectTypeDefinition struct { + Name Ident + Interfaces []*InterfaceTypeDefinition + Fields FieldsDefinition + Desc string + Directives DirectiveList + + InterfaceNames []string +} + +func (*ObjectTypeDefinition) Kind() string { return "OBJECT" } +func (t *ObjectTypeDefinition) String() string { return t.Name.Name } +func (t *ObjectTypeDefinition) TypeName() string { return t.Name.Name } +func (t *ObjectTypeDefinition) Description() string { return t.Desc } diff --git a/types/query.go b/types/query.go new file mode 100644 index 00000000..caca6ef4 --- /dev/null +++ b/types/query.go @@ -0,0 +1,62 @@ +package types + +import "github.com/graph-gophers/graphql-go/errors" + +// ExecutableDefinition represents a set of operations or fragments that can be executed +// against a schema. +// +// http://spec.graphql.org/draft/#ExecutableDefinition +type ExecutableDefinition struct { + Operations OperationList + Fragments FragmentList +} + +// OperationDefinition represents a GraphQL Operation. +// +// https://spec.graphql.org/draft/#sec-Language.Operations +type OperationDefinition struct { + Type OperationType + Name Ident + Vars ArgumentsDefinition + Selections SelectionSet + Directives DirectiveList + Loc errors.Location +} + +type OperationType string + +// A Selection is a field requested in a GraphQL operation. +// +// http://spec.graphql.org/draft/#Selection +type Selection interface { + isSelection() +} + +// A SelectionSet represents a collection of Selections +// +// http://spec.graphql.org/draft/#sec-Selection-Sets +type SelectionSet []Selection + +// Field represents a field used in a query. +type Field struct { + Alias Ident + Name Ident + Arguments ArgumentList + Directives DirectiveList + SelectionSet SelectionSet + SelectionSetLoc errors.Location +} + +func (Field) isSelection() {} + +type OperationList []*OperationDefinition + +// Get returns an OperationDefinition by name or nil if not found. +func (l OperationList) Get(name string) *OperationDefinition { + for _, f := range l { + if f.Name.Name == name { + return f + } + } + return nil +} diff --git a/types/scalar.go b/types/scalar.go new file mode 100644 index 00000000..026b303b --- /dev/null +++ b/types/scalar.go @@ -0,0 +1,19 @@ +package types + +// ScalarTypeDefinition types represent primitive leaf values (e.g. a string or an integer) in a GraphQL type +// system. +// +// GraphQL responses take the form of a hierarchical tree; the leaves on these trees are GraphQL +// scalars. +// +// http://spec.graphql.org/draft/#sec-Scalars +type ScalarTypeDefinition struct { + Name Ident + Desc string + Directives DirectiveList +} + +func (*ScalarTypeDefinition) Kind() string { return "SCALAR" } +func (t *ScalarTypeDefinition) String() string { return t.Name.Name } +func (t *ScalarTypeDefinition) TypeName() string { return t.Name.Name } +func (t *ScalarTypeDefinition) Description() string { return t.Desc } diff --git a/types/schema.go b/types/schema.go new file mode 100644 index 00000000..06811a97 --- /dev/null +++ b/types/schema.go @@ -0,0 +1,42 @@ +package types + +// Schema represents a GraphQL service's collective type system capabilities. +// A schema is defined in terms of the types and directives it supports as well as the root +// operation types for each kind of operation: `query`, `mutation`, and `subscription`. +// +// For a more formal definition, read the relevant section in the specification: +// +// http://spec.graphql.org/draft/#sec-Schema +type Schema struct { + // EntryPoints determines the place in the type system where `query`, `mutation`, and + // `subscription` operations begin. + // + // http://spec.graphql.org/draft/#sec-Root-Operation-Types + // + EntryPoints map[string]NamedType + + // Types are the fundamental unit of any GraphQL schema. + // There are six kinds of named types, and two wrapping types. + // + // http://spec.graphql.org/draft/#sec-Types + Types map[string]NamedType + + // Directives are used to annotate various parts of a GraphQL document as an indicator that they + // should be evaluated differently by a validator, executor, or client tool such as a code + // generator. + // + // http://spec.graphql.org/#sec-Type-System.Directives + Directives map[string]*DirectiveDefinition + + UseFieldResolvers bool + + EntryPointNames map[string]string + Objects []*ObjectTypeDefinition + Unions []*Union + Enums []*EnumTypeDefinition + Extensions []*Extension +} + +func (s *Schema) Resolve(name string) Type { + return s.Types[name] +} diff --git a/types/types.go b/types/types.go new file mode 100644 index 00000000..df34d08a --- /dev/null +++ b/types/types.go @@ -0,0 +1,63 @@ +package types + +import ( + "github.com/graph-gophers/graphql-go/errors" +) + +// TypeName is a base building block for GraphQL type references. +type TypeName struct { + Ident +} + +// NamedType represents a type with a name. +// +// http://spec.graphql.org/draft/#NamedType +type NamedType interface { + Type + TypeName() string + Description() string +} + +type Ident struct { + Name string + Loc errors.Location +} + +type Type interface { + // Kind returns one possible GraphQL type kind. A type kind must be + // valid as defined by the GraphQL spec. + // + // https://spec.graphql.org/draft/#sec-Type-Kinds + Kind() string + + // String serializes a Type into a GraphQL specification format type. + // + // http://spec.graphql.org/draft/#sec-Serialization-Format + String() string +} + +// List represents a GraphQL ListType. +// +// http://spec.graphql.org/draft/#ListType +type List struct { + // OfType represents the inner-type of a List type. + // For example, the List type `[Foo]` has an OfType of Foo. + OfType Type +} + +// NonNull represents a GraphQL NonNullType. +// +// https://spec.graphql.org/draft/#NonNullType +type NonNull struct { + // OfType represents the inner-type of a NonNull type. + // For example, the NonNull type `Foo!` has an OfType of Foo. + OfType Type +} + +func (*List) Kind() string { return "LIST" } +func (*NonNull) Kind() string { return "NON_NULL" } +func (*TypeName) Kind() string { panic("TypeName needs to be resolved to actual type") } + +func (t *List) String() string { return "[" + t.OfType.String() + "]" } +func (t *NonNull) String() string { return t.OfType.String() + "!" } +func (*TypeName) String() string { panic("TypeName needs to be resolved to actual type") } diff --git a/types/union.go b/types/union.go new file mode 100644 index 00000000..b8e0c668 --- /dev/null +++ b/types/union.go @@ -0,0 +1,21 @@ +package types + +// Union types represent objects that could be one of a list of GraphQL object types, but provides no +// guaranteed fields between those types. +// +// They also differ from interfaces in that object types declare what interfaces they implement, but +// are not aware of what unions contain them. +// +// http://spec.graphql.org/draft/#sec-Unions +type Union struct { + Name Ident + UnionMemberTypes []*ObjectTypeDefinition + Desc string + Directives DirectiveList + TypeNames []string +} + +func (*Union) Kind() string { return "UNION" } +func (t *Union) String() string { return t.Name.Name } +func (t *Union) TypeName() string { return t.Name.Name } +func (t *Union) Description() string { return t.Desc } diff --git a/types/value.go b/types/value.go new file mode 100644 index 00000000..9f8d041a --- /dev/null +++ b/types/value.go @@ -0,0 +1,141 @@ +package types + +import ( + "strconv" + "strings" + "text/scanner" + + "github.com/graph-gophers/graphql-go/errors" +) + +// Value represents a literal input or literal default value in the GraphQL Specification. +// +// http://spec.graphql.org/draft/#sec-Input-Values +type Value interface { + // Deserialize transforms a GraphQL specification format literal into a Go type. + Deserialize(vars map[string]interface{}) interface{} + + // String serializes a Value into a GraphQL specification format literal. + String() string + Location() errors.Location +} + +// PrimitiveValue represents one of the following GraphQL scalars: Int, Float, +// String, or Boolean +type PrimitiveValue struct { + Type rune + Text string + Loc errors.Location +} + +func (val *PrimitiveValue) Deserialize(vars map[string]interface{}) interface{} { + switch val.Type { + case scanner.Int: + value, err := strconv.ParseInt(val.Text, 10, 32) + if err != nil { + panic(err) + } + return int32(value) + + case scanner.Float: + value, err := strconv.ParseFloat(val.Text, 64) + if err != nil { + panic(err) + } + return value + + case scanner.String: + value, err := strconv.Unquote(val.Text) + if err != nil { + panic(err) + } + return value + + case scanner.Ident: + switch val.Text { + case "true": + return true + case "false": + return false + default: + return val.Text + } + + default: + panic("invalid literal value") + } +} + +func (val *PrimitiveValue) String() string { return val.Text } +func (val *PrimitiveValue) Location() errors.Location { return val.Loc } + +// ListValue represents a literal list Value in the GraphQL specification. +// +// http://spec.graphql.org/draft/#sec-List-Value +type ListValue struct { + Values []Value + Loc errors.Location +} + +func (val *ListValue) Deserialize(vars map[string]interface{}) interface{} { + entries := make([]interface{}, len(val.Values)) + for i, entry := range val.Values { + entries[i] = entry.Deserialize(vars) + } + return entries +} + +func (val *ListValue) String() string { + entries := make([]string, len(val.Values)) + for i, entry := range val.Values { + entries[i] = entry.String() + } + return "[" + strings.Join(entries, ", ") + "]" +} + +func (val *ListValue) Location() errors.Location { return val.Loc } + +// ObjectValue represents a literal object Value in the GraphQL specification. +// +// http://spec.graphql.org/draft/#sec-Object-Value +type ObjectValue struct { + Fields []*ObjectField + Loc errors.Location +} + +// ObjectField represents field/value pairs in a literal ObjectValue. +type ObjectField struct { + Name Ident + Value Value +} + +func (val *ObjectValue) Deserialize(vars map[string]interface{}) interface{} { + fields := make(map[string]interface{}, len(val.Fields)) + for _, f := range val.Fields { + fields[f.Name.Name] = f.Value.Deserialize(vars) + } + return fields +} + +func (val *ObjectValue) String() string { + entries := make([]string, 0, len(val.Fields)) + for _, f := range val.Fields { + entries = append(entries, f.Name.Name+": "+f.Value.String()) + } + return "{" + strings.Join(entries, ", ") + "}" +} + +func (val *ObjectValue) Location() errors.Location { + return val.Loc +} + +// NullValue represents a literal `null` Value in the GraphQL specification. +// +// http://spec.graphql.org/draft/#sec-Null-Value +type NullValue struct { + Loc errors.Location +} + +func (val *NullValue) Deserialize(vars map[string]interface{}) interface{} { return nil } +func (val *NullValue) String() string { return "null" } +func (val *NullValue) Location() errors.Location { return val.Loc } diff --git a/types/variable.go b/types/variable.go new file mode 100644 index 00000000..1a4e2a51 --- /dev/null +++ b/types/variable.go @@ -0,0 +1,15 @@ +package types + +import "github.com/graph-gophers/graphql-go/errors" + +// Variable is used in GraphQL operations to parameterize an input value. +// +// http://spec.graphql.org/draft/#Variable +type Variable struct { + Name string + Loc errors.Location +} + +func (v Variable) Deserialize(vars map[string]interface{}) interface{} { return vars[v.Name] } +func (v Variable) String() string { return "$" + v.Name } +func (v *Variable) Location() errors.Location { return v.Loc }