diff --git a/src/execution/__tests__/stream-test.ts b/src/execution/__tests__/stream-test.ts index cd9b9b3965..5645fd9b8b 100644 --- a/src/execution/__tests__/stream-test.ts +++ b/src/execution/__tests__/stream-test.ts @@ -2,6 +2,7 @@ import { assert } from 'chai'; import { describe, it } from 'mocha'; import { expectJSON } from '../../__testUtils__/expectJSON.js'; +import { resolveOnNextTick } from '../../__testUtils__/resolveOnNextTick.js'; import type { PromiseOrValue } from '../../jsutils/PromiseOrValue.js'; @@ -1134,7 +1135,7 @@ describe('Execute: stream directive', () => { }, ]); }); - it('Handles async errors thrown by completeValue after initialCount is reached from async iterable for a non-nullable list', async () => { + it('Handles async errors thrown by completeValue after initialCount is reached from async generator for a non-nullable list', async () => { const document = parse(` query { nonNullFriendList @stream(initialCount: 1) { @@ -1174,9 +1175,152 @@ describe('Execute: stream directive', () => { ], }, ], + hasNext: false, + }, + ]); + }); + it('Handles async errors thrown by completeValue after initialCount is reached from async iterable for a non-nullable list when the async iterable does not provide a return method) ', async () => { + const document = parse(` + query { + nonNullFriendList @stream(initialCount: 1) { + nonNullName + } + } + `); + let count = 0; + const result = await complete(document, { + nonNullFriendList: { + [Symbol.asyncIterator]: () => ({ + next: async () => { + switch (count++) { + case 0: + return Promise.resolve({ + done: false, + value: { nonNullName: friends[0].name }, + }); + case 1: + return Promise.resolve({ + done: false, + value: { + nonNullName: () => Promise.reject(new Error('Oops')), + }, + }); + case 2: + return Promise.resolve({ + done: false, + value: { nonNullName: friends[1].name }, + }); + // Not reached + /* c8 ignore next 5 */ + case 3: + return Promise.resolve({ + done: false, + value: { nonNullName: friends[2].name }, + }); + } + }, + }), + }, + }); + expectJSON(result).toDeepEqual([ + { + data: { + nonNullFriendList: [{ nonNullName: 'Luke' }], + }, hasNext: true, }, { + incremental: [ + { + items: null, + path: ['nonNullFriendList', 1], + errors: [ + { + message: 'Oops', + locations: [{ line: 4, column: 11 }], + path: ['nonNullFriendList', 1, 'nonNullName'], + }, + ], + }, + ], + hasNext: false, + }, + ]); + }); + it('Handles async errors thrown by completeValue after initialCount is reached from async iterable for a non-nullable list when the async iterable provides concurrent next/return methods and has a slow return ', async () => { + const document = parse(` + query { + nonNullFriendList @stream(initialCount: 1) { + nonNullName + } + } + `); + let count = 0; + let returned = false; + const result = await complete(document, { + nonNullFriendList: { + [Symbol.asyncIterator]: () => ({ + next: async () => { + /* c8 ignore next 3 */ + if (returned) { + return Promise.resolve({ done: true }); + } + switch (count++) { + case 0: + return Promise.resolve({ + done: false, + value: { nonNullName: friends[0].name }, + }); + case 1: + return Promise.resolve({ + done: false, + value: { + nonNullName: () => Promise.reject(new Error('Oops')), + }, + }); + case 2: + return Promise.resolve({ + done: false, + value: { nonNullName: friends[1].name }, + }); + // Not reached + /* c8 ignore next 5 */ + case 3: + return Promise.resolve({ + done: false, + value: { nonNullName: friends[2].name }, + }); + } + }, + return: async () => { + await resolveOnNextTick(); + returned = true; + return { done: true }; + }, + }), + }, + }); + expectJSON(result).toDeepEqual([ + { + data: { + nonNullFriendList: [{ nonNullName: 'Luke' }], + }, + hasNext: true, + }, + { + incremental: [ + { + items: null, + path: ['nonNullFriendList', 1], + errors: [ + { + message: 'Oops', + locations: [{ line: 4, column: 11 }], + path: ['nonNullFriendList', 1, 'nonNullName'], + }, + ], + }, + ], hasNext: false, }, ]); @@ -1200,25 +1344,19 @@ describe('Execute: stream directive', () => { } /* c8 ignore stop */, }, }); - expectJSON(result).toDeepEqual([ - { - errors: [ - { - message: - 'Cannot return null for non-nullable field NestedObject.nonNullScalarField.', - locations: [{ line: 4, column: 11 }], - path: ['nestedObject', 'nonNullScalarField'], - }, - ], - data: { - nestedObject: null, + expectJSON(result).toDeepEqual({ + errors: [ + { + message: + 'Cannot return null for non-nullable field NestedObject.nonNullScalarField.', + locations: [{ line: 4, column: 11 }], + path: ['nestedObject', 'nonNullScalarField'], }, - hasNext: true, - }, - { - hasNext: false, + ], + data: { + nestedObject: null, }, - ]); + }); }); it('Filters payloads that are nulled by a later synchronous error', async () => { const document = parse(` @@ -1359,9 +1497,6 @@ describe('Execute: stream directive', () => { ], }, ], - hasNext: true, - }, - { hasNext: false, }, ]); diff --git a/src/execution/execute.ts b/src/execution/execute.ts index 1bc6c4267b..f2a57a414c 100644 --- a/src/execution/execute.ts +++ b/src/execution/execute.ts @@ -122,6 +122,7 @@ export interface ExecutionContext { subscribeFieldResolver: GraphQLFieldResolver; errors: Array; subsequentPayloads: Set; + streams: Set; } /** @@ -504,6 +505,7 @@ export function buildExecutionContext( typeResolver: typeResolver ?? defaultTypeResolver, subscribeFieldResolver: subscribeFieldResolver ?? defaultFieldResolver, subsequentPayloads: new Set(), + streams: new Set(), errors: [], }; } @@ -516,6 +518,7 @@ function buildPerEventExecutionContext( ...exeContext, rootValue: payload, subsequentPayloads: new Set(), + streams: new Set(), errors: [], }; } @@ -1036,6 +1039,11 @@ async function completeAsyncIteratorValue( typeof stream.initialCount === 'number' && index >= stream.initialCount ) { + const streamContext: StreamContext = { + path: pathToArray(path), + iterator, + }; + exeContext.streams.add(streamContext); // eslint-disable-next-line @typescript-eslint/no-floating-promises executeStreamIterator( index, @@ -1045,6 +1053,7 @@ async function completeAsyncIteratorValue( info, itemType, path, + streamContext, stream.label, asyncPayloadRecord, ); @@ -1129,6 +1138,7 @@ function completeListValue( let previousAsyncPayloadRecord = asyncPayloadRecord; const completedResults: Array = []; let index = 0; + let streamContext: StreamContext | undefined; for (const item of result) { // No need to modify the info object containing the path, // since from here on it is not ever accessed by resolver functions. @@ -1139,6 +1149,8 @@ function completeListValue( typeof stream.initialCount === 'number' && index >= stream.initialCount ) { + streamContext = { path: pathToArray(path) }; + exeContext.streams.add(streamContext); previousAsyncPayloadRecord = executeStreamField( path, itemPath, @@ -1147,6 +1159,7 @@ function completeListValue( fieldNodes, info, itemType, + streamContext, stream.label, previousAsyncPayloadRecord, ); @@ -1173,6 +1186,10 @@ function completeListValue( index++; } + if (streamContext) { + exeContext.streams.delete(streamContext); + } + return containsPromise ? Promise.all(completedResults) : completedResults; } @@ -1813,6 +1830,7 @@ function executeStreamField( fieldNodes: ReadonlyArray, info: GraphQLResolveInfo, itemType: GraphQLOutputType, + streamContext: StreamContext, label?: string, parentContext?: AsyncPayloadRecord, ): AsyncPayloadRecord { @@ -1835,6 +1853,8 @@ function executeStreamField( (value) => [value], (error) => { asyncPayloadRecord.errors.push(error); + returnStreamIteratorIgnoringError(streamContext); + exeContext.streams.delete(streamContext); filterSubsequentPayloads(exeContext, path, asyncPayloadRecord); return null; }, @@ -1867,6 +1887,8 @@ function executeStreamField( } } catch (error) { asyncPayloadRecord.errors.push(error); + returnStreamIteratorIgnoringError(streamContext); + exeContext.streams.delete(streamContext); filterSubsequentPayloads(exeContext, path, asyncPayloadRecord); asyncPayloadRecord.addItems(null); return asyncPayloadRecord; @@ -1887,6 +1909,8 @@ function executeStreamField( .then( (value) => [value], (error) => { + returnStreamIteratorIgnoringError(streamContext); + exeContext.streams.delete(streamContext); asyncPayloadRecord.errors.push(error); filterSubsequentPayloads(exeContext, path, asyncPayloadRecord); return null; @@ -1965,6 +1989,7 @@ async function executeStreamIterator( info: GraphQLResolveInfo, itemType: GraphQLOutputType, path: Path, + streamContext: StreamContext, label?: string, parentContext?: AsyncPayloadRecord, ): Promise { @@ -1977,7 +2002,6 @@ async function executeStreamIterator( label, path: itemPath, parentContext: previousAsyncPayloadRecord, - iterator, exeContext, }); @@ -1995,14 +2019,10 @@ async function executeStreamIterator( ); } catch (error) { asyncPayloadRecord.errors.push(error); + returnStreamIteratorIgnoringError(streamContext); + exeContext.streams.delete(streamContext); filterSubsequentPayloads(exeContext, path, asyncPayloadRecord); asyncPayloadRecord.addItems(null); - // entire stream has errored and bubbled upwards - if (iterator?.return) { - iterator.return().catch(() => { - // ignore errors - }); - } return; } @@ -2014,6 +2034,8 @@ async function executeStreamIterator( (value) => [value], (error) => { asyncPayloadRecord.errors.push(error); + returnStreamIteratorIgnoringError(streamContext); + exeContext.streams.delete(streamContext); filterSubsequentPayloads(exeContext, path, asyncPayloadRecord); return null; }, @@ -2025,8 +2047,16 @@ async function executeStreamIterator( asyncPayloadRecord.addItems(completedItems); if (done) { + exeContext.streams.delete(streamContext); + break; + } + + if (!exeContext.streams.has(streamContext)) { + // stream was filtered + returnStreamIteratorIgnoringError(streamContext); break; } + previousAsyncPayloadRecord = asyncPayloadRecord; index++; } @@ -2038,6 +2068,16 @@ function filterSubsequentPayloads( currentAsyncRecord: AsyncPayloadRecord | undefined, ): void { const nullPathArray = pathToArray(nullPath); + exeContext.streams.forEach((stream) => { + for (let i = 0; i < nullPathArray.length; i++) { + if (stream.path[i] !== nullPathArray[i]) { + // stream points to a path unaffected by this payload + return; + } + } + returnStreamIteratorIgnoringError(stream); + exeContext.streams.delete(stream); + }); exeContext.subsequentPayloads.forEach((asyncRecord) => { if (asyncRecord === currentAsyncRecord) { // don't remove payload from where error originates @@ -2049,16 +2089,16 @@ function filterSubsequentPayloads( return; } } - // asyncRecord path points to nulled error field - if (isStreamPayload(asyncRecord) && asyncRecord.iterator?.return) { - asyncRecord.iterator.return().catch(() => { - // ignore error - }); - } exeContext.subsequentPayloads.delete(asyncRecord); }); } +function returnStreamIteratorIgnoringError(streamContext: StreamContext): void { + streamContext.iterator?.return?.().catch(() => { + // ignore error + }); +} + function getCompletedIncrementalResults( exeContext: ExecutionContext, ): Array { @@ -2133,12 +2173,9 @@ function yieldSubsequentPayloads( function returnStreamIterators() { const promises: Array>> = []; - exeContext.subsequentPayloads.forEach((asyncPayloadRecord) => { - if ( - isStreamPayload(asyncPayloadRecord) && - asyncPayloadRecord.iterator?.return - ) { - promises.push(asyncPayloadRecord.iterator.return()); + exeContext.streams.forEach((stream) => { + if (stream.iterator?.return) { + promises.push(stream.iterator.return()); } }); return Promise.all(promises); @@ -2211,6 +2248,10 @@ class DeferredFragmentRecord { this._resolve?.(data); } } +interface StreamContext { + path: Array; + iterator?: AsyncIterator | undefined; +} class StreamRecord { type: 'stream'; @@ -2220,7 +2261,6 @@ class StreamRecord { items: Array | null; promise: Promise; parentContext: AsyncPayloadRecord | undefined; - iterator: AsyncIterator | undefined; isCompletedIterator?: boolean; isCompleted: boolean; _exeContext: ExecutionContext; @@ -2228,7 +2268,6 @@ class StreamRecord { constructor(opts: { label: string | undefined; path: Path | undefined; - iterator?: AsyncIterator; parentContext: AsyncPayloadRecord | undefined; exeContext: ExecutionContext; }) { @@ -2237,7 +2276,6 @@ class StreamRecord { this.label = opts.label; this.path = pathToArray(opts.path); this.parentContext = opts.parentContext; - this.iterator = opts.iterator; this.errors = []; this._exeContext = opts.exeContext; this._exeContext.subsequentPayloads.add(this);