@@ -10,7 +10,9 @@ import com.apollographql.apollo.ast.GQLDirectiveDefinition
10
10
import com.apollographql.apollo.ast.GQLDocument
11
11
import com.apollographql.apollo.ast.GQLEnumTypeDefinition
12
12
import com.apollographql.apollo.ast.GQLField
13
+ import com.apollographql.apollo.ast.GQLFieldDefinition
13
14
import com.apollographql.apollo.ast.GQLInputObjectTypeDefinition
15
+ import com.apollographql.apollo.ast.GQLInputValueDefinition
14
16
import com.apollographql.apollo.ast.GQLInterfaceTypeDefinition
15
17
import com.apollographql.apollo.ast.GQLListValue
16
18
import com.apollographql.apollo.ast.GQLNamed
@@ -266,23 +268,23 @@ internal fun validateSchema(definitions: List<GQLDefinition>, options: SchemaVal
266
268
).flatMap {
267
269
it.parseAsGQLDocument().getOrThrow().definitions
268
270
}
269
- .forEach { expected ->
270
- val existing = when (expected) {
271
- is GQLTypeDefinition -> typeDefinitions.get(expected.name)
272
- is GQLDirectiveDefinition -> directiveDefinitions.get(expected.name)
273
- else -> error(" " )// should never happen
274
- }
275
- if (existing != null && ! foreignNames.containsKey(expected.definitionName()) && ! existing.semanticEquals(expected)) {
276
- /*
277
- * For non-linked definitions, check that the definitions match 1:1.
278
- * We do not check linked definitions because:
279
- * - we know we support them by construction.
280
- * - someone may rename argument types, which makes validation much harder. One example is importing `@catch` but not
281
- * `@catchTo`.
282
- */
283
- issues.add(IncompatibleDefinition (expected.name, expected.toSemanticSdl(), existing.sourceLocation))
284
- }
285
- }
271
+ .forEach { expected ->
272
+ val existing = when (expected) {
273
+ is GQLTypeDefinition -> typeDefinitions.get(expected.name)
274
+ is GQLDirectiveDefinition -> directiveDefinitions.get(expected.name)
275
+ else -> error(" " )// should never happen
276
+ }
277
+ if (existing != null && ! foreignNames.containsKey(expected.definitionName()) && ! existing.semanticEquals(expected)) {
278
+ /*
279
+ * For non-linked definitions, check that the definitions match 1:1.
280
+ * We do not check linked definitions because:
281
+ * - we know we support them by construction.
282
+ * - someone may rename argument types, which makes validation much harder. One example is importing `@catch` but not
283
+ * `@catchTo`.
284
+ */
285
+ issues.add(IncompatibleDefinition (expected.name, expected.toSemanticSdl(), existing.sourceLocation))
286
+ }
287
+ }
286
288
287
289
/* *
288
290
* I'm not 100% clear on the order of validations, here I'm merging the extensions first thing
@@ -313,6 +315,7 @@ internal fun validateSchema(definitions: List<GQLDefinition>, options: SchemaVal
313
315
mergedScope.validateUnions()
314
316
mergedScope.validateInputObjects()
315
317
mergedScope.validateScalars()
318
+ mergedScope.validateDirectiveDefinitions()
316
319
317
320
val keyFields = mergedScope.validateAndComputeKeyFields()
318
321
val connectionTypes = mergedScope.computeConnectionTypes()
@@ -515,6 +518,7 @@ private fun List<GQLValue>.parseImport(issues: MutableList<Issue>): Map<Definiti
515
518
// Simple case: import the definition without renaming
516
519
it.value to it.value
517
520
}
521
+
518
522
is GQLObjectValue -> {
519
523
if (it.fields.size != 2 ) {
520
524
issues.add(OtherValidationIssue (" Too many fields in 'import' argument" , it.sourceLocation))
@@ -663,8 +667,8 @@ private fun ValidationScope.validateInterfaces() {
663
667
664
668
validateDirectivesInConstContext(i.directives, i)
665
669
666
- i.fields.forEach { gqlFieldDefinition ->
667
- validateDirectivesInConstContext(gqlFieldDefinition.directives, gqlFieldDefinition )
670
+ i.fields.forEach { fieldDefinition ->
671
+ validateField(fieldDefinition )
668
672
}
669
673
}
670
674
}
@@ -684,8 +688,20 @@ private fun ValidationScope.validateObjects() {
684
688
685
689
validateDirectivesInConstContext(o.directives, o)
686
690
687
- o.fields.forEach { gqlFieldDefinition ->
688
- validateDirectivesInConstContext(gqlFieldDefinition.directives, gqlFieldDefinition)
691
+ o.fields.forEach { fieldDefinition ->
692
+ validateField(fieldDefinition)
693
+ }
694
+ }
695
+ }
696
+
697
+ private fun ValidationScope.validateField (fieldDefinition : GQLFieldDefinition ) {
698
+ validateDirectivesInConstContext(fieldDefinition.directives, fieldDefinition)
699
+
700
+ fieldDefinition.arguments.forEach {
701
+ if (it.defaultValue != null ) {
702
+ validateAndCoerceValue(it.defaultValue, it.type, false , false ) {
703
+ issues.add(it.constContextError())
704
+ }
689
705
}
690
706
}
691
707
}
@@ -716,6 +732,17 @@ private fun ValidationScope.validateCatch(schemaDefinition: GQLSchemaDefinition)
716
732
return
717
733
}
718
734
}
735
+ private fun ValidationScope.validateDirectiveDefinitions () {
736
+ directiveDefinitions.values.forEach {
737
+ it.arguments.forEach {
738
+ if (it.defaultValue != null ) {
739
+ validateAndCoerceValue(it.defaultValue, it.type, false , false ) {
740
+ issues.add(it.constContextError())
741
+ }
742
+ }
743
+ }
744
+ }
745
+ }
719
746
720
747
private fun ValidationScope.validateScalars () {
721
748
typeDefinitions.values.filterIsInstance<GQLScalarTypeDefinition >().forEach { scalarTypeDefinition ->
@@ -732,18 +759,23 @@ private fun ValidationScope.validateScalars() {
732
759
issues.add(OtherValidationIssue (
733
760
message = " Only one of @map and @mapTo can be added to a scalar." ,
734
761
sourceLocation = scalarTypeDefinition.sourceLocation
735
- ))
762
+ )
763
+ )
736
764
}
737
765
}
738
766
}
739
767
740
768
private fun ValidationScope.validateInputObjects () {
769
+ val traversalState = TraversalState ()
770
+ val defaultValueTraversalState = DefaultValueTraversalState ()
741
771
typeDefinitions.values.filterIsInstance<GQLInputObjectTypeDefinition >().forEach { o ->
742
772
if (o.inputFields.isEmpty()) {
743
773
registerIssue(" Input object must specify one or more input fields" , o.sourceLocation)
744
774
}
745
775
746
776
validateDirectivesInConstContext(o.directives, o)
777
+ validateInputFieldCycles(o, traversalState)
778
+ validateInputObjectDefaultValue(o, defaultValueTraversalState)
747
779
748
780
val isOneOfInputObject = o.directives.findOneOf()
749
781
o.inputFields.forEach { gqlInputValueDefinition ->
@@ -759,6 +791,143 @@ private fun ValidationScope.validateInputObjects() {
759
791
}
760
792
}
761
793
794
+ private class TraversalState {
795
+ val visitedTypes = mutableSetOf<String >()
796
+ val fieldPath = mutableListOf<Pair <String , SourceLocation ?>>()
797
+ val fieldPathIndexByTypeName = mutableMapOf<String , Int >()
798
+ }
799
+
800
+ private class DefaultValueTraversalState {
801
+ val visitedFields = mutableSetOf<String >()
802
+ val fieldPath = mutableListOf<Pair <String , SourceLocation ?>>()
803
+ val fieldPathIndex = mutableMapOf<String , Int >()
804
+ }
805
+
806
+
807
+ private fun ValidationScope.validateInputFieldCycles (inputObjectTypeDefinition : GQLInputObjectTypeDefinition , state : TraversalState ) {
808
+ if (state.visitedTypes.contains(inputObjectTypeDefinition.name)) {
809
+ return
810
+ }
811
+ state.visitedTypes.add(inputObjectTypeDefinition.name)
812
+
813
+ state.fieldPathIndexByTypeName[inputObjectTypeDefinition.name] = state.fieldPath.size
814
+
815
+ inputObjectTypeDefinition.inputFields.forEach {
816
+ val type = it.type
817
+ if (type is GQLNonNullType && type.type is GQLNamedType ) {
818
+ val fieldType = typeDefinitions.get(type.type.name)
819
+ if (fieldType is GQLInputObjectTypeDefinition ) {
820
+ val cycleIndex = state.fieldPathIndexByTypeName.get(fieldType.name)
821
+
822
+ state.fieldPath.add(" ${fieldType.name} .${it.name} " to it.sourceLocation)
823
+
824
+ if (cycleIndex == null ) {
825
+ validateInputFieldCycles(fieldType, state)
826
+ } else {
827
+ val cyclePath = state.fieldPath.subList(cycleIndex, state.fieldPath.size)
828
+
829
+ cyclePath.forEach {
830
+ issues.add(
831
+ OtherValidationIssue (
832
+ buildString {
833
+ append(" Invalid circular reference. The Input Object '${fieldType.name} ' references itself " )
834
+ if (cyclePath.size > 1 ) {
835
+ append(" via the non-null fields: " )
836
+ } else {
837
+ append(" in the non-null field " )
838
+ }
839
+ append(cyclePath.map { it.first }.joinToString(" , " ))
840
+ },
841
+ it.second
842
+ )
843
+ )
844
+ }
845
+ }
846
+
847
+ state.fieldPath.removeLast()
848
+ }
849
+ }
850
+ }
851
+
852
+ state.fieldPathIndexByTypeName.remove(inputObjectTypeDefinition.name)
853
+ }
854
+ private fun ValidationScope.validateInputObjectDefaultValue (
855
+ inputObjectTypeDefinition : GQLInputObjectTypeDefinition ,
856
+ state : DefaultValueTraversalState
857
+ ) {
858
+ validateInputObjectDefaultValue(inputObjectTypeDefinition, GQLObjectValue (null ,emptyList()), state)
859
+ }
860
+ private fun ValidationScope.validateInputObjectDefaultValue (
861
+ inputObjectTypeDefinition : GQLInputObjectTypeDefinition ,
862
+ defaultValue : GQLValue ,
863
+ state : DefaultValueTraversalState
864
+ ) {
865
+ if (defaultValue is GQLListValue ) {
866
+ defaultValue.values.forEach {
867
+ validateInputObjectDefaultValue(inputObjectTypeDefinition, it, state)
868
+ }
869
+ } else if (defaultValue is GQLObjectValue ) {
870
+ inputObjectTypeDefinition.inputFields.forEach { inputField ->
871
+ val rawType = inputField.type.rawType()
872
+ val typeDefinition = typeDefinitions.get(rawType.name)
873
+ if (typeDefinition !is GQLInputObjectTypeDefinition ) {
874
+ return
875
+ }
876
+ val fieldDefaultValue = defaultValue.fields.firstOrNull { it.name == inputField.name}
877
+ if (fieldDefaultValue != null ) {
878
+ validateInputObjectDefaultValue(typeDefinition, fieldDefaultValue.value, state)
879
+ } else {
880
+ validateInputFieldDefaultValue(inputField, " ${inputObjectTypeDefinition.name} .${inputField.name} " , defaultValue, typeDefinition, state)
881
+ }
882
+ }
883
+ }
884
+ }
885
+
886
+ private fun ValidationScope.validateInputFieldDefaultValue (
887
+ inputFieldDefinition : GQLInputValueDefinition ,
888
+ fieldStr : String ,
889
+ defaultValue : GQLObjectValue ,
890
+ typeDefinition : GQLInputObjectTypeDefinition ,
891
+ state : DefaultValueTraversalState
892
+ ) {
893
+ val fieldDefaultValue = inputFieldDefinition.defaultValue
894
+ if (fieldDefaultValue == null ) {
895
+ return
896
+ }
897
+
898
+ val cycleIndex = state.fieldPathIndex[fieldStr]
899
+ if (cycleIndex != null ) {
900
+ val cyclePath = state.fieldPath.subList(cycleIndex, state.fieldPath.size)
901
+ cyclePath.forEach {
902
+ issues.add(
903
+ OtherValidationIssue (
904
+ buildString {
905
+ append(" Invalid circular reference. The default value of Input Object field $fieldStr references itself" )
906
+ if (cyclePath.size > 1 ) {
907
+ append(" via the default values of: " )
908
+ append(cyclePath.map { it.first }.joinToString(" , " ))
909
+ }
910
+ append(' .' )
911
+ },
912
+ it.second
913
+ )
914
+ )
915
+ }
916
+ }
917
+ if (state.visitedFields.contains(fieldStr)) {
918
+ return
919
+ }
920
+
921
+ state.visitedFields.add(fieldStr)
922
+ state.fieldPathIndex.put(fieldStr, state.fieldPath.size)
923
+ state.fieldPath.add(fieldStr to fieldDefaultValue.sourceLocation)
924
+
925
+ validateInputObjectDefaultValue(typeDefinition, fieldDefaultValue, state)
926
+
927
+ state.fieldPathIndex.remove(fieldStr)
928
+ state.fieldPath.removeLast()
929
+ }
930
+
762
931
private fun ValidationScope.validateNoIntrospectionNames () {
763
932
// 3.3 All types and directives defined within a schema must not have a name which begins with "__"
764
933
(typeDefinitions.values + directiveDefinitions.values).forEach { definition ->
0 commit comments