diff --git a/pom.xml b/pom.xml index 8f9b227e..186bde02 100644 --- a/pom.xml +++ b/pom.xml @@ -2,7 +2,7 @@ 4.0.0 com.intuit.graphql graphql-orchestrator-java - 5.0.19-SNAPSHOT + 5.0.20-DEFER-SNAPSHOT jar graphql-orchestrator-java GraphQL Orchestrator combines multiple graphql services into a single unified schema @@ -39,6 +39,7 @@ 3.0.1 2.26.0 17.4 + 3.4.24 @@ -159,6 +160,17 @@ jackson-databind 2.13.4.1 + + io.projectreactor + reactor-core + ${reactorVersion} + + + io.projectreactor + reactor-test + ${reactorVersion} + test + diff --git a/src/main/java/com/intuit/graphql/orchestrator/GraphQLOrchestrator.java b/src/main/java/com/intuit/graphql/orchestrator/GraphQLOrchestrator.java index 4812344b..b0b71ab8 100644 --- a/src/main/java/com/intuit/graphql/orchestrator/GraphQLOrchestrator.java +++ b/src/main/java/com/intuit/graphql/orchestrator/GraphQLOrchestrator.java @@ -1,8 +1,12 @@ package com.intuit.graphql.orchestrator; +import com.intuit.graphql.orchestrator.deferDirective.DeferDirectiveInstrumentation; +import com.intuit.graphql.orchestrator.deferDirective.DeferOptions; import com.intuit.graphql.orchestrator.schema.RuntimeGraph; +import com.intuit.graphql.orchestrator.utils.MultiEIGenerator; import graphql.ExecutionInput; import graphql.ExecutionResult; +import graphql.ExecutionResultImpl; import graphql.GraphQL; import graphql.GraphQLContext; import graphql.execution.AsyncExecutionStrategy; @@ -11,11 +15,14 @@ import graphql.execution.instrumentation.ChainedInstrumentation; import graphql.execution.instrumentation.Instrumentation; import graphql.execution.instrumentation.dataloader.DataLoaderDispatcherInstrumentation; +import graphql.execution.reactive.SubscriptionPublisher; import graphql.schema.GraphQLSchema; import lombok.extern.slf4j.Slf4j; import org.dataloader.BatchLoader; import org.dataloader.DataLoader; import org.dataloader.DataLoaderRegistry; +import reactor.core.publisher.Flux; +import reactor.core.scheduler.Schedulers; import java.util.Arrays; import java.util.LinkedList; @@ -23,16 +30,20 @@ import java.util.Map; import java.util.Objects; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import java.util.function.UnaryOperator; import java.util.stream.Collectors; +import static com.intuit.graphql.orchestrator.utils.DirectivesUtil.USE_DEFER; import static java.util.Objects.requireNonNull; @Slf4j public class GraphQLOrchestrator { public static final String DATA_LOADER_REGISTRY_CONTEXT_KEY = DataLoaderRegistry.class.getName() + ".context.key"; + private static final DeferOptions DEFAULT_DEFER_OPTIONS = DeferOptions.builder().nestedDefersAllowed(false).build(); + private static final boolean DISABLED_DEFER = false; private final RuntimeGraph runtimeGraph; private final List instrumentations; @@ -41,8 +52,8 @@ public class GraphQLOrchestrator { private final ExecutionStrategy mutationExecutionStrategy; private GraphQLOrchestrator(final RuntimeGraph runtimeGraph, final List instrumentations, - final ExecutionIdProvider executionIdProvider, final ExecutionStrategy queryExecutionStrategy, - final ExecutionStrategy mutationExecutionStrategy) { + final ExecutionIdProvider executionIdProvider, final ExecutionStrategy queryExecutionStrategy, + final ExecutionStrategy mutationExecutionStrategy) { this.runtimeGraph = runtimeGraph; this.instrumentations = instrumentations; this.executionIdProvider = executionIdProvider; @@ -63,16 +74,80 @@ private DataLoaderRegistry buildNewDataLoaderRegistry() { // to create a new DataLoader per request. Else it will use the cache which is shared // across request. final Map temporaryMap = this.runtimeGraph.getBatchLoaderMap().values().stream().distinct() - .collect(Collectors.toMap(Function.identity(), DataLoader::new)); + .collect(Collectors.toMap(Function.identity(), DataLoader::new)); this.runtimeGraph.getBatchLoaderMap() - .forEach((key, batchLoader) -> - dataLoaderRegistry.register(key, temporaryMap.getOrDefault(batchLoader, new DataLoader(batchLoader)))); + .forEach((key, batchLoader) -> + dataLoaderRegistry.register(key, temporaryMap.getOrDefault(batchLoader, new DataLoader(batchLoader)))); return dataLoaderRegistry; } public CompletableFuture execute(ExecutionInput executionInput) { + return execute(executionInput, DEFAULT_DEFER_OPTIONS, DISABLED_DEFER); + } + + public CompletableFuture execute(ExecutionInput executionInput, DeferOptions deferOptions, boolean hasDefer) { + if(hasDefer) { + return executeWithDefer(executionInput, deferOptions); + } + final GraphQL graphQL = constructGraphQL(); + + final ExecutionInput newExecutionInput = executionInput + .transform(builder -> builder.dataLoaderRegistry(buildNewDataLoaderRegistry())); + + if (newExecutionInput.getContext() instanceof GraphQLContext) { + ((GraphQLContext) newExecutionInput.getContext()) + .put(DATA_LOADER_REGISTRY_CONTEXT_KEY, newExecutionInput.getDataLoaderRegistry()); + ((GraphQLContext) newExecutionInput.getContext()) + .put(USE_DEFER , false); + } + return graphQL.executeAsync(newExecutionInput); + } + private CompletableFuture executeWithDefer(ExecutionInput executionInput, DeferOptions options) { + AtomicInteger responses = new AtomicInteger(0); + MultiEIGenerator eiGenerator = new MultiEIGenerator(executionInput, options, this.getSchema()); + + Flux executionResultPublisher = eiGenerator.generateEIs() + .filter(ei -> !ei.getQuery().equals("")) + .publishOn(Schedulers.elastic()) + .map(ei -> { + log.error("Timestamp processing emittedValue: {}", System.currentTimeMillis()); + return this.generateEIWIthNewContext(ei); + }) + .map(constructGraphQL()::executeAsync) + .map(CompletableFuture::join) + .doOnNext(executionResult -> responses.getAndIncrement()) + .map(ExecutionResultImpl.newExecutionResult()::from) + .map(builder -> builder.addExtension("hasMoreData", hasMoreData(eiGenerator.getNumOfEIs(), responses.get()))) + .map(ExecutionResultImpl.Builder::build) + .map(Object.class::cast) + .takeUntil(object -> eiGenerator.getNumOfEIs() != null && !hasMoreData(eiGenerator.getNumOfEIs(), responses.get())); + + SubscriptionPublisher multiResultPublisher = new SubscriptionPublisher(executionResultPublisher, null); + + return CompletableFuture.completedFuture(ExecutionResultImpl.newExecutionResult().data(multiResultPublisher).build()); + } + + private boolean hasMoreData(Integer expectedNumOfEIs, Integer numOfResponses) { + return expectedNumOfEIs == null || expectedNumOfEIs.intValue() != numOfResponses.intValue(); + } + private ExecutionInput generateEIWIthNewContext(ExecutionInput ei) { + DataLoaderRegistry registry = buildNewDataLoaderRegistry(); + + GraphQLContext graphqlContext = GraphQLContext.newContext() + .of((GraphQLContext)ei.getContext()) + .put(DATA_LOADER_REGISTRY_CONTEXT_KEY, registry) + .put(USE_DEFER, true) + .build(); + + return ei.transform(builder -> { + builder.dataLoaderRegistry(registry); + builder.context(graphqlContext); + }); + } + + private GraphQL constructGraphQL() { final GraphQL.Builder graphqlBuilder = GraphQL.newGraphQL(runtimeGraph.getExecutableSchema()) .instrumentation(new ChainedInstrumentation(instrumentations)) .executionIdProvider(executionIdProvider) @@ -81,17 +156,7 @@ public CompletableFuture execute(ExecutionInput executionInput) if (Objects.nonNull(mutationExecutionStrategy)) { graphqlBuilder.mutationExecutionStrategy(mutationExecutionStrategy); } - - final GraphQL graphQL = graphqlBuilder.build(); - - final ExecutionInput newExecutionInput = executionInput - .transform(builder -> builder.dataLoaderRegistry(buildNewDataLoaderRegistry())); - - if (newExecutionInput.getContext() instanceof GraphQLContext) { - ((GraphQLContext) executionInput.getContext()) - .put(DATA_LOADER_REGISTRY_CONTEXT_KEY, newExecutionInput.getDataLoaderRegistry()); - } - return graphQL.executeAsync(newExecutionInput); + return graphqlBuilder.build(); } public GraphQLSchema getSchema() { @@ -113,7 +178,7 @@ public static class Builder { private ExecutionStrategy queryExecutionStrategy = new AsyncExecutionStrategy(); private ExecutionStrategy mutationExecutionStrategy = null; private List instrumentations = new LinkedList<>( - Arrays.asList(new DataLoaderDispatcherInstrumentation())); + Arrays.asList(new DataLoaderDispatcherInstrumentation(), new DeferDirectiveInstrumentation())); private Builder() { } @@ -155,7 +220,7 @@ public Builder mutationExecutionStrategy(final ExecutionStrategy mutationExecuti public GraphQLOrchestrator build() { return new GraphQLOrchestrator(runtimeGraph, instrumentations, executionIdProvider, queryExecutionStrategy, - mutationExecutionStrategy); + mutationExecutionStrategy); } } } diff --git a/src/main/java/com/intuit/graphql/orchestrator/authorization/DownstreamQueryRedactor.java b/src/main/java/com/intuit/graphql/orchestrator/authorization/DownstreamQueryRedactor.java index 9f9a0d01..e0de8ddd 100644 --- a/src/main/java/com/intuit/graphql/orchestrator/authorization/DownstreamQueryRedactor.java +++ b/src/main/java/com/intuit/graphql/orchestrator/authorization/DownstreamQueryRedactor.java @@ -3,20 +3,19 @@ import com.intuit.graphql.orchestrator.batch.AuthDownstreamQueryModifier; import com.intuit.graphql.orchestrator.schema.ServiceMetadata; import com.intuit.graphql.orchestrator.utils.SelectionCollector; -import graphql.language.AstTransformer; import graphql.language.FragmentDefinition; import graphql.language.Node; import graphql.schema.DataFetchingEnvironment; import graphql.schema.GraphQLType; -import java.util.Map; import lombok.Builder; import lombok.NonNull; -@Builder -public class DownstreamQueryRedactor { +import java.util.Map; - private static final AstTransformer AST_TRANSFORMER = new AstTransformer(); +import static com.intuit.graphql.orchestrator.utils.GraphQLUtil.AST_TRANSFORMER; +@Builder +public class DownstreamQueryRedactor { @NonNull private Node root; @NonNull private GraphQLType rootType; @NonNull private GraphQLType rootParentType; @@ -31,7 +30,9 @@ public DownstreamQueryRedactorResult redact() { return new DownstreamQueryRedactorResult( transformedRoot, downstreamQueryModifier.getDeclineFieldErrors(), - downstreamQueryModifier.redactedQueryHasEmptySelectionSet()); + downstreamQueryModifier.redactedQueryHasEmptySelectionSet(), + downstreamQueryModifier.getFragmentSpreadsRemoved() + ); } private AuthDownstreamQueryModifier createQueryModifier() { diff --git a/src/main/java/com/intuit/graphql/orchestrator/authorization/DownstreamQueryRedactorResult.java b/src/main/java/com/intuit/graphql/orchestrator/authorization/DownstreamQueryRedactorResult.java index e77457a1..82605f3c 100644 --- a/src/main/java/com/intuit/graphql/orchestrator/authorization/DownstreamQueryRedactorResult.java +++ b/src/main/java/com/intuit/graphql/orchestrator/authorization/DownstreamQueryRedactorResult.java @@ -2,11 +2,12 @@ import graphql.GraphqlErrorException; import graphql.language.Node; -import java.util.List; import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NonNull; +import java.util.List; + @Getter @AllArgsConstructor public class DownstreamQueryRedactorResult { @@ -17,4 +18,7 @@ public class DownstreamQueryRedactorResult { private List errors; boolean hasEmptySelectionSet; + + @NonNull + private List fragmentSpreadsRemoved; } \ No newline at end of file diff --git a/src/main/java/com/intuit/graphql/orchestrator/batch/AuthDownstreamQueryModifier.java b/src/main/java/com/intuit/graphql/orchestrator/batch/AuthDownstreamQueryModifier.java index 70d7fdee..1199b406 100644 --- a/src/main/java/com/intuit/graphql/orchestrator/batch/AuthDownstreamQueryModifier.java +++ b/src/main/java/com/intuit/graphql/orchestrator/batch/AuthDownstreamQueryModifier.java @@ -1,17 +1,5 @@ package com.intuit.graphql.orchestrator.batch; -import static com.intuit.graphql.orchestrator.resolverdirective.FieldResolverDirectiveUtil.hasResolverDirective; -import static com.intuit.graphql.orchestrator.utils.QueryPathUtils.getNodesAsPathList; -import static com.intuit.graphql.orchestrator.utils.QueryPathUtils.pathListToFQN; -import static com.intuit.graphql.orchestrator.utils.RenameDirectiveUtil.convertGraphqlFieldWithOriginalName; -import static com.intuit.graphql.orchestrator.utils.RenameDirectiveUtil.getRenameKey; -import static graphql.introspection.Introspection.TypeNameMetaFieldDef; -import static graphql.schema.FieldCoordinates.coordinates; -import static graphql.util.TreeTransformerUtil.changeNode; -import static graphql.util.TreeTransformerUtil.deleteNode; -import static java.util.Objects.nonNull; -import static java.util.Objects.requireNonNull; - import com.intuit.graphql.orchestrator.authorization.FieldAuthorization; import com.intuit.graphql.orchestrator.authorization.FieldAuthorizationEnvironment; import com.intuit.graphql.orchestrator.authorization.FieldAuthorizationResult; @@ -23,9 +11,14 @@ import com.intuit.graphql.orchestrator.schema.transform.FieldResolverContext; import com.intuit.graphql.orchestrator.utils.SelectionCollector; import graphql.GraphQLContext; +import graphql.GraphQLException; import graphql.GraphqlErrorException; +import graphql.language.Argument; +import graphql.language.BooleanValue; +import graphql.language.Directive; import graphql.language.Field; import graphql.language.FragmentDefinition; +import graphql.language.FragmentSpread; import graphql.language.InlineFragment; import graphql.language.Node; import graphql.language.NodeVisitorStub; @@ -40,6 +33,11 @@ import graphql.schema.GraphQLUnionType; import graphql.util.TraversalControl; import graphql.util.TraverserContext; +import lombok.Builder; +import lombok.NonNull; +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.collections4.MapUtils; + import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -47,10 +45,21 @@ import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; -import lombok.Builder; -import lombok.NonNull; -import org.apache.commons.collections4.CollectionUtils; -import org.apache.commons.collections4.MapUtils; + +import static com.intuit.graphql.orchestrator.resolverdirective.FieldResolverDirectiveUtil.hasResolverDirective; +import static com.intuit.graphql.orchestrator.utils.DirectivesUtil.DEFER_DIRECTIVE_NAME; +import static com.intuit.graphql.orchestrator.utils.DirectivesUtil.DEFER_IF_ARG; +import static com.intuit.graphql.orchestrator.utils.DirectivesUtil.USE_DEFER; +import static com.intuit.graphql.orchestrator.utils.QueryPathUtils.getNodesAsPathList; +import static com.intuit.graphql.orchestrator.utils.QueryPathUtils.pathListToFQN; +import static com.intuit.graphql.orchestrator.utils.RenameDirectiveUtil.convertGraphqlFieldWithOriginalName; +import static com.intuit.graphql.orchestrator.utils.RenameDirectiveUtil.getRenameKey; +import static graphql.introspection.Introspection.TypeNameMetaFieldDef; +import static graphql.schema.FieldCoordinates.coordinates; +import static graphql.util.TreeTransformerUtil.changeNode; +import static graphql.util.TreeTransformerUtil.deleteNode; +import static java.util.Objects.nonNull; +import static java.util.Objects.requireNonNull; /** * This class modifies for query for a downstream provider. @@ -67,6 +76,7 @@ public class AuthDownstreamQueryModifier extends NodeVisitorStub { private static final ArgumentValueResolver ARGUMENT_VALUE_RESOLVER = new ArgumentValueResolver(); // thread-safe private final List processedSelectionSetMetadata = new ArrayList<>(); private final List declinedFieldsErrors = new ArrayList<>(); + private final List fragmentSpreadsRemoved = new ArrayList<>(); private boolean hasEmptySelectionSet; @@ -98,6 +108,11 @@ public TraversalControl visitField(Field node, TraverserContext context) { return deleteNode(context); } + List directives = node.getDirectives(); + if(containsDeferDirective(directives)) { + return pruneDeferInfo(node, context, directives); + } + if(!serviceMetadata.getRenamedMetadata().getOriginalFieldNamesByRenamedName().isEmpty()) { String renamedKey = getRenameKey(null, node.getName(), true); String originalName = serviceMetadata.getRenamedMetadata().getOriginalFieldNamesByRenamedName().get(renamedKey); @@ -125,6 +140,22 @@ public TraversalControl visitField(Field node, TraverserContext context) { } } + if(!node.getDirectives(DEFER_DIRECTIVE_NAME).isEmpty()) { + Argument deferArg = node.getDirectives(DEFER_DIRECTIVE_NAME).get(0).getArgument(DEFER_IF_ARG); + if(graphQLContext.getOrDefault(USE_DEFER, false) && (deferArg == null || ((BooleanValue)deferArg.getValue()).isValue())) { + decreaseParentSelectionSetCount(context.getParentContext()); + return deleteNode(context); + } else { + //remove directive from query since directive is not built in and will fail downstream if added + List directives = node.getDirectives() + .stream() + .filter(directive -> !DEFER_DIRECTIVE_NAME.equals(directive.getName())) + .collect(Collectors.toList()); + + return changeNode(context, node.transform(builder -> builder.directives(directives))); + } + } + String renameKey = getRenameKey(parentType.getName(), node.getName(), false); String originalName = serviceMetadata.getRenamedMetadata().getOriginalFieldNamesByRenamedName().get(renameKey); @@ -193,6 +224,24 @@ public TraversalControl visitFragmentDefinition( public TraversalControl visitInlineFragment(InlineFragment node, TraverserContext context) { String typeName = node.getTypeCondition().getName(); context.setVar(GraphQLType.class, this.graphQLSchema.getType(typeName)); + + List directives = node.getDirectives(); + if(containsDeferDirective(directives)) { + return pruneDeferInfo(node, context, directives); + } + + return TraversalControl.CONTINUE; + } + + @Override + public TraversalControl visitFragmentSpread(FragmentSpread node, TraverserContext context) { + context.setVar(GraphQLType.class, this.graphQLSchema.getType(node.getName())); + + List directives = node.getDirectives(); + if(containsDeferDirective(directives)) { + return pruneDeferInfo(node, context, directives); + } + return TraversalControl.CONTINUE; } @@ -280,10 +329,52 @@ private GraphQLType getParentType(TraverserContext context) { return GraphQLTypeUtil.unwrapAll(parentType); } + private boolean containsDeferDirective(List directives) { + return directives != null && directives.stream() + .anyMatch(directive -> DEFER_DIRECTIVE_NAME.equals(directive.getName())); + } + + private TraversalControl pruneDeferInfo(Node node, TraverserContext context, List nodeDirectives) { + Directive deferDirective = nodeDirectives + .stream() + .filter(directive -> DEFER_DIRECTIVE_NAME.equals(directive.getName())) + .findFirst() + .get(); + + Argument deferArg = deferDirective.getArgument(DEFER_IF_ARG); + if(graphQLContext.getOrDefault(USE_DEFER, false) && (deferArg == null || ((BooleanValue)deferArg.getValue()).isValue())) { + decreaseParentSelectionSetCount(context.getParentContext()); + if(node instanceof FragmentSpread) { + this.fragmentSpreadsRemoved.add(((FragmentSpread)node).getName()); + } + + return deleteNode(context); + } else { + final List directives = nodeDirectives + .stream() + .filter(directive -> !DEFER_DIRECTIVE_NAME.equals(directive.getName())) + .collect(Collectors.toList()); + //remove directive from query since directive is not built in and will fail downstream if added + if(node instanceof Field) { + return changeNode(context, ((Field)node).transform(builder -> builder.directives(directives))); + } else if(node instanceof InlineFragment) { + return changeNode(context, ((InlineFragment)node).transform(builder -> builder.directives(directives))); + } else if(node instanceof FragmentSpread) { + return changeNode(context, ((FragmentSpread)node).transform(builder -> builder.directives(directives))); + } else { + throw new GraphQLException("Not Supported Defer Location."); + } + } + } + public List getDeclineFieldErrors() { return declinedFieldsErrors; } + public List getFragmentSpreadsRemoved() { + return this.fragmentSpreadsRemoved; + } + public List getEmptySelectionSets() { return this.processedSelectionSetMetadata.stream() .filter(selectionSetMetadata -> selectionSetMetadata.getRemainingSelectionsCount() == 0) diff --git a/src/main/java/com/intuit/graphql/orchestrator/batch/DownstreamQueryModifier.java b/src/main/java/com/intuit/graphql/orchestrator/batch/DownstreamQueryModifier.java index c6f572e4..36fc60da 100644 --- a/src/main/java/com/intuit/graphql/orchestrator/batch/DownstreamQueryModifier.java +++ b/src/main/java/com/intuit/graphql/orchestrator/batch/DownstreamQueryModifier.java @@ -1,19 +1,14 @@ package com.intuit.graphql.orchestrator.batch; -import static com.intuit.graphql.orchestrator.resolverdirective.FieldResolverDirectiveUtil.hasResolverDirective; -import static com.intuit.graphql.orchestrator.utils.RenameDirectiveUtil.convertGraphqlFieldWithOriginalName; -import static com.intuit.graphql.orchestrator.utils.RenameDirectiveUtil.getRenameKey; -import static graphql.introspection.Introspection.TypeNameMetaFieldDef; -import static graphql.schema.FieldCoordinates.coordinates; -import static graphql.util.TreeTransformerUtil.changeNode; -import static graphql.util.TreeTransformerUtil.deleteNode; -import static java.util.Objects.requireNonNull; - import com.intuit.graphql.orchestrator.federation.RequiredFieldsCollector; import com.intuit.graphql.orchestrator.federation.metadata.FederationMetadata; import com.intuit.graphql.orchestrator.schema.ServiceMetadata; import com.intuit.graphql.orchestrator.schema.transform.FieldResolverContext; import com.intuit.graphql.orchestrator.utils.SelectionCollector; +import graphql.GraphQLContext; +import graphql.language.Argument; +import graphql.language.BooleanValue; +import graphql.language.Directive; import graphql.language.Field; import graphql.language.FragmentDefinition; import graphql.language.InlineFragment; @@ -30,14 +25,27 @@ import graphql.schema.GraphQLUnionType; import graphql.util.TraversalControl; import graphql.util.TraverserContext; +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.collections4.MapUtils; + import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; -import org.apache.commons.collections4.CollectionUtils; -import org.apache.commons.collections4.MapUtils; + +import static com.intuit.graphql.orchestrator.resolverdirective.FieldResolverDirectiveUtil.hasResolverDirective; +import static com.intuit.graphql.orchestrator.utils.DirectivesUtil.DEFER_DIRECTIVE_NAME; +import static com.intuit.graphql.orchestrator.utils.DirectivesUtil.DEFER_IF_ARG; +import static com.intuit.graphql.orchestrator.utils.DirectivesUtil.USE_DEFER; +import static com.intuit.graphql.orchestrator.utils.RenameDirectiveUtil.convertGraphqlFieldWithOriginalName; +import static com.intuit.graphql.orchestrator.utils.RenameDirectiveUtil.getRenameKey; +import static graphql.introspection.Introspection.TypeNameMetaFieldDef; +import static graphql.schema.FieldCoordinates.coordinates; +import static graphql.util.TreeTransformerUtil.changeNode; +import static graphql.util.TreeTransformerUtil.deleteNode; +import static java.util.Objects.requireNonNull; /** * This class modifies for query for a downstream provider. @@ -59,11 +67,14 @@ public class DownstreamQueryModifier extends NodeVisitorStub { private final SelectionCollector selectionCollector; private final GraphQLSchema graphQLSchema; + private final GraphQLContext graphQLContext; + public DownstreamQueryModifier( - GraphQLType rootType, - ServiceMetadata serviceMetadata, - Map fragmentsByName, - GraphQLSchema graphQLSchema) { + GraphQLType rootType, + ServiceMetadata serviceMetadata, + Map fragmentsByName, + GraphQLSchema graphQLSchema, + GraphQLContext context) { Objects.requireNonNull(rootType); Objects.requireNonNull(serviceMetadata); Objects.requireNonNull(fragmentsByName); @@ -71,6 +82,7 @@ public DownstreamQueryModifier( this.serviceMetadata = serviceMetadata; this.selectionCollector = new SelectionCollector(fragmentsByName); this.graphQLSchema = graphQLSchema; + this.graphQLContext = context; } @Override @@ -85,6 +97,21 @@ public TraversalControl visitField(Field node, TraverserContext context) { } } + if(!node.getDirectives(DEFER_DIRECTIVE_NAME).isEmpty()) { + Argument deferArg = node.getDirectives(DEFER_DIRECTIVE_NAME).get(0).getArgument(DEFER_IF_ARG); + if(graphQLContext.getOrDefault(USE_DEFER, false) && (deferArg == null || ((BooleanValue)deferArg.getValue()).isValue())) { + return deleteNode(context); + } else { + //remove directive from query since directive is not built in and will fail downstream if added + List directives = node.getDirectives() + .stream() + .filter(directive -> !DEFER_DIRECTIVE_NAME.equals(directive.getName())) + .collect(Collectors.toList()); + + return changeNode(context, node.transform(builder -> builder.directives(directives))); + } + } + return TraversalControl.CONTINUE; } else { GraphQLFieldsContainer parentType = context.getParentContext().getVar(GraphQLType.class); @@ -112,6 +139,21 @@ public TraversalControl visitField(Field node, TraverserContext context) { return changeNode(context, convertGraphqlFieldWithOriginalName(node, originalName)); } } + + if(!node.getDirectives(DEFER_DIRECTIVE_NAME).isEmpty()) { + Argument deferArg = node.getDirectives(DEFER_DIRECTIVE_NAME).get(0).getArgument(DEFER_IF_ARG); + if(graphQLContext.getOrDefault(USE_DEFER, false) && (deferArg == null || ((BooleanValue)deferArg.getValue()).isValue())) { + return deleteNode(context); + } else { + //remove directive from query since directive is not built in and will fail downstream if added + List directives = node.getDirectives() + .stream() + .filter(directive -> !DEFER_DIRECTIVE_NAME.equals(directive.getName())) + .collect(Collectors.toList()); + + return changeNode(context, node.transform(builder -> builder.directives(directives))); + } + } return TraversalControl.CONTINUE; } } diff --git a/src/main/java/com/intuit/graphql/orchestrator/batch/EntityFetcherBatchLoader.java b/src/main/java/com/intuit/graphql/orchestrator/batch/EntityFetcherBatchLoader.java index 7bf3e37d..832b1168 100644 --- a/src/main/java/com/intuit/graphql/orchestrator/batch/EntityFetcherBatchLoader.java +++ b/src/main/java/com/intuit/graphql/orchestrator/batch/EntityFetcherBatchLoader.java @@ -1,11 +1,5 @@ package com.intuit.graphql.orchestrator.batch; -import static com.intuit.graphql.orchestrator.utils.GraphQLUtil.AST_TRANSFORMER; -import static com.intuit.graphql.orchestrator.utils.GraphQLUtil.unwrapAll; -import static com.intuit.graphql.orchestrator.utils.IntrospectionUtil.__typenameField; -import static graphql.language.Field.newField; -import static graphql.language.InlineFragment.newInlineFragment; - import com.intuit.graphql.orchestrator.ServiceProvider; import com.intuit.graphql.orchestrator.federation.EntityFetchingException; import com.intuit.graphql.orchestrator.federation.EntityQuery; @@ -22,6 +16,9 @@ import graphql.schema.DataFetchingEnvironment; import graphql.schema.GraphQLType; import graphql.schema.GraphQLTypeUtil; +import org.apache.commons.collections4.CollectionUtils; +import org.dataloader.BatchLoader; + import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; @@ -29,8 +26,12 @@ import java.util.Map; import java.util.concurrent.CompletionStage; import java.util.stream.Collectors; -import org.apache.commons.collections4.CollectionUtils; -import org.dataloader.BatchLoader; + +import static com.intuit.graphql.orchestrator.utils.GraphQLUtil.AST_TRANSFORMER; +import static com.intuit.graphql.orchestrator.utils.GraphQLUtil.unwrapAll; +import static com.intuit.graphql.orchestrator.utils.IntrospectionUtil.__typenameField; +import static graphql.language.Field.newField; +import static graphql.language.InlineFragment.newInlineFragment; public class EntityFetcherBatchLoader implements BatchLoader> { @@ -116,7 +117,9 @@ private InlineFragment createEntityRequestInlineFragment(DataFetchingEnvironment if (!GraphQLTypeUtil.isLeaf(fieldType)) { final Field transformedField = (Field) AST_TRANSFORMER.transform(originalField, new DownstreamQueryModifier(fieldType, entityServiceMetadata, - dfe.getFragmentsByName(), dfe.getGraphQLSchema())); + dfe.getFragmentsByName(), dfe.getGraphQLSchema(), dfe.getContext() + ) + ); // is an object fieldSelectionSet = diff --git a/src/main/java/com/intuit/graphql/orchestrator/batch/GraphQLServiceBatchLoader.java b/src/main/java/com/intuit/graphql/orchestrator/batch/GraphQLServiceBatchLoader.java index b48173f4..501fad17 100644 --- a/src/main/java/com/intuit/graphql/orchestrator/batch/GraphQLServiceBatchLoader.java +++ b/src/main/java/com/intuit/graphql/orchestrator/batch/GraphQLServiceBatchLoader.java @@ -1,11 +1,5 @@ package com.intuit.graphql.orchestrator.batch; -import static com.intuit.graphql.orchestrator.schema.transform.DomainTypesTransformer.DELIMITER; -import static graphql.language.AstPrinter.printAstCompact; -import static graphql.language.OperationDefinition.Operation.QUERY; -import static graphql.schema.GraphQLTypeUtil.unwrapAll; -import static java.util.Objects.requireNonNull; - import com.intuit.graphql.orchestrator.authorization.DefaultFieldAuthorization; import com.intuit.graphql.orchestrator.authorization.DownstreamQueryRedactor; import com.intuit.graphql.orchestrator.authorization.DownstreamQueryRedactorResult; @@ -37,10 +31,18 @@ import graphql.schema.GraphQLObjectType; import graphql.schema.GraphQLSchema; import graphql.schema.GraphQLType; +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.collections4.MapUtils; +import org.apache.commons.collections4.MultiValuedMap; +import org.apache.commons.collections4.multimap.ArrayListValuedHashMap; +import org.apache.commons.lang3.StringUtils; +import org.dataloader.BatchLoader; + import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; @@ -49,12 +51,13 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.stream.Collectors; -import org.apache.commons.collections4.CollectionUtils; -import org.apache.commons.collections4.MapUtils; -import org.apache.commons.collections4.MultiValuedMap; -import org.apache.commons.collections4.multimap.ArrayListValuedHashMap; -import org.apache.commons.lang3.StringUtils; -import org.dataloader.BatchLoader; + +import static com.intuit.graphql.orchestrator.schema.transform.DomainTypesTransformer.DELIMITER; +import static com.intuit.graphql.orchestrator.utils.DirectivesUtil.USE_DEFER; +import static graphql.language.AstPrinter.printAstCompact; +import static graphql.language.OperationDefinition.Operation.QUERY; +import static graphql.schema.GraphQLTypeUtil.unwrapAll; +import static java.util.Objects.requireNonNull; public class GraphQLServiceBatchLoader implements BatchLoader> { @@ -115,7 +118,9 @@ private CompletableFuture>> load(List mergedFragmentDefinitions = new HashMap<>(); + //Possibly add a complex object to hold these two MultiValuedMap queryRedactErrorsByKey = new ArrayListValuedHashMap<>(); + Set removedFragments = new HashSet<>(); for (final DataFetchingEnvironment key : keys) { MergedFieldModifierResult result = new MergedFieldModifier(key).getFilteredRootField(); @@ -123,7 +128,7 @@ private CompletableFuture>> load(List removeFieldsWithExternalTypes(field, - operationObjectType, key, authData, fieldAuthorization, queryRedactErrorsByKey)) + operationObjectType, key, authData, fieldAuthorization, queryRedactErrorsByKey, removedFragments, context)) .filter(Objects::nonNull) // denied access or has an empty selectionSet .forEach(selectionSetBuilder::selection); } @@ -193,6 +198,7 @@ private CompletableFuture>> load(List fragmentsAsDefinitions = mergedFragmentDefinitions.values().stream() + .filter(fragment -> !removedFragments.contains(fragment.getName())) .map(GraphQLObjects::cast).collect(Collectors.toList()); OperationDefinition query = OperationDefinition.newOperationDefinition() @@ -227,6 +233,7 @@ private CompletableFuture>> load(List batchResultTransformer.toBatchResult(result, keys)) @@ -368,10 +375,12 @@ private FragmentDefinition removeDomainTypeFromFragment(final FragmentDefinition private Field removeFieldsWithExternalTypes(Field origField, GraphQLObjectType operationObjectType, DataFetchingEnvironment dfe, Object authData, FieldAuthorization fieldAuthorization, - final MultiValuedMap queryRedactErrorsByKey) { + final MultiValuedMap queryRedactErrorsByKey, Set removedFragments, GraphQLContext context) { if (serviceMetadata.shouldModifyDownStreamQuery() || - !(fieldAuthorization instanceof DefaultFieldAuthorization)) { + !(fieldAuthorization instanceof DefaultFieldAuthorization) || + context.getOrDefault(USE_DEFER, false) + ) { GraphQLType origFieldType = getRootFieldDefinition(dfe.getExecutionStepInfo()).getType(); DownstreamQueryRedactor downstreamQueryRedactor = DownstreamQueryRedactor.builder() .root(origField) @@ -387,6 +396,10 @@ private Field removeFieldsWithExternalTypes(Field origField, GraphQLObjectType o String keyPath = dfe.getExecutionStepInfo().getPath().toString(); queryRedactErrorsByKey.putAll(keyPath, redactResult.getErrors()); } + if (CollectionUtils.isNotEmpty(redactResult.getFragmentSpreadsRemoved())) { + removedFragments.addAll(redactResult.getFragmentSpreadsRemoved()); + } + if (redactResult.isHasEmptySelectionSet()) { return null; // if null, not added to selection } diff --git a/src/main/java/com/intuit/graphql/orchestrator/deferDirective/DeferDirectiveInstrumentation.java b/src/main/java/com/intuit/graphql/orchestrator/deferDirective/DeferDirectiveInstrumentation.java new file mode 100644 index 00000000..6c023bc3 --- /dev/null +++ b/src/main/java/com/intuit/graphql/orchestrator/deferDirective/DeferDirectiveInstrumentation.java @@ -0,0 +1,36 @@ +package com.intuit.graphql.orchestrator.deferDirective; + +import graphql.GraphQLContext; +import graphql.execution.instrumentation.SimpleInstrumentation; +import graphql.execution.instrumentation.parameters.InstrumentationFieldFetchParameters; +import graphql.language.Argument; +import graphql.language.BooleanValue; +import graphql.language.Directive; +import graphql.schema.DataFetcher; +import graphql.schema.StaticDataFetcher; + +import java.util.List; + +import static com.intuit.graphql.orchestrator.utils.DirectivesUtil.DEFER_DIRECTIVE_NAME; +import static com.intuit.graphql.orchestrator.utils.DirectivesUtil.DEFER_IF_ARG; +import static com.intuit.graphql.orchestrator.utils.DirectivesUtil.USE_DEFER; + +public class DeferDirectiveInstrumentation extends SimpleInstrumentation { + + @Override + public DataFetcher instrumentDataFetcher(DataFetcher dataFetcher, InstrumentationFieldFetchParameters parameters) { + GraphQLContext gqlContext = parameters.getEnvironment().getContext(); + boolean useDefer = gqlContext != null && gqlContext.getOrDefault(USE_DEFER, false); + + if(useDefer && !parameters.getEnvironment().getField().getDirectives(DEFER_DIRECTIVE_NAME).isEmpty()) { + List deferDirs = parameters.getEnvironment().getField().getDirectives(DEFER_DIRECTIVE_NAME); + Argument ifArgument = deferDirs.get(0).getArgument(DEFER_IF_ARG); + + //Only return null for defer fields if it is enabled otherwise remove from query + if(ifArgument == null || ((BooleanValue) ifArgument.getValue()).isValue()) { + return new StaticDataFetcher(null); + } + } + return dataFetcher; + } +} \ No newline at end of file diff --git a/src/main/java/com/intuit/graphql/orchestrator/deferDirective/DeferOptions.java b/src/main/java/com/intuit/graphql/orchestrator/deferDirective/DeferOptions.java new file mode 100644 index 00000000..16a4f853 --- /dev/null +++ b/src/main/java/com/intuit/graphql/orchestrator/deferDirective/DeferOptions.java @@ -0,0 +1,10 @@ +package com.intuit.graphql.orchestrator.deferDirective; + +import lombok.Builder; +import lombok.Getter; + +@Builder +@Getter +public class DeferOptions { + private boolean nestedDefersAllowed; +} diff --git a/src/main/java/com/intuit/graphql/orchestrator/deferDirective/DeferUtil.java b/src/main/java/com/intuit/graphql/orchestrator/deferDirective/DeferUtil.java new file mode 100644 index 00000000..f5fc1f29 --- /dev/null +++ b/src/main/java/com/intuit/graphql/orchestrator/deferDirective/DeferUtil.java @@ -0,0 +1,53 @@ +package com.intuit.graphql.orchestrator.deferDirective; + +import graphql.language.Argument; +import graphql.language.BooleanValue; +import graphql.language.Directive; +import graphql.language.DirectivesContainer; +import graphql.language.Node; +import graphql.language.Selection; +import graphql.language.SelectionSet; +import lombok.NonNull; + +import java.util.List; + +import static com.intuit.graphql.orchestrator.utils.DirectivesUtil.DEFER_DIRECTIVE_NAME; +import static com.intuit.graphql.orchestrator.utils.DirectivesUtil.DEFER_IF_ARG; + +public class DeferUtil { + /** + * Checks if it is necessary to create ei for deferred field. + * Currently, selection should be skipped if all the children field are deferred resulting in an empty selection set. + * @param selection: node to check if children are all deferred + * @return boolean: true if all children are deferred, false otherwise + */ + public static boolean hasNonDeferredSelection(@NonNull Selection selection) { + return ((List)selection.getChildren()) + .stream() + .filter(SelectionSet.class::isInstance) + .map(SelectionSet.class::cast) + .findAny() + .get() + .getSelections() + .stream() + .anyMatch(child -> !containsEnabledDeferDirective(child)); + } + + /** + * Verifies that Node has defer directive that is not disabled + * @param node: node to check if contains defer directive + * @return boolean: true if node has an enabled defer, false otherwise + */ + public static boolean containsEnabledDeferDirective(Selection node) { + return node instanceof DirectivesContainer && + ((List) ((DirectivesContainer) node).getDirectives()) + .stream() + .filter(directive -> DEFER_DIRECTIVE_NAME.equals(directive.getName())) + .findFirst() + .map(directive -> { + Argument ifArg = directive.getArgument(DEFER_IF_ARG); + return ifArg == null || ((BooleanValue) ifArg.getValue()).isValue(); + }) + .orElse(false); + } +} diff --git a/src/main/java/com/intuit/graphql/orchestrator/deferDirective/MultipartQueryRequest.java b/src/main/java/com/intuit/graphql/orchestrator/deferDirective/MultipartQueryRequest.java new file mode 100644 index 00000000..2b239eb2 --- /dev/null +++ b/src/main/java/com/intuit/graphql/orchestrator/deferDirective/MultipartQueryRequest.java @@ -0,0 +1,16 @@ +package com.intuit.graphql.orchestrator.deferDirective; + +import graphql.language.OperationDefinition; +import lombok.Builder; +import lombok.Getter; +import lombok.Singular; + +import java.util.Set; + +@Getter +@Builder +public class MultipartQueryRequest { + private OperationDefinition multipartOperationDef; + @Singular + private Set fragmentSpreadNames; +} diff --git a/src/main/java/com/intuit/graphql/orchestrator/fieldresolver/FieldResolverBatchSelectionSetSupplier.java b/src/main/java/com/intuit/graphql/orchestrator/fieldresolver/FieldResolverBatchSelectionSetSupplier.java index e4f9cb37..42869def 100644 --- a/src/main/java/com/intuit/graphql/orchestrator/fieldresolver/FieldResolverBatchSelectionSetSupplier.java +++ b/src/main/java/com/intuit/graphql/orchestrator/fieldresolver/FieldResolverBatchSelectionSetSupplier.java @@ -1,20 +1,12 @@ package com.intuit.graphql.orchestrator.fieldresolver; -import static com.intuit.graphql.orchestrator.resolverdirective.FieldResolverDirectiveUtil.getNameFromFieldReference; -import static com.intuit.graphql.orchestrator.resolverdirective.FieldResolverDirectiveUtil.ifInvalidFieldReferenceThrowException; -import static com.intuit.graphql.orchestrator.resolverdirective.FieldResolverDirectiveUtil.isReferenceToFieldInParentType; -import static com.intuit.graphql.orchestrator.utils.GraphQLUtil.AST_TRANSFORMER; -import static com.intuit.graphql.orchestrator.utils.GraphQLUtil.getFieldType; -import static com.intuit.graphql.orchestrator.utils.XtextTypeUtils.isPrimitiveType; -import static graphql.schema.GraphQLTypeUtil.unwrapAll; -import static graphql.schema.InputValueWithState.newExternalValue; - import com.intuit.graphql.orchestrator.batch.DownstreamQueryModifier; import com.intuit.graphql.orchestrator.resolverdirective.FieldResolverDirectiveUtil; import com.intuit.graphql.orchestrator.resolverdirective.ResolverArgumentDefinition; import com.intuit.graphql.orchestrator.resolverdirective.ResolverDirectiveDefinition; import com.intuit.graphql.orchestrator.schema.ServiceMetadata; import com.intuit.graphql.orchestrator.schema.transform.FieldResolverContext; +import graphql.GraphQLContext; import graphql.Scalars; import graphql.execution.ValuesResolver; import graphql.language.Argument; @@ -29,14 +21,24 @@ import graphql.schema.GraphQLSchema; import graphql.schema.GraphQLType; import graphql.schema.GraphQLTypeUtil; +import lombok.AllArgsConstructor; +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; + import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.function.Supplier; -import lombok.AllArgsConstructor; -import org.apache.commons.collections4.CollectionUtils; -import org.apache.commons.lang3.StringUtils; + +import static com.intuit.graphql.orchestrator.resolverdirective.FieldResolverDirectiveUtil.getNameFromFieldReference; +import static com.intuit.graphql.orchestrator.resolverdirective.FieldResolverDirectiveUtil.ifInvalidFieldReferenceThrowException; +import static com.intuit.graphql.orchestrator.resolverdirective.FieldResolverDirectiveUtil.isReferenceToFieldInParentType; +import static com.intuit.graphql.orchestrator.utils.GraphQLUtil.AST_TRANSFORMER; +import static com.intuit.graphql.orchestrator.utils.GraphQLUtil.getFieldType; +import static com.intuit.graphql.orchestrator.utils.XtextTypeUtils.isPrimitiveType; +import static graphql.schema.GraphQLTypeUtil.unwrapAll; +import static graphql.schema.InputValueWithState.newExternalValue; @AllArgsConstructor public class FieldResolverBatchSelectionSetSupplier implements Supplier { @@ -65,7 +67,7 @@ private SelectionSet createBatchSelectionSet() { GraphQLObjectType rootFieldParentType = graphQLSchema.getQueryType(); GraphQLType rootFieldType = getFieldType(rootField, rootFieldParentType).get(); rootField = removeFieldsWithExternalTypes(rootField, rootFieldType, dataFetchingEnvironment - .getFragmentsByName(), graphQLSchema); + .getFragmentsByName(), graphQLSchema, dataFetchingEnvironment.getContext()); parentSelectionSetBuilder.selection(rootField); } @@ -73,10 +75,11 @@ private SelectionSet createBatchSelectionSet() { } private Field removeFieldsWithExternalTypes(final Field field, - GraphQLType parentType, Map fragmentsByName, GraphQLSchema graphQLSchema) { + GraphQLType parentType, Map fragmentsByName, + GraphQLSchema graphQLSchema, GraphQLContext context) { // call serviceMetadata.hasFieldResolverDirective() before calling this method return (Field) AST_TRANSFORMER.transform(field, - new DownstreamQueryModifier(unwrapAll(parentType), serviceMetadata, fragmentsByName, graphQLSchema)); + new DownstreamQueryModifier(unwrapAll(parentType), serviceMetadata, fragmentsByName, graphQLSchema, context)); } private List createFieldArguments(ResolverDirectiveDefinition resolverDirectiveDefinition, DataFetchingEnvironment dataFetchingEnvironment) { diff --git a/src/main/java/com/intuit/graphql/orchestrator/resolverdirective/FieldResolverDirectiveUtil.java b/src/main/java/com/intuit/graphql/orchestrator/resolverdirective/FieldResolverDirectiveUtil.java index bf319f44..f9fc4581 100644 --- a/src/main/java/com/intuit/graphql/orchestrator/resolverdirective/FieldResolverDirectiveUtil.java +++ b/src/main/java/com/intuit/graphql/orchestrator/resolverdirective/FieldResolverDirectiveUtil.java @@ -22,6 +22,7 @@ import java.util.stream.Collectors; import static com.intuit.graphql.orchestrator.resolverdirective.ResolverDirectiveDefinition.extractRequiredFieldsFrom; +import static com.intuit.graphql.orchestrator.utils.DirectivesUtil.DEFER_DIRECTIVE_NAME; import static com.intuit.graphql.orchestrator.utils.XtextTypeUtils.getFieldDefinitions; public class FieldResolverDirectiveUtil { @@ -118,6 +119,10 @@ public static boolean hasResolverDirective(GraphQLFieldDefinition fieldDefinitio return fieldDefinition.getDirective(RESOLVER_DIRECTIVE_NAME) != null; } + public static boolean hasDeferDirective(GraphQLFieldDefinition fieldDefinition) { + return fieldDefinition.getDirective(DEFER_DIRECTIVE_NAME) != null; + } + public static boolean isObjectOrInterfaceType(TypeDefinition typeDefinition) { return typeDefinition instanceof ObjectTypeDefinition || typeDefinition instanceof InterfaceTypeDefinition; } diff --git a/src/main/java/com/intuit/graphql/orchestrator/stitching/XtextStitcher.java b/src/main/java/com/intuit/graphql/orchestrator/stitching/XtextStitcher.java index 359c2e11..091ab3b5 100644 --- a/src/main/java/com/intuit/graphql/orchestrator/stitching/XtextStitcher.java +++ b/src/main/java/com/intuit/graphql/orchestrator/stitching/XtextStitcher.java @@ -1,17 +1,5 @@ package com.intuit.graphql.orchestrator.stitching; -import static com.intuit.graphql.orchestrator.batch.DataLoaderKeyUtil.createDataLoaderKey; -import static com.intuit.graphql.orchestrator.resolverdirective.FieldResolverDirectiveUtil.RESOLVER_ARGUMENT_INPUT_NAME; -import static com.intuit.graphql.orchestrator.resolverdirective.FieldResolverDirectiveUtil.RESOLVER_DIRECTIVE_NAME; -import static com.intuit.graphql.orchestrator.utils.XtextUtils.getAllTypes; -import static com.intuit.graphql.orchestrator.xtext.DataFetcherContext.DataFetcherType.ENTITY_FETCHER; -import static com.intuit.graphql.orchestrator.xtext.DataFetcherContext.DataFetcherType.RESOLVER_ARGUMENT; -import static com.intuit.graphql.orchestrator.xtext.DataFetcherContext.DataFetcherType.RESOLVER_ON_FIELD_DEFINITION; -import static com.intuit.graphql.orchestrator.xtext.DataFetcherContext.DataFetcherType.SERVICE; -import static com.intuit.graphql.orchestrator.xtext.DataFetcherContext.DataFetcherType.STATIC; -import static graphql.schema.FieldCoordinates.coordinates; -import static java.util.Objects.requireNonNull; - import com.intuit.graphql.graphQL.ObjectTypeDefinition; import com.intuit.graphql.graphQL.TypeDefinition; import com.intuit.graphql.orchestrator.ServiceProvider; @@ -61,6 +49,8 @@ import graphql.schema.GraphQLObjectType; import graphql.schema.GraphQLType; import graphql.schema.StaticDataFetcher; +import org.dataloader.BatchLoader; + import java.util.Arrays; import java.util.Collections; import java.util.EnumMap; @@ -70,7 +60,19 @@ import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; -import org.dataloader.BatchLoader; + +import static com.intuit.graphql.orchestrator.batch.DataLoaderKeyUtil.createDataLoaderKey; +import static com.intuit.graphql.orchestrator.resolverdirective.FieldResolverDirectiveUtil.RESOLVER_ARGUMENT_INPUT_NAME; +import static com.intuit.graphql.orchestrator.resolverdirective.FieldResolverDirectiveUtil.RESOLVER_DIRECTIVE_NAME; +import static com.intuit.graphql.orchestrator.utils.DirectivesUtil.getBuiltInClientDirectives; +import static com.intuit.graphql.orchestrator.utils.XtextUtils.getAllTypes; +import static com.intuit.graphql.orchestrator.xtext.DataFetcherContext.DataFetcherType.ENTITY_FETCHER; +import static com.intuit.graphql.orchestrator.xtext.DataFetcherContext.DataFetcherType.RESOLVER_ARGUMENT; +import static com.intuit.graphql.orchestrator.xtext.DataFetcherContext.DataFetcherType.RESOLVER_ON_FIELD_DEFINITION; +import static com.intuit.graphql.orchestrator.xtext.DataFetcherContext.DataFetcherType.SERVICE; +import static com.intuit.graphql.orchestrator.xtext.DataFetcherContext.DataFetcherType.STATIC; +import static graphql.schema.FieldCoordinates.coordinates; +import static java.util.Objects.requireNonNull; /** * The type Xtext stitcher. @@ -295,11 +297,14 @@ private RuntimeGraph.Builder createRuntimeGraph(UnifiedXtextGraph unifiedXtextGr .forEach((operation, objectTypeDefinition) -> operationMap.put(operation, (GraphQLObjectType) visitor.doSwitch(objectTypeDefinition))); + Set directives = getAdditionalDirectives(unifiedXtextGraph, visitor.getDirectiveDefinitions()); + directives.addAll(getBuiltInClientDirectives()); + return RuntimeGraph.newBuilder() .operationMap(operationMap) .objectTypes(visitor.getGraphQLObjectTypes()) .additionalTypes(getAdditionalTypes(unifiedXtextGraph, visitor.getGraphQLObjectTypes())) - .additionalDirectives(getAdditionalDirectives(unifiedXtextGraph, visitor.getDirectiveDefinitions())); + .additionalDirectives(directives); } /** diff --git a/src/main/java/com/intuit/graphql/orchestrator/utils/DirectivesUtil.java b/src/main/java/com/intuit/graphql/orchestrator/utils/DirectivesUtil.java index 302418e5..4fb96c66 100644 --- a/src/main/java/com/intuit/graphql/orchestrator/utils/DirectivesUtil.java +++ b/src/main/java/com/intuit/graphql/orchestrator/utils/DirectivesUtil.java @@ -2,34 +2,43 @@ import graphql.Scalars; import graphql.introspection.Introspection.DirectiveLocation; +import graphql.language.BooleanValue; import graphql.schema.GraphQLArgument; import graphql.schema.GraphQLDirective; import org.apache.commons.collections4.CollectionUtils; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import static java.util.stream.Collectors.toMap; /** - * Holds creates and maintains DEPRECATED_DIRECTIVE object and Directives specific util functions. The @skip and + * Holds creates and maintains DEPRECATED_DIRECTIVE object and Directives specific util functions. The skip and * - * @include is used in queries and not used in schema parsing, thus, not included here. + * include directives are used in queries and not used in schema parsing, thus, not included here. */ -class DirectivesUtil { +public class DirectivesUtil { static final String NO_LONGER_SUPPORTED = "No longer supported"; static final GraphQLDirective DEPRECATED_DIRECTIVE; + static final GraphQLDirective DEFER_DIRECTIVE; + public static final String DEFER_DIRECTIVE_NAME = "defer"; + public static final String DEFER_IF_ARG = "if"; + public static final String USE_DEFER = "useDefer"; private DirectivesUtil() { } static { DEPRECATED_DIRECTIVE = createDeprecatedDirective(); + DEFER_DIRECTIVE = createDeferDirective(); } + /** * creates GraphQLDirective for @deprecated directive. * @@ -37,29 +46,29 @@ private DirectivesUtil() { */ private static GraphQLDirective createDeprecatedDirective() { GraphQLArgument reasonArgument = GraphQLArgument.newArgument() - .name("reason") - .type(Scalars.GraphQLString) - .defaultValue(NO_LONGER_SUPPORTED) - .build(); + .name("reason") + .type(Scalars.GraphQLString) + .defaultValue(NO_LONGER_SUPPORTED) + .build(); return GraphQLDirective.newDirective() - .name("deprecated") - .argument(reasonArgument) - .validLocation(DirectiveLocation.FIELD_DEFINITION) - .validLocation(DirectiveLocation.ENUM_VALUE) - .build(); + .name("deprecated") + .argument(reasonArgument) + .validLocation(DirectiveLocation.FIELD_DEFINITION) + .validLocation(DirectiveLocation.ENUM_VALUE) + .build(); } public static String buildDeprecationReason(List directives) { if (CollectionUtils.isNotEmpty(directives)) { Optional directive = directives.stream() - .filter(d -> "deprecated".equals(d.getName())) - .findFirst(); + .filter(d -> "deprecated".equals(d.getName())) + .findFirst(); if (directive.isPresent()) { Map args = directive.get().getArguments().stream() - .collect(toMap(GraphQLArgument::getName, arg -> (String) arg.getArgumentValue().getValue())); + .collect(toMap(GraphQLArgument::getName, arg -> (String) arg.getArgumentValue().getValue())); if (args.isEmpty()) { return NO_LONGER_SUPPORTED; // default value from spec } else { @@ -70,4 +79,33 @@ public static String buildDeprecationReason(List directives) { } return null; } + + private static GraphQLDirective createDeferDirective() { + GraphQLArgument ifArgument = GraphQLArgument.newArgument() + .name(DEFER_IF_ARG) + .type(Scalars.GraphQLBoolean) + .defaultValueLiteral(BooleanValue.newBooleanValue(true).build()) + .build(); + + GraphQLArgument labelArgument = GraphQLArgument.newArgument() + .name(DEFER_IF_ARG) + .type(Scalars.GraphQLString) + .build(); + + return GraphQLDirective.newDirective() + .name(DEFER_DIRECTIVE_NAME) + .validLocation(DirectiveLocation.FIELD) + .validLocation(DirectiveLocation.INLINE_FRAGMENT) + .validLocation(DirectiveLocation.FRAGMENT_SPREAD) + .argument(ifArgument) + .argument(labelArgument) + .build(); + } + + public static Set getBuiltInClientDirectives() { + HashSet directives = new HashSet(); + directives.add(DEFER_DIRECTIVE); + return directives; + } + } diff --git a/src/main/java/com/intuit/graphql/orchestrator/utils/GraphQLUtil.java b/src/main/java/com/intuit/graphql/orchestrator/utils/GraphQLUtil.java index a3963b96..a68507a2 100644 --- a/src/main/java/com/intuit/graphql/orchestrator/utils/GraphQLUtil.java +++ b/src/main/java/com/intuit/graphql/orchestrator/utils/GraphQLUtil.java @@ -38,7 +38,7 @@ public class GraphQLUtil { public static final AstTransformer AST_TRANSFORMER = new AstTransformer(); private static final String ERR_CREATE_TYPE_UNEXPECTED_TYPE = "Failed to create Type due to " - + "unexpected GraphQL Type %s"; + + "unexpected GraphQL Type %s"; private GraphQLUtil() { } @@ -71,9 +71,9 @@ public static GraphQLType unwrapAll(GraphQLType type) { public static List getErrors(DataFetcherResult> result, Field f) { return result.getErrors().stream() - .filter(e -> e.getPath() == null || e.getPath().isEmpty() || f.getName() - .equals(String.valueOf(e.getPath().get(0)))) - .collect(Collectors.toList()); + .filter(e -> e.getPath() == null || e.getPath().isEmpty() || f.getName() + .equals(String.valueOf(e.getPath().get(0)))) + .collect(Collectors.toList()); } public static Type createTypeBasedOnGraphQLType(GraphQLType graphQLType) { @@ -83,14 +83,14 @@ public static Type createTypeBasedOnGraphQLType(GraphQLType graphQLType) { } if (graphQLType instanceof GraphQLNonNull) { return NonNullType.newNonNullType() - .type(createTypeBasedOnGraphQLType(((GraphQLNonNull) graphQLType).getWrappedType())).build(); + .type(createTypeBasedOnGraphQLType(((GraphQLNonNull) graphQLType).getWrappedType())).build(); } if (graphQLType instanceof GraphQLList) { return ListType.newListType().type(createTypeBasedOnGraphQLType(((GraphQLList) graphQLType).getWrappedType())) - .build(); + .build(); } throw new CreateTypeException( - String.format(ERR_CREATE_TYPE_UNEXPECTED_TYPE, GraphQLTypeUtil.simplePrint(graphQLType))); + String.format(ERR_CREATE_TYPE_UNEXPECTED_TYPE, GraphQLTypeUtil.simplePrint(graphQLType))); } public static Optional getFieldType(Field field, GraphQLFieldsContainer fieldsContainer) { diff --git a/src/main/java/com/intuit/graphql/orchestrator/utils/MultiEIGenerator.java b/src/main/java/com/intuit/graphql/orchestrator/utils/MultiEIGenerator.java new file mode 100644 index 00000000..ca9ba947 --- /dev/null +++ b/src/main/java/com/intuit/graphql/orchestrator/utils/MultiEIGenerator.java @@ -0,0 +1,112 @@ +package com.intuit.graphql.orchestrator.utils; + +import com.google.common.annotations.VisibleForTesting; +import com.intuit.graphql.orchestrator.deferDirective.DeferOptions; +import com.intuit.graphql.orchestrator.visitors.queryVisitors.DeferQueryExtractor; +import graphql.ExecutionInput; +import graphql.analysis.QueryTransformer; +import graphql.language.Document; +import graphql.language.FragmentDefinition; +import graphql.language.OperationDefinition; +import graphql.language.SelectionSet; +import graphql.schema.GraphQLSchema; +import lombok.extern.slf4j.Slf4j; +import reactor.core.publisher.Flux; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static com.intuit.graphql.orchestrator.utils.GraphQLUtil.parser; + +@Slf4j +public class MultiEIGenerator { + private final List eis = new ArrayList<>(); + private final DeferOptions deferOptions; + private final GraphQLSchema schema; + private Integer numOfEIs = null; + + @VisibleForTesting + private long timeProcessedSplit = 0; + + public MultiEIGenerator(ExecutionInput ei, DeferOptions deferOptions, GraphQLSchema schema) { + this.eis.add(ei); + this.deferOptions = deferOptions; + this.schema = schema; + } + + public Flux generateEIs() { + return Flux.generate(() -> 0, (indexToProcess, sink) -> { + ExecutionInput emittedEI = null; + int nextIndex = indexToProcess + 1; + + //Only emit if it is the initial index or a valid index + if(this.numOfEIs == null || nextIndex <= this.numOfEIs) { + emittedEI = this.eis.get(indexToProcess); + //emit the index that needs to be processed + sink.next(emittedEI); + } + + //if null/first iteration then proceed to split ei and add to list of eis needing to be emitted + if(this.numOfEIs == null) { + this.timeProcessedSplit = System.currentTimeMillis(); + //Adds elements to list of eis that need to be processed + try { + Document rootDocument = parser.parseDocument(emittedEI.getQuery()); + Map fragmentDefinitionMap = rootDocument.getDefinitionsOfType(FragmentDefinition.class) + .stream() + .collect(Collectors.toMap(FragmentDefinition::getName , Function.identity())); + + ExecutionInput finalEmittedEI = emittedEI; + AtomicReference operationDefinitionReference = new AtomicReference<>(); + rootDocument.getDefinitionsOfType(OperationDefinition.class) + .stream() + .peek(operationDefinitionReference::set) + .map(OperationDefinition::getSelectionSet) + .map(SelectionSet::getSelections) + .flatMap(List::stream) + .forEach(selection -> { + QueryTransformer transformer = QueryTransformer.newQueryTransformer() + .schema(this.schema) + .root(selection) + .rootParentType(this.schema.getQueryType()) + .fragmentsByName(fragmentDefinitionMap) + .variables(finalEmittedEI.getVariables()) + .build(); + + DeferQueryExtractor visitor = DeferQueryExtractor.builder() + .deferOptions(deferOptions) + .originalEI(finalEmittedEI) + .rootNode(rootDocument) + .operationDefinition(operationDefinitionReference.get()) + .fragmentDefinitionMap(fragmentDefinitionMap) + .build(); + + transformer.transform(visitor); + + this.eis.addAll(visitor.getExtractedEIs()); + }); + } + catch (Exception ex) { + sink.error(ex); + sink.complete(); + } + //sets the number of expected eis which is also the number of responses expected + this.numOfEIs = this.eis.size(); + } else if(nextIndex > this.numOfEIs) { + //index reached the end of all the eis that need to be processed + sink.complete(); + } + + //Call generator with the next index to process + return nextIndex; + }); + } + + public Integer getNumOfEIs() { + return this.numOfEIs; + } +} \ No newline at end of file diff --git a/src/main/java/com/intuit/graphql/orchestrator/utils/MultipartUtil.java b/src/main/java/com/intuit/graphql/orchestrator/utils/MultipartUtil.java new file mode 100644 index 00000000..932f772f --- /dev/null +++ b/src/main/java/com/intuit/graphql/orchestrator/utils/MultipartUtil.java @@ -0,0 +1,284 @@ +package com.intuit.graphql.orchestrator.utils; + +import com.intuit.graphql.orchestrator.deferDirective.MultipartQueryRequest; +import graphql.ExecutionInput; +import graphql.language.Argument; +import graphql.language.AstPrinter; +import graphql.language.BooleanValue; +import graphql.language.Definition; +import graphql.language.Directive; +import graphql.language.Document; +import graphql.language.Field; +import graphql.language.FragmentDefinition; +import graphql.language.FragmentSpread; +import graphql.language.InlineFragment; +import graphql.language.OperationDefinition; +import graphql.language.Selection; +import graphql.language.SelectionSet; +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static com.intuit.graphql.orchestrator.utils.DirectivesUtil.DEFER_DIRECTIVE_NAME; +import static com.intuit.graphql.orchestrator.utils.DirectivesUtil.DEFER_IF_ARG; + +/** + * This class contains helper methods for GraphQL types. This class will often contain modifications of already built + * methods in graphql-java's own GraphQlTypeUtil class. + */ +public class MultipartUtil { + + public static List splitMultipartExecutionInput(ExecutionInput originalEI) { + Document originalDoc = GraphQLUtil.parser.parseDocument(originalEI.getQuery()); + List eiList = new ArrayList<>(); + + Map fragmentDefinitions = originalDoc.getDefinitionsOfType(FragmentDefinition.class) + .stream() + .collect(Collectors.toMap(FragmentDefinition::getName, Function.identity())); + + originalDoc.getDefinitionsOfType(OperationDefinition.class).forEach(operationDef -> { + OperationDefinition mappedOpDef = operationDef; + List deferredPaths = new ArrayList<>(); + + //adds all the paths that are needed to split for path to the list + addMultipartChildPaths(mappedOpDef.getSelectionSet(), deferredPaths, "", new HashMap<>()); + + if(!deferredPaths.isEmpty()) { + constructDeferredOperationDefinitions(operationDef, deferredPaths) + .stream() + .map(multipartRequest -> { + List operationDefinitions = new ArrayList<>(); + List fragmentSpreadDefs = multipartRequest.getFragmentSpreadNames() + .stream() + .map(fragmentDefinitions::get) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + + operationDefinitions.add(multipartRequest.getMultipartOperationDef()); + operationDefinitions.addAll(fragmentSpreadDefs); + return operationDefinitions; + }) + .map(multipartOperationDefs -> originalDoc.transform(builder -> builder.definitions(multipartOperationDefs))) + .map(AstPrinter::printAst) + .map(query -> originalEI.transform(builder -> builder.query(query))) + .forEach(eiList::add); + } + }); + + return eiList; + } + + //fragment name maps to definition + private static List constructDeferredOperationDefinitions(OperationDefinition operationDefinition, List multipartPaths) { + List multipartQueryRequests = new ArrayList<>(); + multipartPaths.forEach(deferPath -> { + String[] deferFieldPath = deferPath.split("\\."); + int lastSelectionInPathIndex = deferFieldPath.length -1; + String multipartFieldName = deferFieldPath[lastSelectionInPathIndex]; + List currentFieldPath = getQueryPathSelections(operationDefinition.getSelectionSet(), deferFieldPath); + Set fragmentsNeeded = new HashSet<>(); + + Selection multipartChildFieldContext = null; + + //iterates through path and constructs field/selectionSet without defer directive and fields irrelevant fields to defer query + for (int j = lastSelectionInPathIndex; j >= 0; j--) { + Selection multipartFieldContext = currentFieldPath.get(j); + if(multipartFieldContext instanceof FragmentSpread) { + fragmentsNeeded.add(((FragmentSpread) multipartFieldContext).getName()); + } + if(multipartFieldContext instanceof InlineFragment) { + multipartFieldName = multipartFieldName.split("-")[0]; + } + + multipartChildFieldContext = updateMultipartSelectionSetField(multipartFieldContext, multipartChildFieldContext, multipartFieldName); + } + + SelectionSet multipartSelectionSet = SelectionSet.newSelectionSet() + .selection(multipartChildFieldContext) + .build(); + + OperationDefinition newMultipartOperationDefinition = operationDefinition.transform( + builder -> builder.selectionSet(multipartSelectionSet) + ); + + multipartQueryRequests.add( + MultipartQueryRequest.builder() + .multipartOperationDef(newMultipartOperationDefinition) + .fragmentSpreadNames(fragmentsNeeded) + .build() + ); + }); + + return multipartQueryRequests; + } + + private static List getQueryPathSelections(SelectionSet originalSelectionSet, String[] fieldKeys) { + List currentFieldPath = new ArrayList<>(); + Selection currentField; + + for (int i = 0; i < fieldKeys.length; i++) { + Selection lastField = ( i > 0) ? currentFieldPath.get(currentFieldPath.size() -1) : null; + SelectionSet selectionSet = (i==0) ? originalSelectionSet : getLastFieldsSelectionSet(lastField); + + boolean isChildAnInlineFragment = fieldKeys[i].contains("-"); + currentField = getSelectionWithName(selectionSet, fieldKeys[i], isChildAnInlineFragment); + currentFieldPath.add(currentField); + } + + return currentFieldPath; + } + + private static SelectionSet getLastFieldsSelectionSet(Selection selection) { + if(selection instanceof Field) { + return ((Field) selection).getSelectionSet(); + } else if(selection instanceof InlineFragment) { + return ((InlineFragment) selection).getSelectionSet(); + } else { + //throw exception because it should not reach + return null; + } + } + + private static String getSelectionName(Selection selection) { + if(selection instanceof Field) return ((Field) selection).getName(); + else if (selection instanceof FragmentSpread) return ((FragmentSpread) selection).getName(); + else if (selection instanceof InlineFragment) return ((InlineFragment) selection).getTypeCondition().getName(); + else { + //throw error + return null; + } + } + + private static Selection getSelectionWithName (SelectionSet selectionSet, String selectionName, boolean isInlineFragment) { + String name = (!isInlineFragment) ? selectionName : selectionName.split("-")[0]; + int idx = (!isInlineFragment) ? 0 : Integer.parseInt(selectionName.split("-")[1]); + + return selectionSet.getSelections() + .stream() + .filter(selection -> getSelectionName(selection).equals(name)) + .collect(Collectors.toList()).get(idx); + } + + private static Selection updateMultipartSelectionSetField(Selection parentField, Selection currentField, String deferName) { + Selection newSelection = null; + + if(getSelectionName(parentField).equals(deferName)) { + if(parentField instanceof Field) { + List prunedDirectives = ((Field)parentField).getDirectives() + .stream() + .filter(directive -> !directive.getName().equals(DEFER_DIRECTIVE_NAME)) + .collect(Collectors.toList()); + + newSelection = ((Field)parentField).transform(builder -> builder.directives(prunedDirectives)); + } else if (parentField instanceof FragmentSpread) { + List prunedDirectives = ((FragmentSpread)parentField).getDirectives() + .stream() + .filter(directive -> !directive.getName().equals(DEFER_DIRECTIVE_NAME)) + .collect(Collectors.toList()); + + newSelection = ((FragmentSpread) parentField).transform(builder -> builder.directives(prunedDirectives)); + } else if (parentField instanceof InlineFragment) { + List prunedDirectives = ((InlineFragment)parentField).getDirectives() + .stream() + .filter(directive -> !directive.getName().equals(DEFER_DIRECTIVE_NAME)) + .collect(Collectors.toList()); + + newSelection = ((InlineFragment)parentField).transform(builder -> builder.directives(prunedDirectives)); + } else { + return null; + } + } else { + Field __typeName = Field.newField().name("__typename").build(); + + if(parentField instanceof Field) { + newSelection = ((Field)parentField).transform(builder -> builder.selectionSet( + SelectionSet.newSelectionSet().selection(currentField).selection(__typeName).build() + )); + } else if (parentField instanceof InlineFragment) { + newSelection = ((InlineFragment)parentField).transform(builder -> builder.selectionSet( + SelectionSet.newSelectionSet().selection(currentField).selection(__typeName).build() + )); + } else { + return null; + } + } + + return newSelection; + } + + private static void addMultipartChildPaths(SelectionSet selectionSet, List deferredPaths, String currentPath, HashMap inlineFragmentIdxMap) { + selectionSet.getSelections() + .stream() + .filter(InlineFragment.class::isInstance) + .map(InlineFragment.class::cast) + .forEach(childInlineFragment -> { + String inlineFragmentName = childInlineFragment.getTypeCondition().getName(); + Integer currentIdx = inlineFragmentIdxMap.getOrDefault(inlineFragmentName, 0); + processInlineFragment(childInlineFragment, deferredPaths, currentPath, currentIdx, inlineFragmentIdxMap); + currentIdx++; + inlineFragmentIdxMap.put(inlineFragmentName, currentIdx); + }); + + selectionSet.getSelections() + .stream() + .filter(FragmentSpread.class::isInstance) + .map(FragmentSpread.class::cast) + .forEach(childFragmentSpread -> processFragmentSpread(childFragmentSpread, deferredPaths, currentPath)); + + selectionSet.getSelections() + .stream() + .filter(Field.class::isInstance) + .map(Field.class::cast) + .forEach(childField -> processField(childField, deferredPaths, currentPath, inlineFragmentIdxMap)); + } + + private static void processField(Field field, List deferredPaths, String currentPath, HashMap inlineFragmentMap) { + String currentChildName = field.getName(); + String childPath = (StringUtils.isBlank(currentPath)) ? currentChildName : StringUtils.join( currentPath, ".", currentChildName); + if(field.hasDirective(DEFER_DIRECTIVE_NAME)) { + Argument ifArg = field.getDirectives(DEFER_DIRECTIVE_NAME).get(0).getArgument(DEFER_IF_ARG); + if(ifArg == null || ((BooleanValue) ifArg.getValue()).isValue()) { + deferredPaths.add(childPath); + } + } + + if(field.getSelectionSet() != null && CollectionUtils.isNotEmpty(field.getSelectionSet().getSelections())) { + addMultipartChildPaths(field.getSelectionSet(), deferredPaths, childPath, inlineFragmentMap); + } + } + + private static void processInlineFragment(InlineFragment inlineFragment, List deferredPaths, String currentPath, int index, HashMap inlineFragmentMap) { + String currentChildName = StringUtils.join(inlineFragment.getTypeCondition().getName(), "-", index); + String childPath = (StringUtils.isBlank(currentPath)) ? currentChildName : StringUtils.join( currentPath, ".", currentChildName); + if(inlineFragment.hasDirective(DEFER_DIRECTIVE_NAME)) { + Argument ifArg = inlineFragment.getDirectives(DEFER_DIRECTIVE_NAME).get(0).getArgument(DEFER_IF_ARG); + if(ifArg == null || ((BooleanValue) ifArg.getValue()).isValue()) { + deferredPaths.add(childPath); + } + } + + if(inlineFragment.getSelectionSet() != null && CollectionUtils.isNotEmpty(inlineFragment.getSelectionSet().getSelections())) { + addMultipartChildPaths(inlineFragment.getSelectionSet(), deferredPaths, childPath, inlineFragmentMap); + } + } + + private static void processFragmentSpread(FragmentSpread fragmentSpread, List deferredPaths, String currentPath) { + String currentChildName = fragmentSpread.getName(); + String childPath = (StringUtils.isBlank(currentPath)) ? currentChildName : StringUtils.join( currentPath, ".", currentChildName); + if(fragmentSpread.hasDirective(DEFER_DIRECTIVE_NAME)) { + Argument ifArg = fragmentSpread.getDirectives(DEFER_DIRECTIVE_NAME).get(0).getArgument(DEFER_IF_ARG); + if(ifArg == null || ((BooleanValue) ifArg.getValue()).isValue()) { + deferredPaths.add(childPath); + } + } + } +} diff --git a/src/main/java/com/intuit/graphql/orchestrator/utils/NodeUtils.java b/src/main/java/com/intuit/graphql/orchestrator/utils/NodeUtils.java new file mode 100644 index 00000000..f06c8d74 --- /dev/null +++ b/src/main/java/com/intuit/graphql/orchestrator/utils/NodeUtils.java @@ -0,0 +1,98 @@ +package com.intuit.graphql.orchestrator.utils; + +import graphql.language.Directive; +import graphql.language.DirectivesContainer; +import graphql.language.Field; +import graphql.language.FragmentDefinition; +import graphql.language.FragmentSpread; +import graphql.language.InlineFragment; +import graphql.language.Node; +import graphql.language.OperationDefinition; +import graphql.language.Selection; +import graphql.language.SelectionSet; +import org.apache.commons.lang3.ObjectUtils; + +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +public class NodeUtils { + + /** + * Transforms node by removing desired directive from the node. + * Returns transformed node + * Throws an exception if it is at a location that is not supported. + * @param node: the selection that you would like to prune + * @param directiveName: the directive you would like to search for + * @return selection: pruned selection + * */ + public static Selection removeDirectiveFromNode(Selection node, String directiveName) { + //DirectivesContainer returns getDirectives as a List so need to cast to stream correctly + List prunedDirectives = ((List)((DirectivesContainer)node).getDirectives()) + .stream() + .filter(directive -> ObjectUtils.notEqual(directiveName, directive.getName())) + .collect(Collectors.toList()); + + if(node instanceof Field) { + return ((Field) node).transform(builder -> builder.directives(prunedDirectives)); + } else if(node instanceof InlineFragment) { + return ((InlineFragment) node).transform(builder -> builder.directives(prunedDirectives)); + } else { + return ((FragmentSpread) node).transform(builder -> builder.directives(prunedDirectives)); + } + } + + /** + * Check if the selection has children selections + * @param node selection that you are trying to check + * @return boolean: true if selection doesn't have selectionset as child, false otherwise + * */ + public static boolean isLeaf(Selection node) { + return node.getChildren().isEmpty() || + node.getChildren() + .stream() + .noneMatch(SelectionSet.class::isInstance); + } + + + /** + * Generates new node + * Transforms the parentNode with a new selection set consisting of the pruned child and typename fields + * @param parentNode node that will be transformed and will contain only selections + * @param selections selections that need to be in selection set for parentNode + * @return node that only contains the passed in selections + * */ + public static Node transformNodeWithSelections(Node parentNode, Selection... selections) { + SelectionSet prunedSelectionSet = SelectionSet.newSelectionSet() + .selections(Arrays.asList(selections)) + .build(); + + if(parentNode instanceof Field) { + return ((Field) parentNode).transform(builder -> builder.selectionSet(prunedSelectionSet)); + } else if (parentNode instanceof FragmentDefinition) { + //add fragment spread names here in case of nested fragment spreads + return ((FragmentDefinition) parentNode).transform(builder -> builder.selectionSet(prunedSelectionSet)); + } else if (parentNode instanceof InlineFragment) { + return ((InlineFragment) parentNode).transform(builder -> builder.selectionSet(prunedSelectionSet)); + } else { + return ((OperationDefinition) parentNode).transform(builder -> builder.selectionSet(prunedSelectionSet)); + } + } + + /** + * Retrieves the objects from map matching key + * @param kvMap: map that will be searched + * @param keyCollection keys to check + * @return list of values from map with matching key + * */ + public static List getAllMapValuesWithMatchingKeys(Map kvMap, Collection keyCollection) { + return keyCollection + .stream() + .map(kvMap::get) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + } +} diff --git a/src/main/java/com/intuit/graphql/orchestrator/utils/QueryPathUtils.java b/src/main/java/com/intuit/graphql/orchestrator/utils/QueryPathUtils.java index ee81538a..89f1f0e3 100644 --- a/src/main/java/com/intuit/graphql/orchestrator/utils/QueryPathUtils.java +++ b/src/main/java/com/intuit/graphql/orchestrator/utils/QueryPathUtils.java @@ -4,11 +4,12 @@ import graphql.language.FragmentDefinition; import graphql.language.Node; import graphql.util.TraverserContext; +import org.apache.commons.lang3.StringUtils; + import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.stream.Collectors; -import org.apache.commons.lang3.StringUtils; public class QueryPathUtils { diff --git a/src/main/java/com/intuit/graphql/orchestrator/utils/TraverserContextUtils.java b/src/main/java/com/intuit/graphql/orchestrator/utils/TraverserContextUtils.java new file mode 100644 index 00000000..715bb198 --- /dev/null +++ b/src/main/java/com/intuit/graphql/orchestrator/utils/TraverserContextUtils.java @@ -0,0 +1,22 @@ +package com.intuit.graphql.orchestrator.utils; + +import graphql.language.Node; +import graphql.language.SelectionSetContainer; +import graphql.util.TraverserContext; + +import java.util.List; +import java.util.stream.Collectors; + +public class TraverserContextUtils { + /** + * Returns the current nodes parent node definitions + * @param currentNode context for current node + * @return List of graphql definitions + * */ + public static List getParentDefinitions(TraverserContext currentNode) { + return currentNode.getParentNodes() + .stream() + .filter(SelectionSetContainer.class::isInstance) + .collect(Collectors.toList()); + } +} diff --git a/src/main/java/com/intuit/graphql/orchestrator/visitors/queryVisitors/DeferQueryExtractor.java b/src/main/java/com/intuit/graphql/orchestrator/visitors/queryVisitors/DeferQueryExtractor.java new file mode 100644 index 00000000..6bb693c8 --- /dev/null +++ b/src/main/java/com/intuit/graphql/orchestrator/visitors/queryVisitors/DeferQueryExtractor.java @@ -0,0 +1,170 @@ +package com.intuit.graphql.orchestrator.visitors.queryVisitors; + +import com.intuit.graphql.orchestrator.deferDirective.DeferOptions; +import graphql.ExecutionInput; +import graphql.analysis.QueryVisitorFieldEnvironment; +import graphql.analysis.QueryVisitorFragmentSpreadEnvironment; +import graphql.analysis.QueryVisitorInlineFragmentEnvironment; +import graphql.analysis.QueryVisitorStub; +import graphql.language.AstPrinter; +import graphql.language.Definition; +import graphql.language.Document; +import graphql.language.FragmentDefinition; +import graphql.language.FragmentSpread; +import graphql.language.Node; +import graphql.language.OperationDefinition; +import graphql.language.Selection; +import graphql.language.SelectionSet; +import graphql.util.TraverserContext; +import graphql.util.TreeTransformerUtil; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static com.intuit.graphql.orchestrator.deferDirective.DeferUtil.containsEnabledDeferDirective; +import static com.intuit.graphql.orchestrator.deferDirective.DeferUtil.hasNonDeferredSelection; +import static com.intuit.graphql.orchestrator.utils.DirectivesUtil.DEFER_DIRECTIVE_NAME; +import static com.intuit.graphql.orchestrator.utils.GraphQLUtil.AST_TRANSFORMER; +import static com.intuit.graphql.orchestrator.utils.IntrospectionUtil.__typenameField; +import static com.intuit.graphql.orchestrator.utils.NodeUtils.getAllMapValuesWithMatchingKeys; +import static com.intuit.graphql.orchestrator.utils.NodeUtils.isLeaf; +import static com.intuit.graphql.orchestrator.utils.NodeUtils.removeDirectiveFromNode; +import static com.intuit.graphql.orchestrator.utils.NodeUtils.transformNodeWithSelections; +import static com.intuit.graphql.orchestrator.utils.TraverserContextUtils.getParentDefinitions; + +@Builder +public class DeferQueryExtractor extends QueryVisitorStub { + + @NonNull private final Document rootNode; + @NonNull private final OperationDefinition operationDefinition; + @Builder.Default + @NonNull + private Map fragmentDefinitionMap = new HashMap<>(); + @NonNull private final PruneChildDeferSelectionsModifier childModifier; + @NonNull private ExecutionInput originalEI; + + @NonNull private DeferOptions deferOptions; + + @Getter + private final List extractedEIs = new ArrayList<>(); + + /** + * Functions: + * Updates the field if it contains defer directive. + * If it is a valid deferred node, generate new EI and add to list to be extracted + * @param queryVisitorFieldEnvironment field that is being visited + */ + @Override + public void visitField(QueryVisitorFieldEnvironment queryVisitorFieldEnvironment) { + updateDeferredInfoForNode(queryVisitorFieldEnvironment.getField(), queryVisitorFieldEnvironment.getTraverserContext()); + } + + /** + * Functions: + * Updates the inline fragment if it contains defer directive. + * If it is a valid deferred node, generate new EI and add to list to be extracted + * @param queryVisitorInlineFragmentEnvironment field that is being visited + */ + @Override + public void visitInlineFragment(QueryVisitorInlineFragmentEnvironment queryVisitorInlineFragmentEnvironment) { + updateDeferredInfoForNode(queryVisitorInlineFragmentEnvironment.getInlineFragment(), queryVisitorInlineFragmentEnvironment.getTraverserContext()); + } + + /** + * Functions: + * Updates the fragment spread if it contains defer directive. + * If it is a valid deferred node, generate new EI and add to list to be extracted + * @param queryVisitorFragmentSpreadEnvironment field that is being visited + */ + @Override + public void visitFragmentSpread(QueryVisitorFragmentSpreadEnvironment queryVisitorFragmentSpreadEnvironment) { + updateDeferredInfoForNode(queryVisitorFragmentSpreadEnvironment.getFragmentSpread(), queryVisitorFragmentSpreadEnvironment.getTraverserContext()); + } + + /** + * Updates node if it contains defer + * @param node node that is currently being visited + * @param context context for traversed node + * */ + private void updateDeferredInfoForNode(Selection node, TraverserContext context) { + if(containsEnabledDeferDirective(node)) { + Selection prunedNode = removeDirectiveFromNode(node, DEFER_DIRECTIVE_NAME); + + if(isLeaf(node) || hasNonDeferredSelection(node)) { + extractedEIs.add(generateDeferredEI(prunedNode, context)); + } + + //update node so children nodes has the correct definition + TreeTransformerUtil.changeNode(context, prunedNode); + } + } + + /** + * Generates an Execution Input given the node and the context + * @param context current selection + * @param currentNode context for the selection + * @return An ExecutionInput + * */ + private ExecutionInput generateDeferredEI(Selection currentNode, TraverserContext context) { + //prune defer information from children + Node prunedNode = AST_TRANSFORMER.transform(currentNode, childModifier); + List parentNodes = getParentDefinitions(context); + + Set neededFragmentSpreads = new HashSet<>(); + if(currentNode instanceof FragmentSpread) { + neededFragmentSpreads.add(((FragmentSpread) currentNode).getName()); + } + + //builds parent nodes with pruned information + for (Node parentNode : parentNodes) { + prunedNode = transformNodeWithSelections(parentNode, (Selection)prunedNode, __typenameField); + + if(parentNode instanceof FragmentSpread) { + neededFragmentSpreads.add(((FragmentSpread) parentNode).getName()); + } + } + + //Gets all the definitions for the fragment spreads + List fragmentSpreadDefs = getAllMapValuesWithMatchingKeys(fragmentDefinitionMap, neededFragmentSpreads); + + SelectionSet ss = SelectionSet.newSelectionSet().selection((Selection) prunedNode).build(); + //builds new OperationDefinition consisting with only pruned nodes + OperationDefinition newOperation = this.operationDefinition.transform(builder -> builder.selectionSet(ss)); + + List deferredDefinitions = new ArrayList<>(); + deferredDefinitions.add(newOperation); + deferredDefinitions.addAll(fragmentSpreadDefs); + + Document deferredDocument = this.rootNode.transform(builder -> builder.definitions(deferredDefinitions)); + + String query = AstPrinter.printAst(deferredDocument); + + return originalEI.transform(builder -> builder.query(query)); + } + + /** + * Builder for class + * */ + public static class DeferQueryExtractorBuilder { + PruneChildDeferSelectionsModifier childModifier; + DeferOptions deferOptions; + + public DeferQueryExtractorBuilder deferOptions(DeferOptions deferOptions) { + this.deferOptions = deferOptions; + return childModifier(PruneChildDeferSelectionsModifier.builder().deferOptions(deferOptions).build()); + } + + private DeferQueryExtractorBuilder childModifier(PruneChildDeferSelectionsModifier childModifier) { + this.childModifier = childModifier; + return this; + } + + } +} diff --git a/src/main/java/com/intuit/graphql/orchestrator/visitors/queryVisitors/PruneChildDeferSelectionsModifier.java b/src/main/java/com/intuit/graphql/orchestrator/visitors/queryVisitors/PruneChildDeferSelectionsModifier.java new file mode 100644 index 00000000..0cd171a5 --- /dev/null +++ b/src/main/java/com/intuit/graphql/orchestrator/visitors/queryVisitors/PruneChildDeferSelectionsModifier.java @@ -0,0 +1,122 @@ +package com.intuit.graphql.orchestrator.visitors.queryVisitors; + +import com.intuit.graphql.orchestrator.deferDirective.DeferOptions; +import graphql.GraphQLException; +import graphql.language.Directive; +import graphql.language.DirectivesContainer; +import graphql.language.Field; +import graphql.language.FragmentSpread; +import graphql.language.InlineFragment; +import graphql.language.Node; +import graphql.language.NodeVisitorStub; +import graphql.language.Selection; +import graphql.language.SelectionSet; +import graphql.util.TraversalControl; +import graphql.util.TraverserContext; +import lombok.Builder; + +import static com.intuit.graphql.orchestrator.deferDirective.DeferUtil.containsEnabledDeferDirective; +import static com.intuit.graphql.orchestrator.deferDirective.DeferUtil.hasNonDeferredSelection; +import static com.intuit.graphql.orchestrator.utils.DirectivesUtil.DEFER_DIRECTIVE_NAME; +import static com.intuit.graphql.orchestrator.utils.IntrospectionUtil.__typenameField; +import static com.intuit.graphql.orchestrator.utils.NodeUtils.isLeaf; +import static com.intuit.graphql.orchestrator.utils.NodeUtils.removeDirectiveFromNode; +import static graphql.util.TreeTransformerUtil.changeNode; +import static graphql.util.TreeTransformerUtil.deleteNode; + +@Builder +public class PruneChildDeferSelectionsModifier extends NodeVisitorStub { + private DeferOptions deferOptions; + + /** + * Visits Field and deletes defer information or throws an exception + * @param node current node traverser is visiting + * @param context context for the current node + * @return TraversalControl whether to continue or abort + * */ + @Override + public TraversalControl visitField(Field node, TraverserContext context) { + return deleteDeferIfExists(node, context); + } + + /** + * Visits FragmentSpread and deletes defer information or throws an exception + * @param node current node traverser is visiting + * @param context context for the current node + * @return TraversalControl whether to continue or abort + * */ + @Override + public TraversalControl visitFragmentSpread(FragmentSpread node, TraverserContext context) { + return deleteDeferIfExists(node, context); + } + + /** + * Visits Inline Fragment and deletes defer information or throws an exception + * @param node current node traverser is visiting + * @param context context for the current node + * @return TraversalControl whether to continue or abort + * */ + @Override + public TraversalControl visitInlineFragment(InlineFragment node, TraverserContext context) { + return deleteDeferIfExists(node, context); + } + + /** + * Visits SelectionSet and add typename to selection set + * @param node current node traverser is visiting + * @param context context for the current node + * @return TraversalControl whether to continue or abort + * */ + @Override + public TraversalControl visitSelectionSet(SelectionSet node, TraverserContext context) { + SelectionSet newSelectionSet = node.transform(builder -> builder.selection(__typenameField)); + + return changeNode(context, newSelectionSet); + } + + /** + * Visits Directive and deletes defer information + * @param node current node traverser is visiting + * @param context context for the current node + * @return TraversalControl whether to continue or abort + * */ + public TraversalControl visitDirective(Directive node, TraverserContext context) { + //removes unnecessary and disabled defer directive if exists + if(DEFER_DIRECTIVE_NAME.equals(node.getName())) { + deleteNode(context); + } + + return this.visitNode(node, context); + } + + /** + * deletes defer information or throws an exception for selection + * @param node current node traverser is visiting + * @param context context for the current node + * @return TraversalControl whether to continue or abort + * */ + private TraversalControl deleteDeferIfExists(Selection node, TraverserContext context) { + //skip if it does not have defer directive + if(((DirectivesContainer)node).hasDirective(DEFER_DIRECTIVE_NAME)) { + if(containsEnabledDeferDirective(node)) { + //if node has an enabled defer, check if option allows it + if(!this.deferOptions.isNestedDefersAllowed()) { + throw new GraphQLException("Nested defers are currently unavailable."); + } + + if(isLeaf(node) || hasNonDeferredSelection(node)) { + //delete node if it is enabled because extractor will create query for it + return deleteNode(context); + } else { + //remove directive so it is not included in downstream query + return changeNode(context, removeDirectiveFromNode(node, DEFER_DIRECTIVE_NAME)); + } + } else { + //remove directive so it is not included in downstream query + return changeNode(context, removeDirectiveFromNode(node, DEFER_DIRECTIVE_NAME)); + } + } + + return TraversalControl.CONTINUE; + } +} diff --git a/src/test/groovy/com/intuit/graphql/orchestrator/authorization/AuthDownstreamQueryRedactorVisitorSpec.groovy b/src/test/groovy/com/intuit/graphql/orchestrator/authorization/AuthDownstreamQueryRedactorVisitorSpec.groovy index 7ed843ed..29d85e7a 100644 --- a/src/test/groovy/com/intuit/graphql/orchestrator/authorization/AuthDownstreamQueryRedactorVisitorSpec.groovy +++ b/src/test/groovy/com/intuit/graphql/orchestrator/authorization/AuthDownstreamQueryRedactorVisitorSpec.groovy @@ -54,6 +54,56 @@ class AuthDownstreamQueryRedactorVisitorSpec extends Specification { } """ + def testDeferInlineFragmentQuery = """ + { + a { + ... on B1 @defer { + c1 { + s1 + } + } + b2 { i1 } + } + } + """ + + def testIfDeferInlineFragmentQuery = """ + { + a { + ... on B1 @defer(if:false) { + c1 { + s1 + } + } + b2 { i1 } + } + } + """ + + def testDeferSpreadFragmentQuery = """ + { + a { + ... DeferredFrag @defer + b2 { i1 } + } + } + fragment DeferredFrag on B1 { + c1 + } + """ + + def testIfDeferSpreadFragmentQuery = """ + { + a { + ... DeferredFrag @defer(if: false) + b2 { i1 } + } + } + fragment DeferredFrag on B1 { + c1 + } + """ + static final Object TEST_AUTH_DATA = "TestAuthDataCanBeAnyObject" Field mockField = Mock() @@ -211,4 +261,244 @@ class AuthDownstreamQueryRedactorVisitorSpec extends Specification { .createDeniedResult(testGraphqlErrorException) } + def "deferred inline fragments are removed"() { + given: + + Document document = new Parser().parseDocument(testDeferInlineFragmentQuery) + OperationDefinition operationDefinition = document.getDefinitionsOfType(OperationDefinition.class).get(0) + Field rootField = SelectionSetUtil.getFieldByPath(Arrays.asList("a"), operationDefinition.getSelectionSet()) + GraphQLFieldsContainer rootFieldParentType = (GraphQLFieldsContainer) testGraphQLSchema.getType("Query") + + AuthDownstreamQueryModifier specUnderTest = AuthDownstreamQueryModifier.builder() + .rootParentType((GraphQLFieldsContainer) rootFieldParentType) + .fieldAuthorization(mockFieldAuthorization) + .graphQLContext(mockGraphQLContext) + .queryVariables(Collections.emptyMap()) + .graphQLSchema(testGraphQLSchema) + .selectionCollector(new SelectionCollector(fragmentsByName)) + .serviceMetadata(mockServiceMetadata) + .authData(TEST_AUTH_DATA) + .build() + + when: + mockGraphQLContext.getOrDefault("useDefer", false) >> true + Field transformedField = (Field) astTransformer.transform(rootField, specUnderTest) + + then: + rootField.getName() == "a" + rootField.selectionSet.selections.size() == 2 + ((InlineFragment)(rootField.selectionSet.selections.get(0))).getTypeCondition().name == "B1" + ((Field)(rootField.selectionSet.selections.get(1))).getName() == "b2" + + transformedField.getName() == "a" + transformedField.selectionSet.selections.size() == 1 + ((Field)(transformedField.selectionSet.selections.get(0))).getName() == "b2" + + mockFieldAuthorization.authorize(queryA) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(aB1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(b1C1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(c1S1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(aB2) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(b2i1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + + mockRenamedMetadata.getOriginalFieldNamesByRenamedName() >> Collections.emptyMap() + mockServiceMetadata.getRenamedMetadata() >> mockRenamedMetadata + } + + def "deferred inline fragments with if argument as false are kept"() { + given: + + Document document = new Parser().parseDocument(testIfDeferInlineFragmentQuery) + OperationDefinition operationDefinition = document.getDefinitionsOfType(OperationDefinition.class).get(0) + Field rootField = SelectionSetUtil.getFieldByPath(Arrays.asList("a"), operationDefinition.getSelectionSet()) + GraphQLFieldsContainer rootFieldParentType = (GraphQLFieldsContainer) testGraphQLSchema.getType("Query") + + AuthDownstreamQueryModifier specUnderTest = AuthDownstreamQueryModifier.builder() + .rootParentType((GraphQLFieldsContainer) rootFieldParentType) + .fieldAuthorization(mockFieldAuthorization) + .graphQLContext(mockGraphQLContext) + .queryVariables(Collections.emptyMap()) + .graphQLSchema(testGraphQLSchema) + .selectionCollector(new SelectionCollector(fragmentsByName)) + .serviceMetadata(mockServiceMetadata) + .authData(TEST_AUTH_DATA) + .build() + + when: + mockGraphQLContext.getOrDefault("useDefer", false) >> true + Field transformedField = (Field) astTransformer.transform(rootField, specUnderTest) + + then: + rootField.getName() == "a" + rootField.selectionSet.selections.size() == 2 + ((InlineFragment)(rootField.selectionSet.selections.get(0))).getTypeCondition().name == "B1" + ((Field)(rootField.selectionSet.selections.get(1))).getName() == "b2" + + transformedField.getName() == "a" + transformedField.selectionSet.selections.size() == 2 + ((InlineFragment)(transformedField.selectionSet.selections.get(0))).getTypeCondition().name == "B1" + ((Field)(transformedField.selectionSet.selections.get(1))).getName() == "b2" + + mockFieldAuthorization.authorize(queryA) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(aB1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(b1C1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(c1S1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(aB2) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(b2i1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + + mockRenamedMetadata.getOriginalFieldNamesByRenamedName() >> Collections.emptyMap() + mockServiceMetadata.getRenamedMetadata() >> mockRenamedMetadata + } + + def "deferred fragment spreads are removed"() { + given: + + Document document = new Parser().parseDocument(testDeferSpreadFragmentQuery) + OperationDefinition operationDefinition = document.getDefinitionsOfType(OperationDefinition.class).get(0) + Field rootField = SelectionSetUtil.getFieldByPath(Arrays.asList("a"), operationDefinition.getSelectionSet()) + GraphQLFieldsContainer rootFieldParentType = (GraphQLFieldsContainer) testGraphQLSchema.getType("Query") + FragmentDefinition mockedFragmentDef = Mock() + HashMap fragsByName = ["DeferredFrag": mockedFragmentDef] + Field mockCField = Mock() + SelectionSet selectionSet = SelectionSet.newSelectionSet().selection(mockCField).build() + + AuthDownstreamQueryModifier specUnderTest = AuthDownstreamQueryModifier.builder() + .rootParentType((GraphQLFieldsContainer) rootFieldParentType) + .fieldAuthorization(mockFieldAuthorization) + .graphQLContext(mockGraphQLContext) + .queryVariables(Collections.emptyMap()) + .graphQLSchema(testGraphQLSchema) + .selectionCollector(new SelectionCollector(fragsByName)) + .serviceMetadata(mockServiceMetadata) + .authData(TEST_AUTH_DATA) + .build() + + when: + mockGraphQLContext.getOrDefault("useDefer", false) >> true + mockedFragmentDef.getSelectionSet() >> selectionSet + Field transformedField = (Field) astTransformer.transform(rootField, specUnderTest) + + then: + rootField.getName() == "a" + rootField.selectionSet.selections.size() == 2 + ((FragmentSpread)(rootField.selectionSet.selections.get(0))).getName() == "DeferredFrag" + ((Field)(rootField.selectionSet.selections.get(1))).getName() == "b2" + + transformedField.getName() == "a" + transformedField.selectionSet.selections.size() == 1 + ((Field)(transformedField.selectionSet.selections.get(0))).getName() == "b2" + + specUnderTest.fragmentSpreadsRemoved.size() == 1 + specUnderTest.fragmentSpreadsRemoved.get(0) == "DeferredFrag" + + mockFieldAuthorization.authorize(queryA) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(aB1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(b1C1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(c1S1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(aB2) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(b2i1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + + mockRenamedMetadata.getOriginalFieldNamesByRenamedName() >> Collections.emptyMap() + mockServiceMetadata.getRenamedMetadata() >> mockRenamedMetadata + } + + def "deferred fragment spreads with if argument as false are kept"() { + given: + + Document document = new Parser().parseDocument(testIfDeferSpreadFragmentQuery) + OperationDefinition operationDefinition = document.getDefinitionsOfType(OperationDefinition.class).get(0) + Field rootField = SelectionSetUtil.getFieldByPath(Arrays.asList("a"), operationDefinition.getSelectionSet()) + GraphQLFieldsContainer rootFieldParentType = (GraphQLFieldsContainer) testGraphQLSchema.getType("Query") + FragmentDefinition mockedFragmentDef = Mock() + HashMap fragsByName = ["DeferredFrag": mockedFragmentDef] + Field mockCField = Mock() + SelectionSet selectionSet = SelectionSet.newSelectionSet().selection(mockCField).build() + + AuthDownstreamQueryModifier specUnderTest = AuthDownstreamQueryModifier.builder() + .rootParentType((GraphQLFieldsContainer) rootFieldParentType) + .fieldAuthorization(mockFieldAuthorization) + .graphQLContext(mockGraphQLContext) + .queryVariables(Collections.emptyMap()) + .graphQLSchema(testGraphQLSchema) + .selectionCollector(new SelectionCollector(fragsByName)) + .serviceMetadata(mockServiceMetadata) + .authData(TEST_AUTH_DATA) + .build() + + when: + mockGraphQLContext.getOrDefault("useDefer", false) >> true + mockedFragmentDef.getSelectionSet() >> selectionSet + Field transformedField = (Field) astTransformer.transform(rootField, specUnderTest) + + then: + rootField.getName() == "a" + rootField.selectionSet.selections.size() == 2 + ((FragmentSpread)(rootField.selectionSet.selections.get(0))).getName() == "DeferredFrag" + ((Field)(rootField.selectionSet.selections.get(1))).getName() == "b2" + + transformedField.getName() == "a" + transformedField.selectionSet.selections.size() == 2 + ((FragmentSpread)(transformedField.selectionSet.selections.get(0))).getName() == "DeferredFrag" + ((Field)(transformedField.selectionSet.selections.get(1))).getName() == "b2" + + mockFieldAuthorization.authorize(queryA) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(aB1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(b1C1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(c1S1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(aB2) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(b2i1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + + mockRenamedMetadata.getOriginalFieldNamesByRenamedName() >> Collections.emptyMap() + mockServiceMetadata.getRenamedMetadata() >> mockRenamedMetadata + } + + def "Exception when defer directive is not in supported location"() { + given: + + Document document = new Parser().parseDocument(testIfDeferSpreadFragmentQuery) + OperationDefinition operationDefinition = document.getDefinitionsOfType(OperationDefinition.class).get(0) + Field rootField = SelectionSetUtil.getFieldByPath(Arrays.asList("a"), operationDefinition.getSelectionSet()) + GraphQLFieldsContainer rootFieldParentType = (GraphQLFieldsContainer) testGraphQLSchema.getType("Query") + FragmentDefinition mockedFragmentDef = Mock() + HashMap fragsByName = ["DeferredFrag": mockedFragmentDef] + Field mockCField = Mock() + SelectionSet selectionSet = SelectionSet.newSelectionSet().selection(mockCField).build() + + AuthDownstreamQueryModifier specUnderTest = AuthDownstreamQueryModifier.builder() + .rootParentType((GraphQLFieldsContainer) rootFieldParentType) + .fieldAuthorization(mockFieldAuthorization) + .graphQLContext(mockGraphQLContext) + .queryVariables(Collections.emptyMap()) + .graphQLSchema(testGraphQLSchema) + .selectionCollector(new SelectionCollector(fragsByName)) + .serviceMetadata(mockServiceMetadata) + .authData(TEST_AUTH_DATA) + .build() + + when: + mockGraphQLContext.getOrDefault("useDefer", false) >> true + mockedFragmentDef.getSelectionSet() >> selectionSet + Field transformedField = (Field) astTransformer.transform(rootField, specUnderTest) + + then: + rootField.getName() == "a" + rootField.selectionSet.selections.size() == 2 + ((FragmentSpread)(rootField.selectionSet.selections.get(0))).getName() == "DeferredFrag" + ((Field)(rootField.selectionSet.selections.get(1))).getName() == "b2" + + transformedField.getName() == "a" + transformedField.selectionSet.selections.size() == 2 + ((FragmentSpread)(transformedField.selectionSet.selections.get(0))).getName() == "DeferredFrag" + ((Field)(transformedField.selectionSet.selections.get(1))).getName() == "b2" + + mockFieldAuthorization.authorize(queryA) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(aB1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(b1C1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(c1S1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(aB2) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockFieldAuthorization.authorize(b2i1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + + mockRenamedMetadata.getOriginalFieldNamesByRenamedName() >> Collections.emptyMap() + mockServiceMetadata.getRenamedMetadata() >> mockRenamedMetadata + } } diff --git a/src/test/groovy/com/intuit/graphql/orchestrator/batch/DownstreamQueryModifierSpec.groovy b/src/test/groovy/com/intuit/graphql/orchestrator/batch/DownstreamQueryModifierSpec.groovy index 138b91ff..4071f750 100644 --- a/src/test/groovy/com/intuit/graphql/orchestrator/batch/DownstreamQueryModifierSpec.groovy +++ b/src/test/groovy/com/intuit/graphql/orchestrator/batch/DownstreamQueryModifierSpec.groovy @@ -10,6 +10,7 @@ import com.intuit.graphql.orchestrator.resolverdirective.DownstreamQueryModifier import com.intuit.graphql.orchestrator.resolverdirective.DownstreamQueryModifierTestHelper.TestService import com.intuit.graphql.orchestrator.schema.ServiceMetadataImpl import com.intuit.graphql.orchestrator.schema.transform.FieldResolverContext +import graphql.GraphQLContext import graphql.language.AstTransformer import graphql.language.Field import graphql.language.FragmentDefinition @@ -91,7 +92,7 @@ class DownstreamQueryModifierSpec extends Specification { serviceMetadataMock.isOwnedByEntityExtension(_) >> false serviceMetadataMock.shouldModifyDownStreamQuery() >> true - subjectUnderTest = new DownstreamQueryModifier(aType, serviceMetadataMock, Collections.emptyMap(), graphQLSchema) + subjectUnderTest = new DownstreamQueryModifier(aType, serviceMetadataMock, Collections.emptyMap(), graphQLSchema, GraphQLContext.newContext().build()) } def "can Remove Field"() { diff --git a/src/test/groovy/com/intuit/graphql/orchestrator/integration/DownstreamQueryModifierUnionRootSpec.groovy b/src/test/groovy/com/intuit/graphql/orchestrator/integration/DownstreamQueryModifierUnionRootSpec.groovy index 9bedce64..f0fc41f3 100644 --- a/src/test/groovy/com/intuit/graphql/orchestrator/integration/DownstreamQueryModifierUnionRootSpec.groovy +++ b/src/test/groovy/com/intuit/graphql/orchestrator/integration/DownstreamQueryModifierUnionRootSpec.groovy @@ -3,6 +3,7 @@ package com.intuit.graphql.orchestrator.integration import com.intuit.graphql.orchestrator.batch.DownstreamQueryModifier import com.intuit.graphql.orchestrator.metadata.RenamedMetadata import com.intuit.graphql.orchestrator.schema.ServiceMetadata +import graphql.GraphQLContext import graphql.Scalars import graphql.language.* import graphql.schema.* @@ -73,7 +74,7 @@ class DownstreamQueryModifierUnionRootSpec extends Specification { Map emptyFragmentsByName = emptyMap() queryModifierWithUnionRootType = new DownstreamQueryModifier(rootType, serviceMetadataMock, - emptyFragmentsByName, graphQLSchemaMock) + emptyFragmentsByName, graphQLSchemaMock, GraphQLContext.newContext().build()) } diff --git a/src/test/groovy/com/intuit/graphql/orchestrator/integration/QueryDirectiveSpec.groovy b/src/test/groovy/com/intuit/graphql/orchestrator/integration/QueryDirectiveSpec.groovy index 18e903f6..9c1840d8 100644 --- a/src/test/groovy/com/intuit/graphql/orchestrator/integration/QueryDirectiveSpec.groovy +++ b/src/test/groovy/com/intuit/graphql/orchestrator/integration/QueryDirectiveSpec.groovy @@ -3,15 +3,20 @@ package com.intuit.graphql.orchestrator.integration import com.google.common.collect.ImmutableMap import com.intuit.graphql.orchestrator.GraphQLOrchestrator import com.intuit.graphql.orchestrator.ServiceProvider +import com.intuit.graphql.orchestrator.deferDirective.DeferOptions +import com.intuit.graphql.orchestrator.testhelpers.MockServiceProvider import com.intuit.graphql.orchestrator.testhelpers.SimpleMockServiceProvider import com.intuit.graphql.orchestrator.utils.GraphQLUtil import com.intuit.graphql.orchestrator.utils.SelectionSetUtil import graphql.ExecutionInput +import graphql.ExecutionResult +import graphql.execution.reactive.SubscriptionPublisher import graphql.language.Document import graphql.language.Field import graphql.language.OperationDefinition import graphql.parser.Parser import helpers.BaseIntegrationTestSpecification +import reactor.core.publisher.Flux class QueryDirectiveSpec extends BaseIntegrationTestSpecification { @@ -256,4 +261,112 @@ class QueryDirectiveSpec extends BaseIntegrationTestSpecification { users_lastName_field.getDirectives().get(0).getArgument("if") != null } + def "test defer directive on Query"() { + given: + + def initialEI = "query getPetsDeferred {pets {id name}}" + def deferredEI = "query getPetsDeferred {pets {type __typename}}" + + def petsInitResponse = [ + data: [ + pets: [ + [id: "pet-1", name: "Charlie"], + [id: "pet-2", name: "Milo"], + [id: "pet-3", name: "Poppy"] + ] + ] + ] + + def petsDeferResponse = [ + data: [ + pets: [ + [type: "DOG", "__typename": "Pet" ], + [type: "RABBIT", "__typename": "Pet"], + [type: "CAT", "__typename": "Pet"] + ] + ] + ] + MockServiceProvider petsService = createQueryMatchingService("petsService", + petsSchema, + [ + (initialEI) :petsInitResponse, + (deferredEI) : petsDeferResponse + ] + ) + + ServiceProvider[] services = [ petsService ] + specUnderTest = createGraphQLOrchestrator(services) + + when: + ExecutionInput petsEI = ExecutionInput.newExecutionInput() + .query(''' + query getPetsDeferred { + pets { + id + name + type @defer + } + } + ''').build() + + DeferOptions deferOptions = DeferOptions.builder() + .nestedDefersAllowed(true) + .build() + + ExecutionResult executionResult = specUnderTest.execute(petsEI, deferOptions,true).get() + SubscriptionPublisher subscriptionPublisher = (SubscriptionPublisher)executionResult.data + Flux publisher = (Flux) subscriptionPublisher.upstreamPublisher; + List results = (List) publisher.collectList().block() + + then: + results.size() == 2 + results.get(0).errors.size() == 0 + results.get(0).data != null + + Map initDataValue = (Map) results.get(0).data + initDataValue.keySet().contains("pets") + initDataValue.get("pets").size() == 3 + + Map initPet1 = initDataValue.get("pets").get(0) + initPet1.get("id") == "pet-1" + initPet1.get("name") == "Charlie" + initPet1.keySet().contains("type") == true + initPet1.get("type") == null + + Map initPet2 = initDataValue.get("pets").get(1) + initPet2.get("id") == "pet-2" + initPet2.get("name") == "Milo" + initPet2.keySet().contains("type") == true + initPet2.get("type") == null + + Map initPet3 = initDataValue.get("pets").get(2) + initPet3.get("id") == "pet-3" + initPet3.get("name") == "Poppy" + initPet3.keySet().contains("type") == true + initPet3.get("type") == null + + Map deferDataValue = (Map) results.get(1).data + deferDataValue.keySet().contains("pets") + deferDataValue.get("pets").size() == 3 + + Map deferPet1 = deferDataValue.get("pets").get(0) + deferPet1.keySet().contains("id") == false + deferPet1.keySet().contains("name") == false + deferPet1.get("type") == "DOG" + deferPet1.get("__typename") == "Pet" + + Map deferPet2 = deferDataValue.get("pets").get(1) + deferPet2.keySet().contains("id") == false + deferPet2.keySet().contains("name") == false + deferPet2.get("type") == "RABBIT" + deferPet2.get("__typename") == "Pet" + + Map deferPet3 = deferDataValue.get("pets").get(2) + deferPet3.keySet().contains("id") == false + deferPet3.keySet().contains("name") == false + deferPet3.get("type") == "CAT" + deferPet3.get("__typename") == "Pet" + + } + } diff --git a/src/test/groovy/com/intuit/graphql/orchestrator/integration/authorization/FieldAuthorizationNestedFieldsSpec.groovy b/src/test/groovy/com/intuit/graphql/orchestrator/integration/authorization/FieldAuthorizationNestedFieldsSpec.groovy index 6f43b8a7..981a5572 100644 --- a/src/test/groovy/com/intuit/graphql/orchestrator/integration/authorization/FieldAuthorizationNestedFieldsSpec.groovy +++ b/src/test/groovy/com/intuit/graphql/orchestrator/integration/authorization/FieldAuthorizationNestedFieldsSpec.groovy @@ -92,6 +92,8 @@ class FieldAuthorizationNestedFieldsSpec extends BaseIntegrationTestSpecificatio Map argumentValues = Collections.emptyMap(); void setup() { + mockGraphQLContext.getOrDefault("useDefer", false) >> false + testServiceA = createSimpleMockService("testServiceA", testSchemaA, mockServiceResponseA) testServiceB = createSimpleMockService("testServiceB", testSchemaB, mockServiceResponseB) testServiceC = createSimpleMockService("testServiceC", testSchemaC, mockServiceResponseC) diff --git a/src/test/groovy/com/intuit/graphql/orchestrator/integration/authorization/FieldAuthorizationTopLevelFieldsSpec.groovy b/src/test/groovy/com/intuit/graphql/orchestrator/integration/authorization/FieldAuthorizationTopLevelFieldsSpec.groovy index e7d53dfe..ab98f6ab 100644 --- a/src/test/groovy/com/intuit/graphql/orchestrator/integration/authorization/FieldAuthorizationTopLevelFieldsSpec.groovy +++ b/src/test/groovy/com/intuit/graphql/orchestrator/integration/authorization/FieldAuthorizationTopLevelFieldsSpec.groovy @@ -79,6 +79,8 @@ class FieldAuthorizationTopLevelFieldsSpec extends BaseIntegrationTestSpecificat Map argumentValues = Collections.emptyMap(); void setup() { + mockGraphQLContext.getOrDefault("useDefer", false) >> false + testServiceA = createSimpleMockService("testServiceA", testSchemaA, mockServiceResponseA) testServiceB = createSimpleMockService("testServiceB", testSchemaB, mockServiceResponseB) testServiceC = createQueryMatchingService("testServiceC", testSchemaC, mockServiceResponseC) diff --git a/src/test/groovy/com/intuit/graphql/orchestrator/utils/DeferUtilSpec.groovy b/src/test/groovy/com/intuit/graphql/orchestrator/utils/DeferUtilSpec.groovy new file mode 100644 index 00000000..b9a8cb26 --- /dev/null +++ b/src/test/groovy/com/intuit/graphql/orchestrator/utils/DeferUtilSpec.groovy @@ -0,0 +1,167 @@ +package com.intuit.graphql.orchestrator.utils + +import graphql.language.* +import helpers.BaseIntegrationTestSpecification + +import static com.intuit.graphql.orchestrator.deferDirective.DeferUtil.containsEnabledDeferDirective +import static com.intuit.graphql.orchestrator.deferDirective.DeferUtil.hasNonDeferredSelection + +class DeferUtilSpec extends BaseIntegrationTestSpecification{ + Directive enabledDefer = Directive.newDirective().name("defer").build() + Directive disabledDefer = Directive.newDirective() + .name("defer") + .argument( + Argument.newArgument("if", BooleanValue.of(false)) + .build()) + .build() + + def "hasNonDeferredSelection throws exception if node is null"(){ + when: + hasNonDeferredSelection(null) + + then: + thrown(NullPointerException) + } + + def "hasNonDeferredSelection return true if node has child selections do not have defer"(){ + when: + Field childField1 = Field.newField("childField").build() + Field childField2 = Field.newField("childField").build() + SelectionSet ss = SelectionSet.newSelectionSet() + .selection(childField1) + .selection(childField2) + .build() + Field parentField = Field.newField("parentNode", ss).build() + + then: + hasNonDeferredSelection(parentField) + } + + def "hasNonDeferredSelection return true if node has child selections with one defer and others arent"(){ + when: + Field childField1 = Field.newField("childField").directive(enabledDefer).build() + Field childField2 = Field.newField("childField").build() + SelectionSet ss = SelectionSet.newSelectionSet() + .selection(childField1) + .selection(childField2) + .build() + Field parentField = Field.newField("parentNode", ss).build() + + then: + hasNonDeferredSelection(parentField) + } + + def "hasNonDeferredSelection return true if node has child selections that are disabled"(){ + when: + Field childField1 = Field.newField("childField").directive(disabledDefer).build() + Field childField2 = Field.newField("childField").build() + SelectionSet ss = SelectionSet.newSelectionSet() + .selection(childField1) + .selection(childField2) + .build() + Field parentField = Field.newField("parentNode", ss).build() + + then: + hasNonDeferredSelection(parentField) + } + + def "hasNonDeferredSelection return false if node has child selections and all are deferred"(){ + when: + Field childField1 = Field.newField("childField1").directive(enabledDefer).build() + Field childField2 = Field.newField("childField2").directive(enabledDefer).build() + SelectionSet ss = SelectionSet.newSelectionSet() + .selection(childField1) + .selection(childField2) + .build() + Field parentField = Field.newField("parentNode", ss).build() + + then: + !hasNonDeferredSelection(parentField) + } + + def "containsEnabledDeferDirective returns false for non DirectiveContainer"() { + when: + SelectionCollectorSpec.NewImplementationSelection node = new SelectionCollectorSpec.NewImplementationSelection() + + then: + !containsEnabledDeferDirective(node) + } + + def "FragmentSpread with defer returns true"() { + when: + FragmentSpread node = FragmentSpread.newFragmentSpread("testFragment").directive(enabledDefer).build() + + then: + containsEnabledDeferDirective(node) + } + + def "FragmentSpread with disabled defer returns false"() { + when: + FragmentSpread node = FragmentSpread.newFragmentSpread("testFragment").directive(disabledDefer).build() + + then: + !containsEnabledDeferDirective(node) + } + + def "FragmentSpread without defer directive returns false"() { + when: + FragmentSpread node = FragmentSpread.newFragmentSpread("testFragment") + .directive(Directive.newDirective().name("testDir").build()) + .build() + + then: + !containsEnabledDeferDirective(node) + } + + def "InlineFragment with defer returns true"() { + when: + InlineFragment node = InlineFragment.newInlineFragment().directive(enabledDefer).build() + + then: + containsEnabledDeferDirective((DirectivesContainer)node) + } + + def "InlineFragment with disabled defer returns false"() { + when: + InlineFragment node = InlineFragment.newInlineFragment().directive(disabledDefer).build() + + then: + !containsEnabledDeferDirective((DirectivesContainer)node) + } + + def "InlineFragment without defer directive returns false"() { + when: + InlineFragment node = InlineFragment.newInlineFragment() + .directive(Directive.newDirective().name("testDir").build()) + .build() + + then: + !containsEnabledDeferDirective((DirectivesContainer)node) + } + + def "Field with defer returns true"() { + when: + Field node = Field.newField("testField").directive(enabledDefer).build() + + then: + containsEnabledDeferDirective((Selection)node) + } + + def "Field with disabled defer returns false"() { + when: + Field node = Field.newField("testField").directive(disabledDefer).build() + + then: + !containsEnabledDeferDirective((Selection)node) + } + + def "Field without defer directive returns false"() { + when: + Field node = Field.newField("testField") + .directive(Directive.newDirective().name("testDir").build()) + .build() + + then: + !containsEnabledDeferDirective(node) + } +} diff --git a/src/test/groovy/com/intuit/graphql/orchestrator/utils/MultiEIGeneratorSpec.groovy b/src/test/groovy/com/intuit/graphql/orchestrator/utils/MultiEIGeneratorSpec.groovy new file mode 100644 index 00000000..ce99f359 --- /dev/null +++ b/src/test/groovy/com/intuit/graphql/orchestrator/utils/MultiEIGeneratorSpec.groovy @@ -0,0 +1,170 @@ +package com.intuit.graphql.orchestrator.utils + +import com.intuit.graphql.orchestrator.deferDirective.DeferOptions +import graphql.ExecutionInput +import graphql.parser.InvalidSyntaxException +import graphql.scalar.GraphqlStringCoercing +import graphql.schema.GraphQLFieldDefinition +import graphql.schema.GraphQLObjectType +import graphql.schema.GraphQLScalarType +import graphql.schema.GraphQLSchema +import reactor.test.StepVerifier +import spock.lang.Specification + +class MultiEIGeneratorSpec extends Specification { + + MultiEIGenerator multiEIGenerator + DeferOptions options = DeferOptions.builder() + .nestedDefersAllowed(true) + .build() + + GraphQLScalarType scalarType = GraphQLScalarType.newScalar() + .name("scale") + .coercing(new GraphqlStringCoercing()) + .build() + + GraphQLFieldDefinition idField = GraphQLFieldDefinition.newFieldDefinition() + .name("id") + .type(scalarType) + .build() + + GraphQLFieldDefinition nameField = GraphQLFieldDefinition.newFieldDefinition() + .name("name") + .type(scalarType) + .build() + + GraphQLFieldDefinition typeField = GraphQLFieldDefinition.newFieldDefinition() + .name("type") + .type(scalarType) + .build() + + GraphQLObjectType petType = GraphQLObjectType.newObject() + .name("Pet") + .field(idField) + .field(nameField) + .field(typeField) + .build() + + GraphQLFieldDefinition petsQuery = GraphQLFieldDefinition.newFieldDefinition() + .name("pets") + .type(petType) + .build() + + GraphQLObjectType queryType = GraphQLObjectType.newObject() + .name("query") + .field(petsQuery) + .build() + + GraphQLSchema schema = GraphQLSchema.newSchema().query(queryType).additionalType(petType).build() + + + + def "Generator split query correctly"() { + given: + String query = ''' + query getPetsDeferred { + pets { + id + name + type @defer + } + } + ''' + + String deferredQuery = "query getPetsDeferred {\n" + + " pets {\n" + + " type\n" + + " __typename\n" + + " }\n" + + "}\n" + + ExecutionInput ei = ExecutionInput.newExecutionInput(query).build() + + when: + multiEIGenerator = new MultiEIGenerator(ei, options, schema) + + then: + StepVerifier.create(multiEIGenerator.generateEIs()) + .expectNextMatches({ response -> response.getQuery() == query }) + .expectNextMatches({ response -> response.getQuery() == deferredQuery }) + .verifyComplete() + } + + def "Initial EI is processed before splitting query"(){ + given: + String query = ''' + query getPetsDeferred { + pets { + id + name + type @defer + } + } + ''' + + String deferredQuery = "query getPetsDeferred {\n" + + " pets {\n" + + " type\n" + + " __typename\n" + + " }\n" + + "}\n" + + ExecutionInput ei = ExecutionInput.newExecutionInput(query).build() + + long timeEmitted = 0; + + when: + multiEIGenerator = new MultiEIGenerator(ei, options, schema) + + then: + StepVerifier.create(multiEIGenerator.generateEIs()) + .expectNextMatches({ response -> + timeEmitted = System.currentTimeMillis() + return response.query == query && response + }) + .expectAccessibleContext() + .assertThat({ e -> this.multiEIGenerator.timeProcessedSplit > timeEmitted }) + .then() + .expectNextMatches({ response -> response.getQuery() == deferredQuery }) + .verifyComplete() + } + + def "EI w/o defer flux completes and only emits 1 object"(){ + given: + String query = ''' + query getPetsDeferred { + pets { + id + name + type + } + } + ''' + + ExecutionInput ei = ExecutionInput.newExecutionInput(query).build() + + when: + multiEIGenerator = new MultiEIGenerator(ei, options, schema) + + then: + StepVerifier.create(multiEIGenerator.generateEIs()) + .expectNextMatches({ response -> response.query == query}) + .verifyComplete() + } + + def "emits error if it throws error when trying to split ei"(){ + given: + String query = "" + + ExecutionInput ei = ExecutionInput.newExecutionInput(query).build() + + when: + multiEIGenerator = new MultiEIGenerator(ei, options, schema) + + then: + StepVerifier.create(multiEIGenerator.generateEIs()) + .expectNextMatches( {emptyEi -> emptyEi.getQuery() == query}) + .expectError(InvalidSyntaxException.class) + .verify() + } +} diff --git a/src/test/groovy/com/intuit/graphql/orchestrator/utils/NodeUtilsSpec.groovy b/src/test/groovy/com/intuit/graphql/orchestrator/utils/NodeUtilsSpec.groovy new file mode 100644 index 00000000..bcad31a4 --- /dev/null +++ b/src/test/groovy/com/intuit/graphql/orchestrator/utils/NodeUtilsSpec.groovy @@ -0,0 +1,62 @@ +package com.intuit.graphql.orchestrator.utils + +import graphql.language.Directive +import graphql.language.Field +import graphql.language.SelectionSet +import helpers.BaseIntegrationTestSpecification + +import static NodeUtils.removeDirectiveFromNode + +class NodeUtilsSpec extends BaseIntegrationTestSpecification { + def "removeDirectiveFromNode throws exception if node is null"(){ + when: + removeDirectiveFromNode(null, "") + then: + thrown(RuntimeException) + } + + def "removeDirectiveFromNode throws exception if node isn't a directiveContainer"(){ + when: + SelectionSet selectionSet = SelectionSet.newSelectionSet().build() + removeDirectiveFromNode(selectionSet, "") + then: + thrown(RuntimeException) + } + + def "removeDirectiveFromNode returns original node if directive isn't found"(){ + given: + Directive dir1 = Directive.newDirective().name("test1").build() + Directive dir2 = Directive.newDirective().name("test2").build() + Field testField = Field.newField("testField") + .directive(dir1) + .directive(dir2) + .build() + + when: + Field result = removeDirectiveFromNode(testField, "test3") + + then: + result.getName() == "testField" + result.getDirectives().size() == 2 + result.getDirectives().get(0).name == "test1" + result.getDirectives().get(1).name == "test2" + } + + def "removeDirectiveFromNode returns node without desired directive name"(){ + given: + Directive dir1 = Directive.newDirective().name("test1").build() + Directive dir2 = Directive.newDirective().name("test2").build() + Field testField = Field.newField("testField") + .directive(dir1) + .directive(dir2) + .build() + + when: + Field result = removeDirectiveFromNode(testField, "test1") + + then: + result.getName() == "testField" + result.getDirectives().size() == 1 + result.getDirectives().get(0).name == "test2" + } +} diff --git a/src/test/groovy/com/intuit/graphql/orchestrator/visitors/DeferQueryExtractorSpec.groovy b/src/test/groovy/com/intuit/graphql/orchestrator/visitors/DeferQueryExtractorSpec.groovy new file mode 100644 index 00000000..7f4df563 --- /dev/null +++ b/src/test/groovy/com/intuit/graphql/orchestrator/visitors/DeferQueryExtractorSpec.groovy @@ -0,0 +1,693 @@ +package com.intuit.graphql.orchestrator.visitors + +import com.intuit.graphql.orchestrator.GraphQLOrchestrator +import com.intuit.graphql.orchestrator.deferDirective.DeferOptions +import com.intuit.graphql.orchestrator.visitors.queryVisitors.DeferQueryExtractor +import graphql.ExecutionInput +import graphql.GraphQLException +import graphql.analysis.QueryTransformer +import graphql.language.Document +import graphql.language.Field +import graphql.language.FragmentDefinition +import graphql.language.OperationDefinition +import helpers.BaseIntegrationTestSpecification +import lombok.extern.slf4j.Slf4j + +import java.util.function.Function +import java.util.stream.Collectors + +import static com.intuit.graphql.orchestrator.utils.GraphQLUtil.parser +/** + * Covers test for ObjectTypeExtension, InterfaceTypeExtension, UnionTypeExtension, EnumTypeExtension, + * InputObjectTypeExtension TODO ScalarTypeExtension. + */ +@Slf4j +class DeferQueryExtractorSpec extends BaseIntegrationTestSpecification { + private static final DeferOptions deferOptions = DeferOptions.builder() + .nestedDefersAllowed(true) + .build() + private static final DeferOptions disabledDeferOptions = DeferOptions.builder() + .nestedDefersAllowed(false) + .build() + + private deferTestSchema = """ + type Query { + queryA: NestedObjectA + argQuery(id: String): NestedObjectA + } + + type NestedObjectA { + fieldA: String + fieldB: String + fieldC: String + objectField: TopLevelObject + } + + type TopLevelObject { + fieldD: String + fieldE: String + fieldF: String + nestedObject: NestedObject + } + + type NestedObject { + fieldG: String + fieldH: String + fieldI: String + } + """ + + private deferService = createSimpleMockService("DEFER", deferTestSchema, new HashMap()) + private GraphQLOrchestrator orchestrator = createGraphQLOrchestrator(deferService) + + def "can split Execution input"() { + given: + String query = "query { queryA { fieldA fieldB fieldC @defer } }" + ExecutionInput ei = ExecutionInput.newExecutionInput(query).build() + Document rootDocument = parser.parseDocument(query) + OperationDefinition opDef = rootDocument.getFirstDefinitionOfType(OperationDefinition).get() + Field selection = opDef.selectionSet.getSelections().get(0) as Field + + when: + DeferQueryExtractor visitor = DeferQueryExtractor.builder() + .originalEI(ei) + .rootNode(rootDocument) + .operationDefinition(opDef) + .deferOptions(deferOptions) + .build() + + QueryTransformer.newQueryTransformer() + .schema(orchestrator.getSchema()) + .rootParentType(orchestrator.getSchema().getQueryType()) + .root(selection) + .fragmentsByName(new HashMap()) + .variables(ei.getVariables()) + .build() + .transform(visitor) + + + then: + List splitSet = visitor.getExtractedEIs() + splitSet.size() == 1 + splitSet.get(0).query == "query {\n" + + " queryA {\n" + + " fieldC\n" + + " __typename\n" + + " }\n" + + "}\n" + } + + def "can split EI with alias selections"(){ + given: + String query = "query { queryA { aliasA: fieldA aliasB: fieldB @defer } }" + ExecutionInput ei = ExecutionInput.newExecutionInput(query).build() + Document rootDocument = parser.parseDocument(query) + OperationDefinition opDef = rootDocument.getFirstDefinitionOfType(OperationDefinition).get() + Field selection = opDef.selectionSet.getSelections().get(0) as Field + + when: + DeferQueryExtractor visitor = DeferQueryExtractor.builder() + .originalEI(ei) + .rootNode(rootDocument) + .operationDefinition(opDef) + .deferOptions(deferOptions) + .build() + + QueryTransformer.newQueryTransformer() + .schema(orchestrator.getSchema()) + .rootParentType(orchestrator.getSchema().getQueryType()) + .root(selection) + .fragmentsByName(new HashMap()) + .variables(ei.getVariables()) + .build() + .transform(visitor) + + then: + List splitSet = visitor.getExtractedEIs() + splitSet.get(0).query == "query {\n" + + " queryA {\n" + + " aliasB: fieldB\n" + + " __typename\n" + + " }\n" + + "}\n" + } + + def "can split execution input with arguments"() { + given: + String query = "query { argQuery(id: \"inputA\") { fieldA fieldB @defer } }" + ExecutionInput ei = ExecutionInput.newExecutionInput(query).build() + Document rootDocument = parser.parseDocument(query) + OperationDefinition opDef = rootDocument.getFirstDefinitionOfType(OperationDefinition).get() + Field selection = opDef.selectionSet.getSelections().get(0) as Field + + when: + DeferQueryExtractor visitor = DeferQueryExtractor.builder() + .originalEI(ei) + .rootNode(rootDocument) + .operationDefinition(opDef) + .deferOptions(deferOptions) + .build() + + QueryTransformer.newQueryTransformer() + .schema(orchestrator.getSchema()) + .rootParentType(orchestrator.getSchema().getQueryType()) + .root(selection) + .fragmentsByName(new HashMap()) + .variables(ei.getVariables()) + .build() + .transform(visitor) + + then: + List splitSet = visitor.getExtractedEIs() + splitSet.size() == 1 + splitSet.get(0).query == "query {\n" + + " argQuery(id: \"inputA\") {\n" + + " fieldB\n" + + " __typename\n" + + " }\n" + + "}\n" + } + + def "can split EI with multiple defer on same level"() { + given: + String query = "query { queryA { fieldA fieldB @defer fieldC @defer } }" + ExecutionInput ei = ExecutionInput.newExecutionInput(query).build() + Document rootDocument = parser.parseDocument(query) + OperationDefinition opDef = rootDocument.getFirstDefinitionOfType(OperationDefinition).get() + Field selection = opDef.selectionSet.getSelections().get(0) as Field + + when: + DeferQueryExtractor visitor = DeferQueryExtractor.builder() + .originalEI(ei) + .rootNode(rootDocument) + .operationDefinition(opDef) + .deferOptions(deferOptions) + .build() + + QueryTransformer.newQueryTransformer() + .schema(orchestrator.getSchema()) + .rootParentType(orchestrator.getSchema().getQueryType()) + .root(selection) + .fragmentsByName(new HashMap()) + .variables(ei.getVariables()) + .build() + .transform(visitor) + + then: + List splitSet = visitor.getExtractedEIs() + splitSet.size() == 2 + splitSet.get(0).query == "query {\n" + + " queryA {\n" + + " fieldB\n" + + " __typename\n" + + " }\n" + + "}\n" + splitSet.get(1).query == "query {\n" + + " queryA {\n" + + " fieldC\n" + + " __typename\n" + + " }\n" + + "}\n" + + + } + + def "can split EI with deferred nested selections"() { + given: + String query = "query { queryA { fieldA objectField { fieldD fieldE @defer } } }" + ExecutionInput ei = ExecutionInput.newExecutionInput(query).build() + Document rootDocument = parser.parseDocument(query) + OperationDefinition opDef = rootDocument.getFirstDefinitionOfType(OperationDefinition).get() + Field selection = opDef.selectionSet.getSelections().get(0) as Field + + when: + DeferQueryExtractor visitor = DeferQueryExtractor.builder() + .originalEI(ei) + .rootNode(rootDocument) + .operationDefinition(opDef) + .deferOptions(deferOptions) + .build() + + QueryTransformer.newQueryTransformer() + .schema(orchestrator.getSchema()) + .rootParentType(orchestrator.getSchema().getQueryType()) + .root(selection) + .fragmentsByName(new HashMap()) + .variables(ei.getVariables()) + .build() + .transform(visitor) + + then: + List splitSet = visitor.getExtractedEIs() + splitSet.get(0).query == "query {\n" + + " queryA {\n" + + " objectField {\n" + + " fieldE" + + "\n" + + " __typename\n" + + " }\n" + + " __typename\n" + + " }\n" + + "}\n" + } + + def "Does not split EI when if arg is false"() { + given: + String query = "query { queryA { fieldA fieldB fieldC @defer(if: false) } }" + ExecutionInput ei = ExecutionInput.newExecutionInput(query).build() + Document rootDocument = parser.parseDocument(query) + OperationDefinition opDef = rootDocument.getFirstDefinitionOfType(OperationDefinition).get() + Field selection = opDef.selectionSet.getSelections().get(0) as Field + + when: + DeferQueryExtractor visitor = DeferQueryExtractor.builder() + .originalEI(ei) + .rootNode(rootDocument) + .operationDefinition(opDef) + .deferOptions(deferOptions) + .build() + + QueryTransformer.newQueryTransformer() + .schema(orchestrator.getSchema()) + .rootParentType(orchestrator.getSchema().getQueryType()) + .root(selection) + .fragmentsByName(new HashMap()) + .variables(ei.getVariables()) + .build() + .transform(visitor) + + then: + List splitSet = visitor.getExtractedEIs() + splitSet.size() == 0 + } + + def "prunes selections sets without fields after removing deferred fields"() { + given: + String query = "query { queryA { fieldA objectField { fieldD nestedObject @defer { fieldH @defer} } } }" + ExecutionInput ei = ExecutionInput.newExecutionInput(query).build() + Document rootDocument = parser.parseDocument(query) + OperationDefinition opDef = rootDocument.getFirstDefinitionOfType(OperationDefinition).get() + Field selection = opDef.selectionSet.getSelections().get(0) as Field + + when: + DeferQueryExtractor visitor = DeferQueryExtractor.builder() + .originalEI(ei) + .rootNode(rootDocument) + .operationDefinition(opDef) + .deferOptions(deferOptions) + .build() + + QueryTransformer.newQueryTransformer() + .schema(orchestrator.getSchema()) + .rootParentType(orchestrator.getSchema().getQueryType()) + .root(selection) + .fragmentsByName(new HashMap()) + .variables(ei.getVariables()) + .build() + .transform(visitor) + + then: + List splitSet = visitor.getExtractedEIs() + splitSet.size() == 1 + splitSet.get(0).query == "query {\n" + + " queryA {\n" + + " objectField {\n" + + " nestedObject {\n" + + " fieldH\n" + + " __typename\n" + + " }\n" + + " __typename\n" + + " }\n" + + " __typename\n" + + " }\n" + + "}\n" + } + + def "can split nested defer selections"() { + given: + String query = "query { queryA { fieldA objectField @defer { fieldD fieldE @defer } } }" + ExecutionInput ei = ExecutionInput.newExecutionInput(query).build() + Document rootDocument = parser.parseDocument(query) + OperationDefinition opDef = rootDocument.getFirstDefinitionOfType(OperationDefinition).get() + Field selection = opDef.selectionSet.getSelections().get(0) as Field + + when: + DeferQueryExtractor visitor = DeferQueryExtractor.builder() + .originalEI(ei) + .rootNode(rootDocument) + .operationDefinition(opDef) + .deferOptions(deferOptions) + .build() + + QueryTransformer.newQueryTransformer() + .schema(orchestrator.getSchema()) + .rootParentType(orchestrator.getSchema().getQueryType()) + .root(selection) + .fragmentsByName(new HashMap()) + .variables(ei.getVariables()) + .build() + .transform(visitor) + + then: + List splitSet = visitor.getExtractedEIs() + splitSet.size() == 2 + splitSet.get(0).query == "query {\n" + + " queryA {\n" + + " objectField {\n" + + " fieldD\n" + + " __typename\n" + + " }\n" + + " __typename\n" + + " }\n" + + "}\n" + + splitSet.get(1).query == "query {\n" + + " queryA {\n" + + " objectField {\n" + + " fieldE\n" + + " __typename\n" + + " }\n" + + " __typename\n" + + " }\n" + + "}\n" + } + + def "exception thrown for nested defer selections when option is off"() { + given: + String query = "query { queryA { fieldA objectField @defer { fieldD fieldE @defer } } }" + ExecutionInput ei = ExecutionInput.newExecutionInput(query).build() + Document rootDocument = parser.parseDocument(query) + OperationDefinition opDef = rootDocument.getFirstDefinitionOfType(OperationDefinition).get() + Field selection = opDef.selectionSet.getSelections().get(0) as Field + + when: + DeferQueryExtractor visitor = DeferQueryExtractor.builder() + .originalEI(ei) + .rootNode(rootDocument) + .operationDefinition(opDef) + .deferOptions(disabledDeferOptions) + .build() + + QueryTransformer.newQueryTransformer() + .schema(orchestrator.getSchema()) + .rootParentType(orchestrator.getSchema().getQueryType()) + .root(selection) + .fragmentsByName(new HashMap()) + .variables(ei.getVariables()) + .build() + .transform(visitor) + + then: + def exception = thrown(GraphQLException) + exception.getMessage() ==~ "Nested defers are currently unavailable." + } + + //todo + def "can split EI with variables" () {} + + def "can split EI with inline fragment"() { + given: + String query = """ + query { + queryA { + fieldA + objectField { + fieldD + } + ... on TopLevelObject @defer { + fieldE + } + } + } + """ + ExecutionInput ei = ExecutionInput.newExecutionInput(query).build() + Document rootDocument = parser.parseDocument(query) + OperationDefinition opDef = rootDocument.getFirstDefinitionOfType(OperationDefinition).get() + Field selection = opDef.selectionSet.getSelections().get(0) as Field + + when: + DeferQueryExtractor visitor = DeferQueryExtractor.builder() + .originalEI(ei) + .rootNode(rootDocument) + .operationDefinition(opDef) + .deferOptions(deferOptions) + .build() + + QueryTransformer.newQueryTransformer() + .schema(orchestrator.getSchema()) + .rootParentType(orchestrator.getSchema().getQueryType()) + .root(selection) + .fragmentsByName(new HashMap()) + .variables(ei.getVariables()) + .build() + .transform(visitor) + + then: + List splitSet = visitor.getExtractedEIs() + splitSet.size() == 1 + splitSet.get(0).query == "query {\n" + + " queryA {\n" + + " ... on TopLevelObject {\n" + + " fieldE\n" + + " __typename\n" + + " }\n" + + " __typename\n" + + " }\n" + + "}\n" + } + + def "split EI has correct inline fragments different types"() { + given: + String query = """ + query { + queryA { + fieldA + objectField { + fieldD + } + ... on TopLevelObject @defer { + fieldF + } + objectField { + nestedObject { + fieldH + } + } + ... on TopLevelObject @defer { + fieldE + } + } + } + """ + ExecutionInput ei = ExecutionInput.newExecutionInput(query).build() + Document rootDocument = parser.parseDocument(query) + OperationDefinition opDef = rootDocument.getFirstDefinitionOfType(OperationDefinition).get() + Field selection = opDef.selectionSet.getSelections().get(0) as Field + + when: + DeferQueryExtractor visitor = DeferQueryExtractor.builder() + .originalEI(ei) + .rootNode(rootDocument) + .operationDefinition(opDef) + .deferOptions(deferOptions) + .build() + + QueryTransformer.newQueryTransformer() + .schema(orchestrator.getSchema()) + .rootParentType(orchestrator.getSchema().getQueryType()) + .root(selection) + .fragmentsByName(new HashMap()) + .variables(ei.getVariables()) + .build() + .transform(visitor) + + then: + List splitSet = visitor.getExtractedEIs() + splitSet.size() == 2 + splitSet.get(0).query == "query {\n" + + " queryA {\n" + + " ... on TopLevelObject {\n" + + " fieldF\n" + + " __typename\n" + + " }\n" + + " __typename\n" + + " }\n" + + "}\n" + splitSet.get(1).query == "query {\n" + + " queryA {\n" + + " ... on TopLevelObject {\n" + + " fieldE\n" + + " __typename\n" + + " }\n" + + " __typename\n" + + " }\n" + + "}\n" + } + + def "split EI has correct inline fragments non merged types"() { + given: + String query = """ + query { + queryA { + fieldA + objectField { + fieldD + } + ... on TopLevelObject @defer { + fieldF + } + objectField { + fieldE + } + ... on TopLevelObject @defer { + fieldE + } + } + } + """ + ExecutionInput ei = ExecutionInput.newExecutionInput(query).build() + Document rootDocument = parser.parseDocument(query) + OperationDefinition opDef = rootDocument.getFirstDefinitionOfType(OperationDefinition).get() + Field selection = opDef.selectionSet.getSelections().get(0) as Field + + when: + DeferQueryExtractor visitor = DeferQueryExtractor.builder() + .originalEI(ei) + .rootNode(rootDocument) + .operationDefinition(opDef) + .deferOptions(deferOptions) + .build() + + QueryTransformer.newQueryTransformer() + .schema(orchestrator.getSchema()) + .rootParentType(orchestrator.getSchema().getQueryType()) + .root(selection) + .fragmentsByName(new HashMap()) + .variables(ei.getVariables()) + .build() + .transform(visitor) + + then: + List splitSet = visitor.getExtractedEIs() + splitSet.size() == 2 + splitSet.get(0).query == "query {\n" + + " queryA {\n" + + " ... on TopLevelObject {\n" + + " fieldF\n" + + " __typename\n" + + " }\n" + + " __typename\n" + + " }\n" + + "}\n" + splitSet.get(1).query == "query {\n" + + " queryA {\n" + + " ... on TopLevelObject {\n" + + " fieldE\n" + + " __typename\n" + + " }\n" + + " __typename\n" + + " }\n" + + "}\n" + } + + def "can split EI with fragment spread"() { + given: + String query = """ + query { + queryA { + fieldA + objectField { + fieldD + ... deferredInfo @defer + } + } + } + fragment deferredInfo on TopLevelObject { + fieldE + } + """ + + ExecutionInput ei = ExecutionInput.newExecutionInput(query).build() + Document rootDocument = parser.parseDocument(query) + OperationDefinition opDef = rootDocument.getFirstDefinitionOfType(OperationDefinition).get() + Field selection = opDef.selectionSet.getSelections().get(0) as Field + + when: + Map fragmentDefinitionMap = rootDocument.getDefinitionsOfType(FragmentDefinition.class) + .stream() + .collect(Collectors.toMap({ fragment -> ((FragmentDefinition)fragment).getName() }, Function.identity())); + + DeferQueryExtractor visitor = DeferQueryExtractor.builder() + .originalEI(ei) + .rootNode(rootDocument) + .operationDefinition(opDef) + .deferOptions(deferOptions) + .fragmentDefinitionMap(fragmentDefinitionMap) + .build() + + QueryTransformer.newQueryTransformer() + .schema(orchestrator.getSchema()) + .rootParentType(orchestrator.getSchema().getQueryType()) + .root(selection) + .fragmentsByName(fragmentDefinitionMap) + .variables(ei.getVariables()) + .build() + .transform(visitor) + + then: + List splitSet = visitor.getExtractedEIs() + splitSet.size() == 1 + splitSet.get(0).query == "query {\n" + + " queryA {\n" + + " objectField {\n" + + " ...deferredInfo\n" + + " __typename\n" + + " }\n" + + " __typename\n" + + " }\n" + + "}\n" + + "\nfragment deferredInfo on TopLevelObject {\n" + + " fieldE\n" + + "}\n" + } + + def "thrown exception when building with null defer options"(){ + given: + ExecutionInput ei = ExecutionInput.newExecutionInput() + .query("query { queryA { fieldA } }") + .build() + Document rootDocument = parser.parseDocument(ei.getQuery()) + OperationDefinition opDef = rootDocument.getFirstDefinitionOfType(OperationDefinition).get() + + when: + DeferQueryExtractor.builder() + .originalEI(ei) + .rootNode(rootDocument) + .operationDefinition(opDef) + .deferOptions(null) + .build() + + then: + thrown(NullPointerException) + } + + def "thrown exception when building with null ei"(){ + given: + ExecutionInput ei = ExecutionInput.newExecutionInput() + .query("query { queryA { fieldA } }") + .build() + Document rootDocument = parser.parseDocument(ei.getQuery()) + OperationDefinition opDef = rootDocument.getFirstDefinitionOfType(OperationDefinition).get() + + when: + DeferQueryExtractor.builder() + .originalEI(null) + .rootNode(rootDocument) + .operationDefinition(opDef) + .deferOptions(deferOptions) + .build() + + then: + thrown(NullPointerException) + } +} diff --git a/src/test/groovy/com/intuit/graphql/orchestrator/visitors/PruneChildDeferSelectionsModifierSpec.groovy b/src/test/groovy/com/intuit/graphql/orchestrator/visitors/PruneChildDeferSelectionsModifierSpec.groovy new file mode 100644 index 00000000..60fb7f99 --- /dev/null +++ b/src/test/groovy/com/intuit/graphql/orchestrator/visitors/PruneChildDeferSelectionsModifierSpec.groovy @@ -0,0 +1,294 @@ +package com.intuit.graphql.orchestrator.visitors + +import com.intuit.graphql.orchestrator.deferDirective.DeferOptions +import com.intuit.graphql.orchestrator.visitors.queryVisitors.PruneChildDeferSelectionsModifier +import graphql.GraphQLException +import graphql.language.* +import lombok.extern.slf4j.Slf4j +import spock.lang.Specification + +@Slf4j +class PruneChildDeferSelectionsModifierSpec extends Specification { + + Directive enabledDirective = Directive.newDirective() + .name("defer") + .argument(Argument.newArgument("if", BooleanValue.of(true)).build()) + .build() + Directive disabledDirective = Directive.newDirective() + .name("defer") + .argument(Argument.newArgument("if", BooleanValue.of(false)).build()) + .build() + + DeferOptions enabledNestedDefer = DeferOptions.builder().nestedDefersAllowed(true).build() + DeferOptions disabledNestedDefer = DeferOptions.builder().nestedDefersAllowed(false).build() + + PruneChildDeferSelectionsModifier specToTest = PruneChildDeferSelectionsModifier.builder() + .deferOptions(disabledNestedDefer) + .build() + + PruneChildDeferSelectionsModifier nestedSpecToTest = PruneChildDeferSelectionsModifier.builder() + .deferOptions(enabledNestedDefer) + .build() + + AstTransformer astTransformer = new AstTransformer() + + def "Top Level Deferred Child Field is Removed"(){ + given: + Field deferredField = Field.newField("topLevelDeferredChild").directive(enabledDirective).build() + Field childField2 = Field.newField("topLevelChild").build() + SelectionSet selectionSet = SelectionSet.newSelectionSet().selection(deferredField).selection(childField2).build() + Field root = Field.newField("rootField", selectionSet).build() + + when: + Field result = (Field) astTransformer.transform(root, nestedSpecToTest) + + then: + result != null + result.getSelectionSet().getSelections().size() == 2 + ((Field)result.getSelectionSet().getSelections().get(0)).getName() == "topLevelChild" + ((Field)result.getSelectionSet().getSelections().get(1)).getName() == "__typename" + } + + def "Nested Deferred Child Field is Removed"(){ + given: + Field deferredField = Field.newField("nestedDeferredChild").directive(enabledDirective).build() + Field childField2 = Field.newField("nestedChild").build() + SelectionSet nestedSelectionSet = SelectionSet.newSelectionSet().selection(deferredField).selection(childField2).build() + Field topLevelChild = Field.newField("topLevelChild", nestedSelectionSet).build() + + SelectionSet selectionSet = SelectionSet.newSelectionSet().selection(topLevelChild).build() + Field root = Field.newField("rootField", selectionSet).build() + + when: + Field result = (Field) astTransformer.transform(root, nestedSpecToTest) + + then: + result != null + result.getSelectionSet().getSelections().size() == 2 + + Field resultTopChild = (Field)result.getSelectionSet().getSelections().get(0) + topLevelChild.getName() == "topLevelChild" + ((Field)result.getSelectionSet().getSelections().get(1)).getName() == "__typename" + + resultTopChild.getSelectionSet().getSelections().size() == 2 + ((Field) resultTopChild.getSelectionSet().getSelections().get(0)).getName() == "nestedChild" + ((Field) resultTopChild.getSelectionSet().getSelections().get(1)).getName() == "__typename" + } + + def "Removes Disabled Defer Directive From Child Field"(){ + given: + Field deferredField = Field.newField("topLevelDeferredChild").directive(disabledDirective).build() + Field childField2 = Field.newField("topLevelChild").build() + SelectionSet selectionSet = SelectionSet.newSelectionSet().selection(deferredField).selection(childField2).build() + Field root = Field.newField("rootField", selectionSet).build() + + when: + Field result = (Field) astTransformer.transform(root, nestedSpecToTest) + + then: + result != null + result.getSelectionSet().getSelections().size() == 3 + ((Field)result.getSelectionSet().getSelections().get(0)).getName() == "topLevelDeferredChild" + ((Field)result.getSelectionSet().getSelections().get(1)).getName() == "topLevelChild" + ((Field)result.getSelectionSet().getSelections().get(2)).getName() == "__typename" + } + + def "Throws exception if Field Has Defer Directive and nestedDefer is not Allowed"(){ + given: + Field deferredField = Field.newField("topLevelDeferredChild").directive(enabledDirective).build() + Field childField2 = Field.newField("topLevelChild").build() + SelectionSet selectionSet = SelectionSet.newSelectionSet().selection(deferredField).selection(childField2).build() + Field root = Field.newField("rootField", selectionSet).build() + + when: + astTransformer.transform(root, specToTest) + + then: + def exception = thrown(GraphQLException) + exception.getMessage() ==~ "Nested defers are currently unavailable." + } + + def "Top Level Deferred Child FragmentSpread is Removed"(){ + given: + FragmentSpread deferredFragment = FragmentSpread.newFragmentSpread("topLevelFragment") + .directive(enabledDirective) + .build() + Field childField2 = Field.newField().name("topLevelChild").build() + SelectionSet selectionSet = SelectionSet.newSelectionSet().selection(deferredFragment).selection(childField2).build() + Field root = Field.newField("rootField", selectionSet).build() + + when: + Field result = (Field) astTransformer.transform(root, nestedSpecToTest) + + then: + result != null + result.getSelectionSet().getSelections().size() == 2 + ((Field)result.getSelectionSet().getSelections().get(0)).getName() == "topLevelChild" + ((Field)result.getSelectionSet().getSelections().get(1)).getName() == "__typename" + } + + def "Nested Deferred Child FragmentSpread is Removed"(){ + given: + FragmentSpread deferredFragment = FragmentSpread.newFragmentSpread("nestedDeferredChild") + .directive(enabledDirective) + .build() + Field childField2 = Field.newField().name("nestedChild").build() + SelectionSet nestedSelectionSet = SelectionSet.newSelectionSet().selection(deferredFragment).selection(childField2).build() + Field topLevelChild = Field.newField("topLevelChild", nestedSelectionSet).build() + + SelectionSet selectionSet = SelectionSet.newSelectionSet().selection(topLevelChild).build() + Field root = Field.newField("rootField", selectionSet).build() + + when: + Field result = (Field) astTransformer.transform(root, nestedSpecToTest) + + then: + result != null + result.getSelectionSet().getSelections().size() == 2 + + Field resultTopChild = (Field)result.getSelectionSet().getSelections().get(0) + topLevelChild.getName() == "topLevelChild" + ((Field)result.getSelectionSet().getSelections().get(1)).getName() == "__typename" + + resultTopChild.getSelectionSet().getSelections().size() == 2 + ((Field) resultTopChild.getSelectionSet().getSelections().get(0)).getName() == "nestedChild" + ((Field) resultTopChild.getSelectionSet().getSelections().get(1)).getName() == "__typename" + + } + + def "Removes Disabled Defer Directive From Child FragmentSpread"(){ + given: + FragmentSpread deferredFragment = FragmentSpread.newFragmentSpread("DeferredFrag") + .directive(disabledDirective) + .build() + Field childField2 = Field.newField("topLevelChild").build() + SelectionSet selectionSet = SelectionSet.newSelectionSet().selection(deferredFragment).selection(childField2).build() + Field root = Field.newField("rootField", selectionSet).build() + + when: + Field result = (Field) astTransformer.transform(root, nestedSpecToTest) + + then: + result != null + result.getSelectionSet().getSelections().size() == 3 + ((FragmentSpread)result.getSelectionSet().getSelections().get(0)).getName() == "DeferredFrag" + ((Field)result.getSelectionSet().getSelections().get(1)).getName() == "topLevelChild" + ((Field)result.getSelectionSet().getSelections().get(2)).getName() == "__typename" + } + + def "Throws exception if FragmentSpread Has Defer Directive and nestedDefer is not Allowed"(){ + given: + FragmentSpread deferredFragment = FragmentSpread.newFragmentSpread("topLevelFragment") + .directive(enabledDirective) + .build() + Field childField2 = Field.newField().name("topLevelChild").build() + SelectionSet selectionSet = SelectionSet.newSelectionSet().selection(deferredFragment).selection(childField2).build() + Field root = Field.newField("rootField", selectionSet).build() + + when: + astTransformer.transform(root, specToTest) + + then: + def exception = thrown(GraphQLException) + exception.getMessage() ==~ "Nested defers are currently unavailable." + } + + def "Top Level Deferred Child InlineFragment is Removed"(){ + given: + SelectionSet deferredSelectionSet = SelectionSet.newSelectionSet() + .selection(Field.newField("deferredFragField").build()) + .build() + InlineFragment deferredFragment = InlineFragment.newInlineFragment().selectionSet(deferredSelectionSet) + .directive(enabledDirective) + .build() + Field childField2 = Field.newField().name("topLevelChild").build() + SelectionSet selectionSet = SelectionSet.newSelectionSet().selection(deferredFragment).selection(childField2).build() + Field root = Field.newField("rootField", selectionSet).build() + + when: + Field result = (Field) astTransformer.transform(root, nestedSpecToTest) + + then: + result != null + result.getSelectionSet().getSelections().size() == 2 + ((Field)result.getSelectionSet().getSelections().get(0)).getName() == "topLevelChild" + ((Field)result.getSelectionSet().getSelections().get(1)).getName() == "__typename" + } + + def "Nested Deferred Child InlineFragment is Removed"(){ + given: + SelectionSet deferredSelectionSet = SelectionSet.newSelectionSet() + .selection(Field.newField("deferredFragField").build()) + .build() + InlineFragment deferredFragment = InlineFragment.newInlineFragment().selectionSet(deferredSelectionSet) + .directive(enabledDirective) + .build() + Field childField2 = Field.newField().name("nestedChild").build() + SelectionSet nestedSelectionSet = SelectionSet.newSelectionSet().selection(deferredFragment).selection(childField2).build() + Field topLevelChild = Field.newField("topLevelChild", nestedSelectionSet).build() + + SelectionSet selectionSet = SelectionSet.newSelectionSet().selection(topLevelChild).build() + Field root = Field.newField("rootField", selectionSet).build() + + when: + Field result = (Field) astTransformer.transform(root, nestedSpecToTest) + + then: + result != null + result.getSelectionSet().getSelections().size() == 2 + + Field resultTopChild = (Field)result.getSelectionSet().getSelections().get(0) + topLevelChild.getName() == "topLevelChild" + ((Field)result.getSelectionSet().getSelections().get(1)).getName() == "__typename" + + resultTopChild.getSelectionSet().getSelections().size() == 2 + ((Field) resultTopChild.getSelectionSet().getSelections().get(0)).getName() == "nestedChild" + ((Field) resultTopChild.getSelectionSet().getSelections().get(1)).getName() == "__typename" + + } + + def "Removes Disabled Defer Directive From Child InlineFragment"(){ + given: + SelectionSet deferredSelectionSet = SelectionSet.newSelectionSet() + .selection(Field.newField("deferredFragField").build()) + .build() + InlineFragment deferredFragment = InlineFragment.newInlineFragment().selectionSet(deferredSelectionSet) + .directive(disabledDirective) + .build() + Field childField2 = Field.newField("topLevelChild").build() + SelectionSet selectionSet = SelectionSet.newSelectionSet().selection(deferredFragment).selection(childField2).build() + Field root = Field.newField("rootField", selectionSet).build() + + when: + Field result = (Field) astTransformer.transform(root, nestedSpecToTest) + + then: + result != null + result.getSelectionSet().getSelections().size() == 3 + InlineFragment fragment = (InlineFragment)result.getSelectionSet().getSelections().get(0) + ((Field)fragment.getSelectionSet().getSelections().get(0)).getName() == "deferredFragField" + ((Field)fragment.getSelectionSet().getSelections().get(1)).getName() == "__typename" + ((Field)result.getSelectionSet().getSelections().get(1)).getName() == "topLevelChild" + ((Field)result.getSelectionSet().getSelections().get(2)).getName() == "__typename" + } + + def "Throws exception if InlineFragment Has Defer Directive and nestedDefer is not Allowed"(){ + given: + SelectionSet deferredSelectionSet = SelectionSet.newSelectionSet() + .selection(Field.newField("deferredFragField").build()) + .build() + InlineFragment deferredFragment = InlineFragment.newInlineFragment().selectionSet(deferredSelectionSet) + .directive(enabledDirective) + .build() + Field childField2 = Field.newField().name("topLevelChild").build() + SelectionSet selectionSet = SelectionSet.newSelectionSet().selection(deferredFragment).selection(childField2).build() + Field root = Field.newField("rootField", selectionSet).build() + + when: + astTransformer.transform(root, specToTest) + + then: + def exception = thrown(GraphQLException) + exception.getMessage() ==~ "Nested defers are currently unavailable." + } +} diff --git a/src/test/groovy/helpers/BaseIntegrationTestSpecification.groovy b/src/test/groovy/helpers/BaseIntegrationTestSpecification.groovy index c2d69116..bd52c8be 100644 --- a/src/test/groovy/helpers/BaseIntegrationTestSpecification.groovy +++ b/src/test/groovy/helpers/BaseIntegrationTestSpecification.groovy @@ -10,6 +10,7 @@ import graphql.ExecutionInput import graphql.execution.AsyncExecutionStrategy import graphql.execution.ExecutionIdProvider import graphql.execution.ExecutionStrategy +import graphql.language.AstTransformer import graphql.language.Document import graphql.language.OperationDefinition import graphql.parser.Parser @@ -18,6 +19,7 @@ import spock.lang.Specification class BaseIntegrationTestSpecification extends Specification { public static final Parser PARSER = new Parser() + public static final AstTransformer AST_TRANSFORMER = new AstTransformer() def testService