Skip to content

Commit 05ded31

Browse files
committed
add support for variable references for built in query directives
1 parent ec8d92f commit 05ded31

File tree

3 files changed

+116
-4
lines changed

3 files changed

+116
-4
lines changed

src/main/java/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilter.java

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,38 @@
66
import graphql.analysis.QueryVisitorInlineFragmentEnvironment;
77
import graphql.analysis.QueryVisitorStub;
88
import graphql.language.Argument;
9+
import graphql.language.AstTransformer;
910
import graphql.language.Document;
1011
import graphql.language.Field;
1112
import graphql.language.FragmentDefinition;
1213
import graphql.language.FragmentSpread;
1314
import graphql.language.InlineFragment;
1415
import graphql.language.Node;
16+
import graphql.language.NodeVisitorStub;
1517
import graphql.language.OperationDefinition;
1618
import graphql.language.Value;
1719
import graphql.language.VariableReference;
1820
import graphql.schema.GraphQLObjectType;
1921
import graphql.schema.GraphQLSchema;
22+
import graphql.util.TraversalControl;
23+
import graphql.util.TraverserContext;
24+
import lombok.Getter;
25+
2026
import java.util.ArrayList;
27+
import java.util.Collection;
2128
import java.util.List;
2229
import java.util.Map;
2330
import java.util.Set;
2431
import java.util.stream.Collectors;
2532
import java.util.stream.Stream;
26-
import lombok.Getter;
2733

