From 20904e593e3c08288bf2b6401bc95a4a59f7594b Mon Sep 17 00:00:00 2001 From: Yaacov Rydzinski Date: Mon, 13 May 2024 23:29:22 +0300 Subject: [PATCH] incremental: add highWaterMark option to apply backpressure when using async streams This protects against a potential OOM error if we end up pulling data from a stream faster than it is consumed. Default is set at 100; after pulling 100 entries, we will pause until some have been flushed. --- src/execution/IncrementalPublisher.ts | 60 ++++++++++++--- src/execution/__tests__/stream-test.ts | 100 +++++++++++++++++++++++++ src/execution/execute.ts | 32 +++++++- 3 files changed, 177 insertions(+), 15 deletions(-) diff --git a/src/execution/IncrementalPublisher.ts b/src/execution/IncrementalPublisher.ts index 0722da1ed1..5a40e50eb8 100644 --- a/src/execution/IncrementalPublisher.ts +++ b/src/execution/IncrementalPublisher.ts @@ -185,6 +185,7 @@ export function buildIncrementalResponse( } interface IncrementalPublisherContext { + streamHighWaterMark: number; cancellableStreams: Set | undefined; } @@ -201,6 +202,7 @@ class IncrementalPublisher { private _completedResultQueue: Array; private _newPending: Set; private _incremental: Array; + private _asyncStreamCounts: Map; private _completed: Array; // these are assigned within the Promise executor called synchronously within the constructor private _signalled!: Promise; @@ -213,6 +215,7 @@ class IncrementalPublisher { this._completedResultQueue = []; this._newPending = new Set(); this._incremental = []; + this._asyncStreamCounts = new Map(); this._completed = []; this._reset(); } @@ -427,7 +430,18 @@ class IncrementalPublisher { subsequentIncrementalExecutionResult.completed = this._completed; } + for (const [streamRecord, count] of this._asyncStreamCounts) { + streamRecord.waterMark -= count; + if ( + streamRecord.resume !== undefined && + streamRecord.waterMark < this._context.streamHighWaterMark + ) { + streamRecord.resume(); + } + } + this._incremental = []; + this._asyncStreamCounts.clear(); this._completed = []; return { value: subsequentIncrementalExecutionResult, done: false }; @@ -593,20 +607,26 @@ class IncrementalPublisher { errors: streamItemsResult.errors, }); this._pending.delete(streamRecord); - if (isCancellableStreamRecord(streamRecord)) { - invariant(this._context.cancellableStreams !== undefined); - this._context.cancellableStreams.delete(streamRecord); - streamRecord.earlyReturn().catch(() => { - /* c8 ignore next 1 */ - // ignore error - }); + if (isAsyncStreamRecord(streamRecord)) { + this._asyncStreamCounts.delete(streamRecord); + if (isCancellableStreamRecord(streamRecord)) { + invariant(this._context.cancellableStreams !== undefined); + this._context.cancellableStreams.delete(streamRecord); + streamRecord.earlyReturn().catch(() => { + /* c8 ignore next 1 */ + // ignore error + }); + } } } else if (streamItemsResult.result === undefined) { this._completed.push({ id }); this._pending.delete(streamRecord); - if (isCancellableStreamRecord(streamRecord)) { - invariant(this._context.cancellableStreams !== undefined); - this._context.cancellableStreams.delete(streamRecord); + if (isAsyncStreamRecord(streamRecord)) { + this._asyncStreamCounts.delete(streamRecord); + if (isCancellableStreamRecord(streamRecord)) { + invariant(this._context.cancellableStreams !== undefined); + this._context.cancellableStreams.delete(streamRecord); + } } } else { const incrementalEntry: IncrementalStreamResult = { @@ -615,6 +635,13 @@ class IncrementalPublisher { }; this._incremental.push(incrementalEntry); + if (isAsyncStreamRecord(streamRecord)) { + const count = this._asyncStreamCounts.get(streamRecord); + this._asyncStreamCounts.set( + streamRecord, + count === undefined ? 1 : count + 1, + ); + } if (streamItemsResult.incrementalDataRecords !== undefined) { this._addIncrementalDataRecords( @@ -739,7 +766,18 @@ export class DeferredFragmentRecord implements SubsequentResultRecord { } } -export interface CancellableStreamRecord extends SubsequentResultRecord { +export interface AsyncStreamRecord extends SubsequentResultRecord { + waterMark: number; + resume: (() => void) | undefined; +} + +function isAsyncStreamRecord( + subsequentResultRecord: SubsequentResultRecord, +): subsequentResultRecord is AsyncStreamRecord { + return 'waterMark' in subsequentResultRecord; +} + +export interface CancellableStreamRecord extends AsyncStreamRecord { earlyReturn: () => Promise; } diff --git a/src/execution/__tests__/stream-test.ts b/src/execution/__tests__/stream-test.ts index 522b82f3d4..9885902c53 100644 --- a/src/execution/__tests__/stream-test.ts +++ b/src/execution/__tests__/stream-test.ts @@ -667,6 +667,106 @@ describe('Execute: stream directive', () => { }, }); }); + it('Can stream a field that returns an async iterable with backpressure', async () => { + const document = parse(` + query { + friendList @stream { + name + id + } + } + `); + let count = 0; + const executeResult = await experimentalExecuteIncrementally({ + schema, + document, + rootValue: { + async *friendList() { + for (const friend of friends) { + count++; + // eslint-disable-next-line no-await-in-loop + yield await Promise.resolve(friend); + } + }, + }, + streamHighWaterMark: 2, + }); + assert('initialResult' in executeResult); + const iterator = executeResult.subsequentResults[Symbol.asyncIterator](); + + const result1 = executeResult.initialResult; + expectJSON(result1).toDeepEqual({ + data: { + friendList: [], + }, + pending: [{ id: '0', path: ['friendList'] }], + hasNext: true, + }); + + expect(count).to.equal(2); + + await resolveOnNextTick(); + await resolveOnNextTick(); + await resolveOnNextTick(); + await resolveOnNextTick(); + await resolveOnNextTick(); + + const result2 = await iterator.next(); + expectJSON(result2).toDeepEqual({ + done: false, + value: { + incremental: [ + { + items: [{ name: 'Luke', id: '1' }], + id: '0', + }, + ], + hasNext: true, + }, + }); + + expect(count).to.equal(3); + + const result3 = await iterator.next(); + expectJSON(result3).toDeepEqual({ + done: false, + value: { + incremental: [ + { + items: [{ name: 'Han', id: '2' }], + id: '0', + }, + ], + hasNext: true, + }, + }); + + const result4 = await iterator.next(); + expectJSON(result4).toDeepEqual({ + done: false, + value: { + incremental: [ + { + items: [{ name: 'Leia', id: '3' }], + id: '0', + }, + ], + hasNext: true, + }, + }); + + const result5 = await iterator.next(); + expectJSON(result5).toDeepEqual({ + done: false, + value: { + completed: [{ id: '0' }], + hasNext: false, + }, + }); + + const result6 = await iterator.next(); + expectJSON(result6).toDeepEqual({ done: true, value: undefined }); + }); it('Can handle concurrent calls to .next() without waiting', async () => { const document = parse(` query { diff --git a/src/execution/execute.ts b/src/execution/execute.ts index e5e220dd66..2ee05221d0 100644 --- a/src/execution/execute.ts +++ b/src/execution/execute.ts @@ -12,6 +12,7 @@ import { addPath, pathToArray } from '../jsutils/Path.js'; import { promiseForObject } from '../jsutils/promiseForObject.js'; import type { PromiseOrValue } from '../jsutils/PromiseOrValue.js'; import { promiseReduce } from '../jsutils/promiseReduce.js'; +import { promiseWithResolvers } from '../jsutils/promiseWithResolvers.js'; import { GraphQLError } from '../error/GraphQLError.js'; import { locatedError } from '../error/locatedError.js'; @@ -59,6 +60,7 @@ import { collectSubfields as _collectSubfields, } from './collectFields.js'; import type { + AsyncStreamRecord, CancellableStreamRecord, DeferredGroupedFieldSetRecord, DeferredGroupedFieldSetResult, @@ -143,6 +145,7 @@ export interface ExecutionContext { typeResolver: GraphQLTypeResolver; subscribeFieldResolver: GraphQLFieldResolver; errors: Array | undefined; + streamHighWaterMark: number; cancellableStreams: Set | undefined; } @@ -161,6 +164,7 @@ export interface ExecutionArgs { fieldResolver?: Maybe>; typeResolver?: Maybe>; subscribeFieldResolver?: Maybe>; + streamHighWaterMark?: Maybe; } export interface StreamUsage { @@ -439,6 +443,7 @@ export function buildExecutionContext( fieldResolver, typeResolver, subscribeFieldResolver, + streamHighWaterMark, } = args; // If the schema used for execution is invalid, throw an error. @@ -504,6 +509,7 @@ export function buildExecutionContext( subscribeFieldResolver: subscribeFieldResolver ?? defaultFieldResolver, errors: undefined, cancellableStreams: undefined, + streamHighWaterMark: streamHighWaterMark ?? 100, }; } @@ -1096,16 +1102,20 @@ async function completeAsyncIteratorValue( while (true) { if (streamUsage && index >= streamUsage.initialCount) { const returnFn = asyncIterator.return; - let streamRecord: SubsequentResultRecord | CancellableStreamRecord; + let streamRecord: AsyncStreamRecord | CancellableStreamRecord; if (returnFn === undefined) { streamRecord = { label: streamUsage.label, path, - } as SubsequentResultRecord; + waterMark: 0, + resume: undefined, + }; } else { streamRecord = { label: streamUsage.label, path, + waterMark: 0, + resume: undefined, earlyReturn: returnFn.bind(asyncIterator), }; if (exeContext.cancellableStreams === undefined) { @@ -2317,7 +2327,7 @@ function prependNextResolvedStreamItems( } function firstAsyncStreamItems( - streamRecord: SubsequentResultRecord, + streamRecord: AsyncStreamRecord, path: Path, initialIndex: number, asyncIterator: AsyncIterator, @@ -2343,7 +2353,7 @@ function firstAsyncStreamItems( } async function getNextAsyncStreamItemsResult( - streamRecord: SubsequentResultRecord, + streamRecord: AsyncStreamRecord, path: Path, index: number, asyncIterator: AsyncIterator, @@ -2353,6 +2363,18 @@ async function getNextAsyncStreamItemsResult( itemType: GraphQLOutputType, ): Promise { let iteration; + + const waterMark = streamRecord.waterMark; + + if (waterMark === exeContext.streamHighWaterMark) { + // promiseWithResolvers uses void only as a generic type parameter + // see: https://typescript-eslint.io/rules/no-invalid-void-type/ + // eslint-disable-next-line @typescript-eslint/no-invalid-void-type + const { promise: resumed, resolve: resume } = promiseWithResolvers(); + streamRecord.resume = resume; + await resumed; + } + try { iteration = await asyncIterator.next(); } catch (error) { @@ -2366,6 +2388,8 @@ async function getNextAsyncStreamItemsResult( return { streamRecord }; } + streamRecord.waterMark++; + const itemPath = addPath(path, index, undefined); const result = completeStreamItems(