2834
/**
2935
* This class provides assistance in extracting all VariableReference names used in GraphQL nodes.
3036
*/
3137
public class VariableDefinitionFilter {
3238

39+
private static AstTransformer astTransformer = new AstTransformer();
40+
3341
/**
3442
* Traverses a GraphQL Node and returns all VariableReference names used in all nodes in the graph.
3543
*
@@ -67,8 +75,20 @@ public Set<String> getVariableReferencesFromNode(GraphQLSchema graphQLSchema, Gr
6775

6876
Set<VariableReference> additionalReferences = operationDirectiveVariableReferences(operationDefinitions);
6977

70-
return Stream.concat(variableReferenceVisitor.getVariableReferences().stream(), additionalReferences.stream())
71-
.map(VariableReference::getName).collect(Collectors.toSet());
78+
Stream<VariableReference> variableReferenceStream;
79+
if((variableReferenceVisitor.getVariableReferences().size() + additionalReferences.size()) != variables.size()) {
80+
NodeTraverser nodeTraverser = new NodeTraverser();
81+
astTransformer.transform(rootNode, nodeTraverser);
82+
83+
variableReferenceStream = Stream.of(variableReferenceVisitor.getVariableReferences(),
84+
additionalReferences,
85+
nodeTraverser.getVariableReferenceExtractor().getVariableReferences())
86+
.flatMap(Collection::stream);
87+
} else {
88+
variableReferenceStream = Stream.concat(variableReferenceVisitor.getVariableReferences().stream(), additionalReferences.stream());
89+
}
90+
return variableReferenceStream.map(VariableReference::getName).collect(Collectors.toSet());
91+
7292
}
7393

7494
private Set<VariableReference> operationDirectiveVariableReferences(List<OperationDefinition> operationDefinitions) {
@@ -163,4 +183,19 @@ private void captureVariableReferences(Stream<Argument> arguments) {
163183
variableReferenceExtractor.captureVariableReferences(values);
164184
}
165185
}
186+
187+
static class NodeTraverser extends NodeVisitorStub {
188+
189+
@Getter
190+
private final VariableReferenceExtractor variableReferenceExtractor = new VariableReferenceExtractor();
191+
192+
public TraversalControl visitArgument(Argument node, TraverserContext<Node> context) {
193+
return this.visitNode(node, context);
194+
}
195+
196+
public TraversalControl visitVariableReference(VariableReference node, TraverserContext<Node> context) {
197+
variableReferenceExtractor.captureVariableReference(node);
198+
return this.visitValue(node, context);
199+
}
200+
}
166201
}

src/main/java/com/intuit/graphql/orchestrator/batch/VariableReferenceExtractor.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,14 @@ public Set<VariableReference> getVariableReferences() {
1919

2020
public void captureVariableReferences(List<Value> values) {
2121
for (final Value value : values) {
22-
doSwitch(value);
22+
captureVariableReference(value);
2323
}
2424
}
2525

26+
public void captureVariableReference(Value value) {
27+
doSwitch(value);
28+
}
29+
2630
private void doSwitch(Value value) {
2731
if (value instanceof ArrayValue) {
2832
handleArrayValue((ArrayValue) value);

src/test/groovy/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilterSpec.groovy

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@ class VariableDefinitionFilterSpec extends Specification {
4747
directive @field_directive_argument(arg: InputObject) on FIELD_DEFINITION
4848
'''
4949

50+
private String schema2 = '''
51+
type Query { person: Person }
52+
53+
type Person {
54+
address : Address
55+
id: String
56+
}
57+
58+
type Address { city: String state: String zip: String }
59+
'''
60+
5061
private GraphQLSchema graphQLSchema
5162

5263
private VariableDefinitionFilter variableDefinitionFilter
@@ -63,6 +74,12 @@ class VariableDefinitionFilterSpec extends Specification {
6374
RuntimeWiring.newRuntimeWiring().build())
6475
}
6576

77+
private GraphQLSchema getSchema2() {
78+
return new SchemaGenerator()
79+
.makeExecutableSchema(new SchemaParser().parse(schema2),
80+
RuntimeWiring.newRuntimeWiring().build())
81+
}
82+
6683
private Map<String, FragmentDefinition> getFragmentsByName(Document document) {
6784
return document.getDefinitionsOfType(FragmentDefinition.class).stream()
6885
.inject([:]) {map, it -> map << [(it.getName()): it]}
@@ -179,6 +196,62 @@ class VariableDefinitionFilterSpec extends Specification {
179196
results.containsAll("int_arg", "string_arg")
180197
}
181198

199+
def "variable References In Built in Query Directive includes"() {
200+
given:
201+
String query = '''
202+
query($includeContext: Boolean!) {
203+
consumer {
204+
liabilities(arg: 1) @include(if: $includeContext) {
205+
totalDebt(arg: 1)
206+
}
207+
income
208+
}
209+
}
210+
'''
211+
212+
Document document = parser.parseDocument(query)
213+
HashMap<String, Object> variables = new HashMap<>()
214+
variables.put("includeContext", false)
215+
216+
when:
217+
final Set<String> results = variableDefinitionFilter
218+
.getVariableReferencesFromNode(graphQLSchema, graphQLSchema.getQueryType(), Collections.emptyMap(),
219+
variables, document)
220+
221+
then:
222+
results.size() == 1
223+
224+
results.containsAll("includeContext")
225+
}
226+
227+
def "variable References In Built in Query Directive skip"() {
228+
given:
229+
String query = '''
230+
query($includeContext: Boolean!) {
231+
consumer {
232+
liabilities(arg: 1) @skip(if: $includeContext) {
233+
totalDebt(arg: 1)
234+
}
235+
income
236+
}
237+
}
238+
'''
239+
240+
Document document = parser.parseDocument(query)
241+
HashMap<String, Object> variables = new HashMap<>()
242+
variables.put("includeContext", true)
243+
244+
when:
245+
final Set<String> results = variableDefinitionFilter
246+
.getVariableReferencesFromNode(graphQLSchema, graphQLSchema.getQueryType(), Collections.emptyMap(),
247+
variables, document)
248+
249+
then:
250+
results.size() == 1
251+
252+
results.containsAll("includeContext")
253+
}
254+
182255
def "test Negative Cases"() {
183256
given:
184257
final String negativeTestCaseQuery = "query { consumer { liabilities { totalDebt(arg: 1234) } } }"

0 commit comments

Comments
 (0)