diff --git a/.changeset/long-keys-watch.md b/.changeset/long-keys-watch.md new file mode 100644 index 00000000000..fb13ed74987 --- /dev/null +++ b/.changeset/long-keys-watch.md @@ -0,0 +1,6 @@ +--- +'firebase': minor +'@firebase/ai': minor +--- + +Add support for `AbortSignal`, allowing requests to be aborted. diff --git a/common/api-review/ai.api.md b/common/api-review/ai.api.md index ab79447798f..7f0a2f09320 100644 --- a/common/api-review/ai.api.md +++ b/common/api-review/ai.api.md @@ -120,8 +120,8 @@ export class ChatSession { params?: StartChatParams | undefined; // (undocumented) requestOptions?: RequestOptions | undefined; - sendMessage(request: string | Array): Promise; - sendMessageStream(request: string | Array): Promise; + sendMessage(request: string | Array, singleRequestOptions?: SingleRequestOptions): Promise; + sendMessageStream(request: string | Array, singleRequestOptions?: SingleRequestOptions): Promise; } // @public @@ -397,9 +397,9 @@ export interface GenerativeContentBlob { // @public export class GenerativeModel extends AIModel { constructor(ai: AI, modelParams: ModelParams, requestOptions?: RequestOptions); - countTokens(request: CountTokensRequest | string | Array): Promise; - generateContent(request: GenerateContentRequest | string | Array): Promise; - generateContentStream(request: GenerateContentRequest | string | Array): Promise; + countTokens(request: CountTokensRequest | string | Array, singleRequestOptions?: SingleRequestOptions): Promise; + generateContent(request: GenerateContentRequest | string | Array, singleRequestOptions?: SingleRequestOptions): Promise; + generateContentStream(request: GenerateContentRequest | string | Array, singleRequestOptions?: SingleRequestOptions): Promise; // (undocumented) generationConfig: GenerationConfig; // (undocumented) @@ -599,9 +599,9 @@ export interface ImagenInlineImage { // @beta export class ImagenModel extends AIModel { constructor(ai: AI, modelParams: ImagenModelParams, requestOptions?: RequestOptions | undefined); - generateImages(prompt: string): Promise>; + generateImages(prompt: string, singleRequestOptions?: SingleRequestOptions): Promise>; // @internal - generateImagesGCS(prompt: string, gcsURI: string): Promise>; + generateImagesGCS(prompt: string, gcsURI: string, singleRequestOptions?: SingleRequestOptions): Promise>; generationConfig?: ImagenGenerationConfig; // (undocumented) requestOptions?: RequestOptions | undefined; @@ -868,6 +868,11 @@ export interface Segment { startIndex: number; } +// @public +export interface SingleRequestOptions extends RequestOptions { + signal?: AbortSignal; +} + // @public export interface StartChatParams extends BaseParams { // (undocumented) diff --git a/common/api-review/vertexai.api.md b/common/api-review/vertexai.api.md index 42da114f9e9..df6e7535581 100644 --- a/common/api-review/vertexai.api.md +++ b/common/api-review/vertexai.api.md @@ -120,8 +120,8 @@ export class ChatSession { params?: StartChatParams | undefined; // (undocumented) requestOptions?: RequestOptions | undefined; - sendMessage(request: string | Array): Promise; - sendMessageStream(request: string | Array): Promise; + sendMessage(request: string | Array, singleRequestOptions?: SingleRequestOptions): Promise; + sendMessageStream(request: string | Array, singleRequestOptions?: SingleRequestOptions): Promise; } // @public @@ -394,11 +394,19 @@ export interface GenerativeContentBlob { } // @public +<<<<<<< HEAD +export class GenerativeModel extends VertexAIModel { + constructor(vertexAI: VertexAI, modelParams: ModelParams, requestOptions?: RequestOptions); + countTokens(request: CountTokensRequest | string | Array, singleRequestOptions?: SingleRequestOptions): Promise; + generateContent(request: GenerateContentRequest | string | Array, singleRequestOptions?: SingleRequestOptions): Promise; + generateContentStream(request: GenerateContentRequest | string | Array, singleRequestOptions?: SingleRequestOptions): Promise; +======= export class GenerativeModel extends AIModel { constructor(ai: AI, modelParams: ModelParams, requestOptions?: RequestOptions); countTokens(request: CountTokensRequest | string | Array): Promise; generateContent(request: GenerateContentRequest | string | Array): Promise; generateContentStream(request: GenerateContentRequest | string | Array): Promise; +>>>>>>> main // (undocumented) generationConfig: GenerationConfig; // (undocumented) @@ -595,11 +603,17 @@ export interface ImagenInlineImage { } // @beta +<<<<<<< HEAD +export class ImagenModel extends VertexAIModel { + constructor(vertexAI: VertexAI, modelParams: ImagenModelParams, requestOptions?: RequestOptions | undefined); + generateImages(prompt: string, singleRequestOptions?: SingleRequestOptions): Promise>; +======= export class ImagenModel extends AIModel { constructor(ai: AI, modelParams: ImagenModelParams, requestOptions?: RequestOptions | undefined); generateImages(prompt: string): Promise>; +>>>>>>> main // @internal - generateImagesGCS(prompt: string, gcsURI: string): Promise>; + generateImagesGCS(prompt: string, gcsURI: string, singleRequestOptions?: SingleRequestOptions): Promise>; generationConfig?: ImagenGenerationConfig; // (undocumented) requestOptions?: RequestOptions | undefined; @@ -857,6 +871,11 @@ export interface Segment { startIndex: number; } +// @public +export interface SingleRequestOptions extends RequestOptions { + signal?: AbortSignal; +} + // @public export interface StartChatParams extends BaseParams { // (undocumented) diff --git a/docs-devsite/_toc.yaml b/docs-devsite/_toc.yaml index b77a6b5910e..6e07ffa792b 100644 --- a/docs-devsite/_toc.yaml +++ b/docs-devsite/_toc.yaml @@ -132,6 +132,8 @@ toc: path: /docs/reference/js/ai.schemashared.md - title: Segment path: /docs/reference/js/ai.segment.md + - title: SingleRequestOptions + path: /docs/reference/js/ai.singlerequestoptions.md - title: StartChatParams path: /docs/reference/js/ai.startchatparams.md - title: StringSchema diff --git a/docs-devsite/ai.chatsession.md b/docs-devsite/ai.chatsession.md index 1d6e403b6a8..211502b4076 100644 --- a/docs-devsite/ai.chatsession.md +++ b/docs-devsite/ai.chatsession.md @@ -37,8 +37,8 @@ export declare class ChatSession | Method | Modifiers | Description | | --- | --- | --- | | [getHistory()](./ai.chatsession.md#chatsessiongethistory) | | Gets the chat history so far. Blocked prompts are not added to history. Neither blocked candidates nor the prompts that generated them are added to history. | -| [sendMessage(request)](./ai.chatsession.md#chatsessionsendmessage) | | Sends a chat message and receives a non-streaming [GenerateContentResult](./ai.generatecontentresult.md#generatecontentresult_interface) | -| [sendMessageStream(request)](./ai.chatsession.md#chatsessionsendmessagestream) | | Sends a chat message and receives the response as a [GenerateContentStreamResult](./ai.generatecontentstreamresult.md#generatecontentstreamresult_interface) containing an iterable stream and a response promise. | +| [sendMessage(request, singleRequestOptions)](./ai.chatsession.md#chatsessionsendmessage) | | Sends a chat message and receives a non-streaming [GenerateContentResult](./ai.generatecontentresult.md#generatecontentresult_interface) | +| [sendMessageStream(request, singleRequestOptions)](./ai.chatsession.md#chatsessionsendmessagestream) | | Sends a chat message and receives the response as a [GenerateContentStreamResult](./ai.generatecontentstreamresult.md#generatecontentstreamresult_interface) containing an iterable stream and a response promise. | ## ChatSession.(constructor) @@ -103,7 +103,7 @@ Sends a chat message and receives a non-streaming [GenerateContentResult](./ai.g Signature: ```typescript -sendMessage(request: string | Array): Promise; +sendMessage(request: string | Array, singleRequestOptions?: SingleRequestOptions): Promise; ``` #### Parameters @@ -111,6 +111,7 @@ sendMessage(request: string | Array): Promise> | | +| singleRequestOptions | [SingleRequestOptions](./ai.singlerequestoptions.md#singlerequestoptions_interface) | | Returns: @@ -123,7 +124,7 @@ Sends a chat message and receives the response as a [GenerateContentStreamResult Signature: ```typescript -sendMessageStream(request: string | Array): Promise; +sendMessageStream(request: string | Array, singleRequestOptions?: SingleRequestOptions): Promise; ``` #### Parameters @@ -131,6 +132,7 @@ sendMessageStream(request: string | Array): Promise> | | +| singleRequestOptions | [SingleRequestOptions](./ai.singlerequestoptions.md#singlerequestoptions_interface) | | Returns: diff --git a/docs-devsite/ai.generativemodel.md b/docs-devsite/ai.generativemodel.md index d91cf80e881..948251271b9 100644 --- a/docs-devsite/ai.generativemodel.md +++ b/docs-devsite/ai.generativemodel.md @@ -40,9 +40,9 @@ export declare class GenerativeModel extends AIModel | Method | Modifiers | Description | | --- | --- | --- | -| [countTokens(request)](./ai.generativemodel.md#generativemodelcounttokens) | | Counts the tokens in the provided request. | -| [generateContent(request)](./ai.generativemodel.md#generativemodelgeneratecontent) | | Makes a single non-streaming call to the model and returns an object containing a single [GenerateContentResponse](./ai.generatecontentresponse.md#generatecontentresponse_interface). | -| [generateContentStream(request)](./ai.generativemodel.md#generativemodelgeneratecontentstream) | | Makes a single streaming call to the model and returns an object containing an iterable stream that iterates over all chunks in the streaming response as well as a promise that returns the final aggregated response. | +| [countTokens(request, singleRequestOptions)](./ai.generativemodel.md#generativemodelcounttokens) | | Counts the tokens in the provided request. | +| [generateContent(request, singleRequestOptions)](./ai.generativemodel.md#generativemodelgeneratecontent) | | Makes a single non-streaming call to the model and returns an object containing a single [GenerateContentResponse](./ai.generatecontentresponse.md#generatecontentresponse_interface). | +| [generateContentStream(request, singleRequestOptions)](./ai.generativemodel.md#generativemodelgeneratecontentstream) | | Makes a single streaming call to the model and returns an object containing an iterable stream that iterates over all chunks in the streaming response as well as a promise that returns the final aggregated response. | | [startChat(startChatParams)](./ai.generativemodel.md#generativemodelstartchat) | | Gets a new [ChatSession](./ai.chatsession.md#chatsession_class) instance which can be used for multi-turn chats. | ## GenerativeModel.(constructor) @@ -118,7 +118,7 @@ Counts the tokens in the provided request. Signature: ```typescript -countTokens(request: CountTokensRequest | string | Array): Promise; +countTokens(request: CountTokensRequest | string | Array, singleRequestOptions?: SingleRequestOptions): Promise; ``` #### Parameters @@ -126,6 +126,7 @@ countTokens(request: CountTokensRequest | string | Array): Promis | Parameter | Type | Description | | --- | --- | --- | | request | [CountTokensRequest](./ai.counttokensrequest.md#counttokensrequest_interface) \| string \| Array<string \| [Part](./ai.md#part)> | | +| singleRequestOptions | [SingleRequestOptions](./ai.singlerequestoptions.md#singlerequestoptions_interface) | | Returns: @@ -138,7 +139,7 @@ Makes a single non-streaming call to the model and returns an object containing Signature: ```typescript -generateContent(request: GenerateContentRequest | string | Array): Promise; +generateContent(request: GenerateContentRequest | string | Array, singleRequestOptions?: SingleRequestOptions): Promise; ``` #### Parameters @@ -146,6 +147,7 @@ generateContent(request: GenerateContentRequest | string | Array) | Parameter | Type | Description | | --- | --- | --- | | request | [GenerateContentRequest](./ai.generatecontentrequest.md#generatecontentrequest_interface) \| string \| Array<string \| [Part](./ai.md#part)> | | +| singleRequestOptions | [SingleRequestOptions](./ai.singlerequestoptions.md#singlerequestoptions_interface) | | Returns: @@ -158,7 +160,7 @@ Makes a single streaming call to the model and returns an object containing an i Signature: ```typescript -generateContentStream(request: GenerateContentRequest | string | Array): Promise; +generateContentStream(request: GenerateContentRequest | string | Array, singleRequestOptions?: SingleRequestOptions): Promise; ``` #### Parameters @@ -166,6 +168,7 @@ generateContentStream(request: GenerateContentRequest | string | Array> | | +| singleRequestOptions | [SingleRequestOptions](./ai.singlerequestoptions.md#singlerequestoptions_interface) | | Returns: diff --git a/docs-devsite/ai.imagenmodel.md b/docs-devsite/ai.imagenmodel.md index 911971e0988..ae0d4daa2d0 100644 --- a/docs-devsite/ai.imagenmodel.md +++ b/docs-devsite/ai.imagenmodel.md @@ -42,7 +42,7 @@ export declare class ImagenModel extends AIModel | Method | Modifiers | Description | | --- | --- | --- | -| [generateImages(prompt)](./ai.imagenmodel.md#imagenmodelgenerateimages) | | (Public Preview) Generates images using the Imagen model and returns them as base64-encoded strings. | +| [generateImages(prompt, singleRequestOptions)](./ai.imagenmodel.md#imagenmodelgenerateimages) | | (Public Preview) Generates images using the Imagen model and returns them as base64-encoded strings. | ## ImagenModel.(constructor) @@ -118,7 +118,7 @@ If the prompt was not blocked, but one or more of the generated images were filt Signature: ```typescript -generateImages(prompt: string): Promise>; +generateImages(prompt: string, singleRequestOptions?: SingleRequestOptions): Promise>; ``` #### Parameters @@ -126,6 +126,7 @@ generateImages(prompt: string): PromiseReturns: diff --git a/docs-devsite/ai.md b/docs-devsite/ai.md index 286c8351fd7..3b160b6fcf1 100644 --- a/docs-devsite/ai.md +++ b/docs-devsite/ai.md @@ -117,6 +117,7 @@ The Firebase AI Web SDK. | [SchemaRequest](./ai.schemarequest.md#schemarequest_interface) | Final format for [Schema](./ai.schema.md#schema_class) params passed to backend requests. | | [SchemaShared](./ai.schemashared.md#schemashared_interface) | Basic [Schema](./ai.schema.md#schema_class) properties shared across several Schema-related types. | | [Segment](./ai.segment.md#segment_interface) | | +| [SingleRequestOptions](./ai.singlerequestoptions.md#singlerequestoptions_interface) | Options that can be provided per-request. Extends the base [RequestOptions](./ai.requestoptions.md#requestoptions_interface) (like timeout and baseUrl) with request-specific controls like cancellation via AbortSignal.Options specified here will override any default [RequestOptions](./ai.requestoptions.md#requestoptions_interface) configured on a model (for example, [GenerativeModel](./ai.generativemodel.md#generativemodel_class)). | | [StartChatParams](./ai.startchatparams.md#startchatparams_interface) | Params for [GenerativeModel.startChat()](./ai.generativemodel.md#generativemodelstartchat). | | [TextPart](./ai.textpart.md#textpart_interface) | Content part interface if the part represents a text string. | | [ToolConfig](./ai.toolconfig.md#toolconfig_interface) | Tool config. This config is shared for all tools provided in the request. | diff --git a/docs-devsite/ai.singlerequestoptions.md b/docs-devsite/ai.singlerequestoptions.md new file mode 100644 index 00000000000..a55bd3c2f3c --- /dev/null +++ b/docs-devsite/ai.singlerequestoptions.md @@ -0,0 +1,61 @@ +Project: /docs/reference/js/_project.yaml +Book: /docs/reference/_book.yaml +page_type: reference + +{% comment %} +DO NOT EDIT THIS FILE! +This is generated by the JS SDK team, and any local changes will be +overwritten. Changes should be made in the source code at +https://github.com/firebase/firebase-js-sdk +{% endcomment %} + +# SingleRequestOptions interface +Options that can be provided per-request. Extends the base [RequestOptions](./ai.requestoptions.md#requestoptions_interface) (like `timeout` and `baseUrl`) with request-specific controls like cancellation via `AbortSignal`. + +Options specified here will override any default [RequestOptions](./ai.requestoptions.md#requestoptions_interface) configured on a model (for example, [GenerativeModel](./ai.generativemodel.md#generativemodel_class)). + +Signature: + +```typescript +export interface SingleRequestOptions extends RequestOptions +``` +Extends: [RequestOptions](./ai.requestoptions.md#requestoptions_interface) + +## Properties + +| Property | Type | Description | +| --- | --- | --- | +| [signal](./ai.singlerequestoptions.md#singlerequestoptionssignal) | AbortSignal | An AbortSignal instance that allows cancelling ongoing requests (like generateContent or generateImages).If provided, calling abort() on the corresponding AbortController will attempt to cancel the underlying HTTP request. An AbortError will be thrown if cancellation is successful.Note that this will not cancel the request in the backend, so any applicable billing charges will still be applied despite cancellation. | + +## SingleRequestOptions.signal + +An `AbortSignal` instance that allows cancelling ongoing requests (like `generateContent` or `generateImages`). + +If provided, calling `abort()` on the corresponding `AbortController` will attempt to cancel the underlying HTTP request. An `AbortError` will be thrown if cancellation is successful. + +Note that this will not cancel the request in the backend, so any applicable billing charges will still be applied despite cancellation. + +Signature: + +```typescript +signal?: AbortSignal; +``` + +### Example + + +```javascript +const controller = new AbortController(); +const model = getGenerativeModel({ + // ... +}); +model.generateContent( + "Write a story about a magic backpack.", + { signal: controller.signal } +); + +// To cancel request: +controller.abort(); + +``` + diff --git a/packages/ai/src/methods/chat-session.test.ts b/packages/ai/src/methods/chat-session.test.ts index 0564aa84ed6..700b94d7d9b 100644 --- a/packages/ai/src/methods/chat-session.test.ts +++ b/packages/ai/src/methods/chat-session.test.ts @@ -54,6 +54,64 @@ describe('ChatSession', () => { match.any ); }); + it('singleRequestOptions overrides requestOptions', async () => { + const generateContentStub = stub( + generateContentMethods, + 'generateContent' + ).rejects('generateContent failed'); // not important + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + timeout: 2000 + }; + const chatSession = new ChatSession( + fakeApiSettings, + 'a-model', + undefined, + requestOptions + ); + await expect(chatSession.sendMessage('hello', singleRequestOptions)).to.be + .rejected; + expect(generateContentStub).to.be.calledWith( + fakeApiSettings, + 'a-model', + match.any, + match({ + timeout: singleRequestOptions.timeout + }) + ); + }); + it('singleRequestOptions is merged with requestOptions', async () => { + const generateContentStub = stub( + generateContentMethods, + 'generateContent' + ).rejects('generateContent failed'); // not important + const abortController = new AbortController(); + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + signal: abortController.signal + }; + const chatSession = new ChatSession( + fakeApiSettings, + 'a-model', + undefined, + requestOptions + ); + await expect(chatSession.sendMessage('hello', singleRequestOptions)).to.be + .rejected; + expect(generateContentStub).to.be.calledWith( + fakeApiSettings, + 'a-model', + match.any, + match({ + timeout: requestOptions.timeout, + signal: singleRequestOptions.signal + }) + ); + }); }); describe('sendMessageStream()', () => { it('generateContentStream errors should be catchable', async () => { @@ -96,5 +154,63 @@ describe('ChatSession', () => { ); clock.restore(); }); + it('singleRequestOptions overrides requestOptions', async () => { + const generateContentStub = stub( + generateContentMethods, + 'generateContent' + ).rejects('generateContent failed'); // not important + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + timeout: 2000 + }; + const chatSession = new ChatSession( + fakeApiSettings, + 'a-model', + undefined, + requestOptions + ); + await expect(chatSession.sendMessage('hello', singleRequestOptions)).to.be + .rejected; + expect(generateContentStub).to.be.calledWith( + fakeApiSettings, + 'a-model', + match.any, + match({ + timeout: singleRequestOptions.timeout + }) + ); + }); + it('singleRequestOptions is merged with requestOptions', async () => { + const generateContentStub = stub( + generateContentMethods, + 'generateContent' + ).rejects('generateContent failed'); // not important + const abortController = new AbortController(); + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + signal: abortController.signal + }; + const chatSession = new ChatSession( + fakeApiSettings, + 'a-model', + undefined, + requestOptions + ); + await expect(chatSession.sendMessage('hello', singleRequestOptions)).to.be + .rejected; + expect(generateContentStub).to.be.calledWith( + fakeApiSettings, + 'a-model', + match.any, + match({ + timeout: requestOptions.timeout, + signal: singleRequestOptions.signal + }) + ); + }); }); }); diff --git a/packages/ai/src/methods/chat-session.ts b/packages/ai/src/methods/chat-session.ts index 60794001e37..43ea0afb692 100644 --- a/packages/ai/src/methods/chat-session.ts +++ b/packages/ai/src/methods/chat-session.ts @@ -22,6 +22,7 @@ import { GenerateContentStreamResult, Part, RequestOptions, + SingleRequestOptions, StartChatParams } from '../types'; import { formatNewContent } from '../requests/request-helpers'; @@ -75,7 +76,8 @@ export class ChatSession { * {@link GenerateContentResult} */ async sendMessage( - request: string | Array + request: string | Array, + singleRequestOptions?: SingleRequestOptions ): Promise { await this._sendPromise; const newContent = formatNewContent(request); @@ -95,7 +97,11 @@ export class ChatSession { this._apiSettings, this.model, generateContentRequest, - this.requestOptions + // Merge requestOptions + { + ...this.requestOptions, + ...singleRequestOptions + } ) ) .then(result => { @@ -130,7 +136,8 @@ export class ChatSession { * and a response promise. */ async sendMessageStream( - request: string | Array + request: string | Array, + singleRequestOptions?: SingleRequestOptions ): Promise { await this._sendPromise; const newContent = formatNewContent(request); @@ -146,7 +153,11 @@ export class ChatSession { this._apiSettings, this.model, generateContentRequest, - this.requestOptions + // Merge requestOptions + { + ...this.requestOptions, + ...singleRequestOptions + } ); // Add onto the chain. diff --git a/packages/ai/src/methods/count-tokens.ts b/packages/ai/src/methods/count-tokens.ts index b1e60e3a182..e87fc8d1e64 100644 --- a/packages/ai/src/methods/count-tokens.ts +++ b/packages/ai/src/methods/count-tokens.ts @@ -18,7 +18,7 @@ import { CountTokensRequest, CountTokensResponse, - RequestOptions + SingleRequestOptions } from '../types'; import { Task, makeRequest } from '../requests/request'; import { ApiSettings } from '../types/internal'; @@ -29,7 +29,7 @@ export async function countTokens( apiSettings: ApiSettings, model: string, params: CountTokensRequest, - requestOptions?: RequestOptions + singleRequestOptions?: SingleRequestOptions ): Promise { let body: string = ''; if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) { @@ -44,7 +44,7 @@ export async function countTokens( apiSettings, false, body, - requestOptions + singleRequestOptions ); return response.json(); } diff --git a/packages/ai/src/methods/generate-content.ts b/packages/ai/src/methods/generate-content.ts index 5f7902f5954..6692c15618a 100644 --- a/packages/ai/src/methods/generate-content.ts +++ b/packages/ai/src/methods/generate-content.ts @@ -20,7 +20,7 @@ import { GenerateContentResponse, GenerateContentResult, GenerateContentStreamResult, - RequestOptions + SingleRequestOptions } from '../types'; import { Task, makeRequest } from '../requests/request'; import { createEnhancedContentResponse } from '../requests/response-helpers'; @@ -33,7 +33,7 @@ export async function generateContentStream( apiSettings: ApiSettings, model: string, params: GenerateContentRequest, - requestOptions?: RequestOptions + singleRequestOptions?: SingleRequestOptions ): Promise { if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) { params = GoogleAIMapper.mapGenerateContentRequest(params); @@ -44,7 +44,7 @@ export async function generateContentStream( apiSettings, /* stream */ true, JSON.stringify(params), - requestOptions + singleRequestOptions ); return processStream(response, apiSettings); // TODO: Map streaming responses } @@ -53,7 +53,7 @@ export async function generateContent( apiSettings: ApiSettings, model: string, params: GenerateContentRequest, - requestOptions?: RequestOptions + singleRequestOptions?: SingleRequestOptions ): Promise { if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) { params = GoogleAIMapper.mapGenerateContentRequest(params); @@ -64,7 +64,7 @@ export async function generateContent( apiSettings, /* stream */ false, JSON.stringify(params), - requestOptions + singleRequestOptions ); const generateContentResponse = await processGenerateContentResponse( response, diff --git a/packages/ai/src/models/generative-model.test.ts b/packages/ai/src/models/generative-model.test.ts index d055b82b1be..4b0a6614a87 100644 --- a/packages/ai/src/models/generative-model.test.ts +++ b/packages/ai/src/models/generative-model.test.ts @@ -15,15 +15,19 @@ * limitations under the License. */ import { use, expect } from 'chai'; +import chaiAsPromised from 'chai-as-promised'; import { GenerativeModel } from './generative-model'; import { FunctionCallingMode, AI } from '../public-types'; import * as request from '../requests/request'; import { match, restore, stub } from 'sinon'; import { getMockResponse } from '../../test-utils/mock-response'; import sinonChai from 'sinon-chai'; +import * as generateContentMethods from '../methods/generate-content'; +import * as countTokens from '../methods/count-tokens'; import { VertexAIBackend } from '../backend'; use(sinonChai); +use(chaiAsPromised); const fakeAI: AI = { app: { @@ -40,6 +44,9 @@ const fakeAI: AI = { }; describe('GenerativeModel', () => { + afterEach(() => { + restore(); + }); it('passes params through to generateContent', async () => { const genModel = new GenerativeModel(fakeAI, { model: 'my-model', @@ -167,182 +174,296 @@ describe('GenerativeModel', () => { ); restore(); }); - it('passes base model params through to ChatSession when there are no startChatParams', async () => { - const genModel = new GenerativeModel(fakeAI, { - model: 'my-model', - generationConfig: { - topK: 1 - } - }); - const chatSession = genModel.startChat(); - expect(chatSession.params?.generationConfig).to.deep.equal({ - topK: 1 - }); - restore(); - }); - it('overrides base model params with startChatParams', () => { - const genModel = new GenerativeModel(fakeAI, { - model: 'my-model', - generationConfig: { - topK: 1 - } - }); - const chatSession = genModel.startChat({ - generationConfig: { - topK: 2 - } - }); - expect(chatSession.params?.generationConfig).to.deep.equal({ - topK: 2 - }); - }); - it('passes params through to chat.sendMessage', async () => { - const genModel = new GenerativeModel(fakeAI, { - model: 'my-model', - tools: [ - { functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] } - ], - toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, - systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }, - generationConfig: { - topK: 1 - } - }); - expect(genModel.tools?.length).to.equal(1); - expect(genModel.toolConfig?.functionCallingConfig?.mode).to.equal( - FunctionCallingMode.NONE - ); - expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly'); - const mockResponse = getMockResponse( - 'vertexAI', - 'unary-success-basic-reply-short.json' + it('generateContent singleRequestOptions overrides requestOptions', async () => { + const generateContentStub = stub( + generateContentMethods, + 'generateContent' + ).rejects('generateContent failed'); // not important + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + timeout: 2000 + }; + const genModel = new GenerativeModel( + fakeAI, + { model: 'my-model' }, + requestOptions ); - const makeRequestStub = stub(request, 'makeRequest').resolves( - mockResponse as Response - ); - await genModel.startChat().sendMessage('hello'); - expect(makeRequestStub).to.be.calledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, + await expect(genModel.generateContent('hello', singleRequestOptions)).to.be + .rejected; + expect(generateContentStub).to.be.calledWith( match.any, - false, - match((value: string) => { - return ( - value.includes('myfunc') && - value.includes(FunctionCallingMode.NONE) && - value.includes('be friendly') && - value.includes('topK') - ); - }), - {} + match.any, + match.any, + match({ + timeout: singleRequestOptions.timeout + }) ); - restore(); }); - it('passes text-only systemInstruction through to chat.sendMessage', async () => { - const genModel = new GenerativeModel(fakeAI, { - model: 'my-model', - systemInstruction: 'be friendly' - }); - expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly'); - const mockResponse = getMockResponse( - 'vertexAI', - 'unary-success-basic-reply-short.json' + it('generateContent singleRequestOptions is merged with requestOptions', async () => { + const generateContentStub = stub( + generateContentMethods, + 'generateContent' + ).rejects('generateContent failed'); // not important + const abortController = new AbortController(); + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + signal: abortController.signal + }; + const genModel = new GenerativeModel( + fakeAI, + { model: 'my-model' }, + requestOptions ); - const makeRequestStub = stub(request, 'makeRequest').resolves( - mockResponse as Response - ); - await genModel.startChat().sendMessage('hello'); - expect(makeRequestStub).to.be.calledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, + await expect(genModel.generateContent('hello', singleRequestOptions)).to.be + .rejected; + expect(generateContentStub).to.be.calledWith( match.any, - false, - match((value: string) => { - return value.includes('be friendly'); - }), - {} + match.any, + match.any, + match({ + timeout: requestOptions.timeout, + signal: singleRequestOptions.signal + }) ); - restore(); - }); - it('startChat overrides model values', async () => { - const genModel = new GenerativeModel(fakeAI, { - model: 'my-model', - tools: [ - { functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] } - ], - toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, - systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }, - generationConfig: { - responseMimeType: 'image/jpeg' - } + it('passes base model params through to ChatSession when there are no startChatParams', async () => { + const genModel = new GenerativeModel(fakeAI, { + model: 'my-model', + generationConfig: { + topK: 1 + } + }); + const chatSession = genModel.startChat(); + expect(chatSession.params?.generationConfig).to.deep.equal({ + topK: 1 + }); + restore(); }); - expect(genModel.tools?.length).to.equal(1); - expect(genModel.toolConfig?.functionCallingConfig?.mode).to.equal( - FunctionCallingMode.NONE - ); - expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly'); - const mockResponse = getMockResponse( - 'vertexAI', - 'unary-success-basic-reply-short.json' - ); - const makeRequestStub = stub(request, 'makeRequest').resolves( - mockResponse as Response - ); - await genModel - .startChat({ + it('overrides base model params with startChatParams', () => { + const genModel = new GenerativeModel(fakeAI, { + model: 'my-model', + generationConfig: { + topK: 1 + } + }); + const chatSession = genModel.startChat({ + generationConfig: { + topK: 2 + } + }); + expect(chatSession.params?.generationConfig).to.deep.equal({ + topK: 2 + }); + }); + it('passes params through to chat.sendMessage', async () => { + const genModel = new GenerativeModel(fakeAI, { + model: 'my-model', tools: [ - { - functionDeclarations: [ - { name: 'otherfunc', description: 'otherdesc' } - ] - } + { functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] } ], toolConfig: { - functionCallingConfig: { mode: FunctionCallingMode.AUTO } + functionCallingConfig: { mode: FunctionCallingMode.NONE } }, - systemInstruction: { role: 'system', parts: [{ text: 'be formal' }] }, + systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }, generationConfig: { - responseMimeType: 'image/png' + topK: 1 } - }) - .sendMessage('hello'); - expect(makeRequestStub).to.be.calledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - match.any, - false, - match((value: string) => { - return ( - value.includes('otherfunc') && - value.includes(FunctionCallingMode.AUTO) && - value.includes('be formal') && - value.includes('image/png') && - !value.includes('image/jpeg') - ); - }), - {} - ); - restore(); - }); - it('calls countTokens', async () => { - const genModel = new GenerativeModel(fakeAI, { model: 'my-model' }); - const mockResponse = getMockResponse( - 'vertexAI', - 'unary-success-total-tokens.json' - ); - const makeRequestStub = stub(request, 'makeRequest').resolves( - mockResponse as Response - ); - await genModel.countTokens('hello'); - expect(makeRequestStub).to.be.calledWith( - 'publishers/google/models/my-model', - request.Task.COUNT_TOKENS, - match.any, - false, - match((value: string) => { - return value.includes('hello'); - }) - ); - restore(); + }); + expect(genModel.tools?.length).to.equal(1); + expect(genModel.toolConfig?.functionCallingConfig?.mode).to.equal( + FunctionCallingMode.NONE + ); + expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly'); + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-success-basic-reply-short.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + await genModel.startChat().sendMessage('hello'); + expect(makeRequestStub).to.be.calledWith( + 'publishers/google/models/my-model', + request.Task.GENERATE_CONTENT, + match.any, + false, + match((value: string) => { + return ( + value.includes('myfunc') && + value.includes(FunctionCallingMode.NONE) && + value.includes('be friendly') && + value.includes('topK') + ); + }), + {} + ); + restore(); + }); + it('passes text-only systemInstruction through to chat.sendMessage', async () => { + const genModel = new GenerativeModel(fakeAI, { + model: 'my-model', + systemInstruction: 'be friendly' + }); + expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly'); + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-success-basic-reply-short.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + await genModel.startChat().sendMessage('hello'); + expect(makeRequestStub).to.be.calledWith( + 'publishers/google/models/my-model', + request.Task.GENERATE_CONTENT, + match.any, + false, + match((value: string) => { + return value.includes('be friendly'); + }), + {} + ); + restore(); + }); + it('startChat overrides model values', async () => { + const genModel = new GenerativeModel(fakeAI, { + model: 'my-model', + tools: [ + { functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] } + ], + toolConfig: { + functionCallingConfig: { mode: FunctionCallingMode.NONE } + }, + systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }, + generationConfig: { + responseMimeType: 'image/jpeg' + } + }); + expect(genModel.tools?.length).to.equal(1); + expect(genModel.toolConfig?.functionCallingConfig?.mode).to.equal( + FunctionCallingMode.NONE + ); + expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly'); + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-success-basic-reply-short.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + await genModel + .startChat({ + tools: [ + { + functionDeclarations: [ + { name: 'otherfunc', description: 'otherdesc' } + ] + } + ], + toolConfig: { + functionCallingConfig: { mode: FunctionCallingMode.AUTO } + }, + systemInstruction: { role: 'system', parts: [{ text: 'be formal' }] }, + generationConfig: { + responseMimeType: 'image/png' + } + }) + .sendMessage('hello'); + expect(makeRequestStub).to.be.calledWith( + 'publishers/google/models/my-model', + request.Task.GENERATE_CONTENT, + match.any, + false, + match((value: string) => { + return ( + value.includes('otherfunc') && + value.includes(FunctionCallingMode.AUTO) && + value.includes('be formal') && + value.includes('image/png') && + !value.includes('image/jpeg') + ); + }), + {} + ); + restore(); + }); + it('calls countTokens', async () => { + const genModel = new GenerativeModel(fakeAI, { model: 'my-model' }); + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-success-total-tokens.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + await genModel.countTokens('hello'); + expect(makeRequestStub).to.be.calledWith( + 'publishers/google/models/my-model', + request.Task.COUNT_TOKENS, + match.any, + false, + match((value: string) => { + return value.includes('hello'); + }) + ); + restore(); + }); + it('countTokens singleRequestOptions overrides requestOptions', async () => { + const countTokensStub = stub(countTokens, 'countTokens').rejects( + 'countTokens failed' + ); + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + timeout: 2000 + }; + const genModel = new GenerativeModel( + fakeAI, + { model: 'my-model' }, + requestOptions + ); + await expect(genModel.countTokens('hello', singleRequestOptions)).to.be + .rejected; + expect(countTokensStub).to.be.calledWith( + match.any, + match.any, + match.any, + match({ + timeout: singleRequestOptions.timeout + }) + ); + }); + it('countTokens singleRequestOptions is merged with requestOptions', async () => { + const countTokensStub = stub(countTokens, 'countTokens').rejects( + 'countTokens failed' + ); + const abortController = new AbortController(); + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + signal: abortController.signal + }; + const genModel = new GenerativeModel( + fakeAI, + { model: 'my-model' }, + requestOptions + ); + await expect(genModel.countTokens('hello', singleRequestOptions)).to.be + .rejected; + expect(countTokensStub).to.be.calledWith( + match.any, + match.any, + match.any, + match({ + timeout: requestOptions.timeout, + signal: singleRequestOptions.signal + }) + ); + }); }); }); diff --git a/packages/ai/src/models/generative-model.ts b/packages/ai/src/models/generative-model.ts index b09a9290aa4..fa1b855e576 100644 --- a/packages/ai/src/models/generative-model.ts +++ b/packages/ai/src/models/generative-model.ts @@ -29,11 +29,12 @@ import { GenerationConfig, ModelParams, Part, - RequestOptions, SafetySetting, + RequestOptions, StartChatParams, Tool, - ToolConfig + ToolConfig, + SingleRequestOptions } from '../types'; import { ChatSession } from '../methods/chat-session'; import { countTokens } from '../methods/count-tokens'; @@ -77,7 +78,8 @@ export class GenerativeModel extends AIModel { * and returns an object containing a single {@link GenerateContentResponse}. */ async generateContent( - request: GenerateContentRequest | string | Array + request: GenerateContentRequest | string | Array, + singleRequestOptions?: SingleRequestOptions ): Promise { const formattedParams = formatGenerateContentInput(request); return generateContent( @@ -91,7 +93,11 @@ export class GenerativeModel extends AIModel { systemInstruction: this.systemInstruction, ...formattedParams }, - this.requestOptions + // Merge request options + { + ...this.requestOptions, + ...singleRequestOptions + } ); } @@ -102,7 +108,8 @@ export class GenerativeModel extends AIModel { * a promise that returns the final aggregated response. */ async generateContentStream( - request: GenerateContentRequest | string | Array + request: GenerateContentRequest | string | Array, + singleRequestOptions?: SingleRequestOptions ): Promise { const formattedParams = formatGenerateContentInput(request); return generateContentStream( @@ -116,7 +123,11 @@ export class GenerativeModel extends AIModel { systemInstruction: this.systemInstruction, ...formattedParams }, - this.requestOptions + // Merge request options + { + ...this.requestOptions, + ...singleRequestOptions + } ); } @@ -149,9 +160,19 @@ export class GenerativeModel extends AIModel { * Counts the tokens in the provided request. */ async countTokens( - request: CountTokensRequest | string | Array + request: CountTokensRequest | string | Array, + singleRequestOptions?: SingleRequestOptions ): Promise { const formattedParams = formatGenerateContentInput(request); - return countTokens(this._apiSettings, this.model, formattedParams); + return countTokens( + this._apiSettings, + this.model, + formattedParams, + // Merge request options + { + ...this.requestOptions, + ...singleRequestOptions + } + ); } } diff --git a/packages/ai/src/models/imagen-model.test.ts b/packages/ai/src/models/imagen-model.test.ts index f4121e18f2d..fb57d5cf61b 100644 --- a/packages/ai/src/models/imagen-model.test.ts +++ b/packages/ai/src/models/imagen-model.test.ts @@ -47,6 +47,9 @@ const fakeAI: AI = { }; describe('ImagenModel', () => { + afterEach(() => { + restore(); + }); it('generateImages makes a request to predict with default parameters', async () => { const mockResponse = getMockResponse( 'vertexAI', @@ -72,9 +75,8 @@ describe('ImagenModel', () => { value.includes(`"sampleCount":1`) ); }), - undefined + {} ); - restore(); }); it('generateImages makes a request to predict with generation config and safety settings', async () => { const imagenModel = new ImagenModel(fakeAI, { @@ -131,9 +133,74 @@ describe('ImagenModel', () => { ) ); }), - undefined + {} + ); + }); + it('generateImages singleRequestOptions overrides requestOptions', async () => { + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + timeout: 2000 + }; + const imagenModel = new ImagenModel( + fakeAI, + { model: 'my-model' }, + requestOptions + ); + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-success-generate-images-base64.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + const prompt = 'A photorealistic image of a toy boat at sea.'; + await imagenModel.generateImages(prompt, singleRequestOptions); + expect(makeRequestStub).to.be.calledWith( + match.any, + request.Task.PREDICT, + match.any, + false, + match.any, + match({ + timeout: singleRequestOptions.timeout + }) + ); + }); + it('generateImages singleRequestOptions is merged with requestOptions', async () => { + const abortController = new AbortController(); + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + signal: abortController.signal + }; + const imagenModel = new ImagenModel( + fakeAI, + { model: 'my-model' }, + requestOptions + ); + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-success-generate-images-base64.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + const prompt = 'A photorealistic image of a toy boat at sea.'; + await imagenModel.generateImages(prompt, singleRequestOptions); + expect(makeRequestStub).to.be.calledWith( + match.any, + request.Task.PREDICT, + match.any, + false, + match.any, + match({ + timeout: requestOptions.timeout, + signal: singleRequestOptions.signal + }) ); - restore(); }); it('throws if prompt blocked', async () => { const mockResponse = getMockResponse( @@ -159,8 +226,72 @@ describe('ImagenModel', () => { expect((e as AIError).message).to.include( "Image generation failed with the following error: The prompt could not be submitted. This prompt contains sensitive words that violate Google's Responsible AI practices. Try rephrasing the prompt. If you think this was an error, send feedback." ); - } finally { - restore(); } }); + it('generateImagesGCS singleRequestOptions overrides requestOptions', async () => { + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + timeout: 2000 + }; + const imagenModel = new ImagenModel( + fakeAI, + { model: 'my-model' }, + requestOptions + ); + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-success-generate-images-gcs.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + const prompt = 'A photorealistic image of a toy boat at sea.'; + await imagenModel.generateImagesGCS(prompt, '', singleRequestOptions); + expect(makeRequestStub).to.be.calledWith( + match.any, + request.Task.PREDICT, + match.any, + false, + match.any, + match({ + timeout: singleRequestOptions.timeout + }) + ); + }); + it('generateImages singleRequestOptions is merged with requestOptions', async () => { + const abortController = new AbortController(); + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + signal: abortController.signal + }; + const imagenModel = new ImagenModel( + fakeAI, + { model: 'my-model' }, + requestOptions + ); + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-success-generate-images-gcs.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + const prompt = 'A photorealistic image of a toy boat at sea.'; + await imagenModel.generateImagesGCS(prompt, '', singleRequestOptions); + expect(makeRequestStub).to.be.calledWith( + match.any, + request.Task.PREDICT, + match.any, + false, + match.any, + match({ + timeout: requestOptions.timeout, + signal: singleRequestOptions.signal + }) + ); + }); }); diff --git a/packages/ai/src/models/imagen-model.ts b/packages/ai/src/models/imagen-model.ts index 3c76a1c721c..eb1df2a6072 100644 --- a/packages/ai/src/models/imagen-model.ts +++ b/packages/ai/src/models/imagen-model.ts @@ -26,7 +26,8 @@ import { RequestOptions, ImagenModelParams, ImagenGenerationResponse, - ImagenSafetySettings + ImagenSafetySettings, + SingleRequestOptions } from '../types'; import { AIModel } from './ai-model'; @@ -102,7 +103,8 @@ export class ImagenModel extends AIModel { * @beta */ async generateImages( - prompt: string + prompt: string, + singleRequestOptions?: SingleRequestOptions ): Promise> { const body = createPredictRequestBody(prompt, { ...this.generationConfig, @@ -114,7 +116,11 @@ export class ImagenModel extends AIModel { this._apiSettings, /* stream */ false, JSON.stringify(body), - this.requestOptions + // Merge request options + { + ...this.requestOptions, + ...singleRequestOptions + } ); return handlePredictResponse(response); } @@ -140,7 +146,8 @@ export class ImagenModel extends AIModel { */ async generateImagesGCS( prompt: string, - gcsURI: string + gcsURI: string, + singleRequestOptions?: SingleRequestOptions ): Promise> { const body = createPredictRequestBody(prompt, { gcsURI, @@ -153,7 +160,10 @@ export class ImagenModel extends AIModel { this._apiSettings, /* stream */ false, JSON.stringify(body), - this.requestOptions + { + ...this.requestOptions, + ...singleRequestOptions + } ); return handlePredictResponse(response); } diff --git a/packages/ai/src/requests/request.test.ts b/packages/ai/src/requests/request.test.ts index 0d162906fdc..a7e99ca56c6 100644 --- a/packages/ai/src/requests/request.test.ts +++ b/packages/ai/src/requests/request.test.ts @@ -16,7 +16,7 @@ */ import { expect, use } from 'chai'; -import { match, restore, stub } from 'sinon'; +import Sinon, { match, restore, stub, useFakeTimers } from 'sinon'; import sinonChai from 'sinon-chai'; import chaiAsPromised from 'chai-as-promised'; import { RequestUrl, Task, getHeaders, makeRequest } from './request'; @@ -277,8 +277,37 @@ describe('request methods', () => { }); }); describe('makeRequest', () => { + let fetchStub: Sinon.SinonStub; + let clock: Sinon.SinonFakeTimers; + const fetchAborter = ( + _url: string, + options?: RequestInit + ): Promise => { + expect(options).to.not.be.undefined; + expect(options!.signal).to.not.be.undefined; + const signal = options!.signal; + console.log(signal); + return new Promise((_resolve, reject): void => { + const abortListener = (): void => { + reject(new DOMException(signal?.reason || 'Aborted', 'AbortError')); + }; + + signal?.addEventListener('abort', abortListener, { once: true }); + }); + }; + + beforeEach(() => { + fetchStub = stub(globalThis, 'fetch'); + clock = useFakeTimers(); + }); + + afterEach(() => { + restore(); + clock.restore(); + }); + it('no error', async () => { - const fetchStub = stub(globalThis, 'fetch').resolves({ + fetchStub.resolves({ ok: true } as Response); const response = await makeRequest( @@ -292,7 +321,7 @@ describe('request methods', () => { expect(response.ok).to.be.true; }); it('error with timeout', async () => { - const fetchStub = stub(globalThis, 'fetch').resolves({ + fetchStub.resolves({ ok: false, status: 500, statusText: 'AbortError' @@ -321,7 +350,7 @@ describe('request methods', () => { expect(fetchStub).to.be.calledOnce; }); it('Network error, no response.json()', async () => { - const fetchStub = stub(globalThis, 'fetch').resolves({ + fetchStub.resolves({ ok: false, status: 500, statusText: 'Server Error' @@ -345,7 +374,7 @@ describe('request methods', () => { expect(fetchStub).to.be.calledOnce; }); it('Network error, includes response.json()', async () => { - const fetchStub = stub(globalThis, 'fetch').resolves({ + fetchStub.resolves({ ok: false, status: 500, statusText: 'Server Error', @@ -371,7 +400,7 @@ describe('request methods', () => { expect(fetchStub).to.be.calledOnce; }); it('Network error, includes response.json() and details', async () => { - const fetchStub = stub(globalThis, 'fetch').resolves({ + fetchStub.resolves({ ok: false, status: 500, statusText: 'Server Error', @@ -409,28 +438,221 @@ describe('request methods', () => { } expect(fetchStub).to.be.calledOnce; }); - }); - it('Network error, API not enabled', async () => { - const mockResponse = getMockResponse( - 'vertexAI', - 'unary-failure-firebasevertexai-api-not-enabled.json' - ); - const fetchStub = stub(globalThis, 'fetch').resolves( - mockResponse as Response - ); - try { + it('Network error, API not enabled', async () => { + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-failure-firebasevertexai-api-not-enabled.json' + ); + fetchStub.resolves(mockResponse as Response); + try { + await makeRequest( + 'models/model-name', + Task.GENERATE_CONTENT, + fakeApiSettings, + false, + '' + ); + } catch (e) { + expect((e as AIError).code).to.equal(AIErrorCode.API_NOT_ENABLED); + expect((e as AIError).message).to.include('my-project'); + expect((e as AIError).message).to.include('googleapis.com'); + } + expect(fetchStub).to.be.calledOnce; + }); + + it('should throw DOMException if external signal is already aborted', async () => { + const controller = new AbortController(); + const abortReason = 'Aborted before request'; + controller.abort(abortReason); + + const requestPromise = makeRequest( + 'models/model-name', + Task.GENERATE_CONTENT, + fakeApiSettings, + false, + '{}', + { signal: controller.signal } + ); + + await expect(requestPromise).to.be.rejectedWith( + DOMException, + abortReason + ); + + expect(fetchStub).not.to.have.been.called; + }); + it('should abort fetch if external signal aborts during request', async () => { + fetchStub.callsFake(fetchAborter); + const controller = new AbortController(); + const abortReason = 'Aborted during request'; + + const requestPromise = makeRequest( + 'models/model-name', + Task.GENERATE_CONTENT, + fakeApiSettings, + false, + '{}', + { signal: controller.signal } + ); + + await clock.tickAsync(0); + controller.abort(abortReason); + + await expect(requestPromise).to.be.rejectedWith( + AIError, + `AI: Error fetching from https://firebasevertexai.googleapis.com/v1beta/projects/my-project/locations/us-central1/models/model-name:generateContent: ${abortReason} (AI/error)` + ); + }); + + it('should abort fetch if timeout expires during request', async () => { + const timeoutDuration = 100; + fetchStub.callsFake(fetchAborter); + + const requestPromise = makeRequest( + 'models/model-name', + Task.GENERATE_CONTENT, + fakeApiSettings, + false, + '{}', + { timeout: timeoutDuration } + ); + + await clock.tickAsync(timeoutDuration + 100); + + await expect(requestPromise).to.be.rejectedWith( + AIError, + /Timeout has expired/ + ); + + expect(fetchStub).to.have.been.calledOnce; + const fetchOptions = fetchStub.firstCall.args[1] as RequestInit; + const internalSignal = fetchOptions.signal; + + expect(internalSignal?.aborted).to.be.true; + expect(internalSignal?.reason).to.equal('Timeout has expired.'); + }); + + it('should succeed and clear timeout if fetch completes before timeout', async () => { + const mockResponse = new Response('{}', { + status: 200, + statusText: 'OK' + }); + const fetchPromise = Promise.resolve(mockResponse); + fetchStub.resolves(fetchPromise); + + const requestPromise = makeRequest( + 'models/model-name', + Task.GENERATE_CONTENT, + fakeApiSettings, + false, + '{}', + { timeout: 5000 } // Generous timeout + ); + + // Advance time slightly, well within timeout + await clock.tickAsync(10); + + const response = await requestPromise; + expect(response.ok).to.be.true; + + expect(fetchStub).to.have.been.calledOnce; + }); + + it('should succeed and clear timeout/listener if fetch completes with signal provided but not aborted', async () => { + const controller = new AbortController(); + const mockResponse = new Response('{}', { + status: 200, + statusText: 'OK' + }); + const fetchPromise = Promise.resolve(mockResponse); + fetchStub.resolves(fetchPromise); + + const requestPromise = makeRequest( + 'models/model-name', + Task.GENERATE_CONTENT, + fakeApiSettings, + false, + '{}', + { signal: controller.signal } + ); + + // Advance time slightly + await clock.tickAsync(10); + + const response = await requestPromise; + expect(response.ok).to.be.true; + expect(fetchStub).to.have.been.calledOnce; + }); + + it('should use external signal abort reason if it occurs before timeout', async () => { + const controller = new AbortController(); + const abortReason = 'External Abort Wins'; + const timeoutDuration = 500; + fetchStub.callsFake(fetchAborter); + + const requestPromise = makeRequest( + 'models/model-name', + Task.GENERATE_CONTENT, + fakeApiSettings, + false, + '{}', + { signal: controller.signal, timeout: timeoutDuration } + ); + + // Advance time, but less than the timeout + await clock.tickAsync(timeoutDuration / 2); + controller.abort(abortReason); + + await expect(requestPromise).to.be.rejectedWith(AIError, abortReason); + }); + + it('should use timeout reason if it occurs before external signal abort', async () => { + const controller = new AbortController(); + const abortReason = 'External Abort Loses'; + const timeoutDuration = 100; + fetchStub.callsFake(fetchAborter); + + const requestPromise = makeRequest( + 'models/model-name', + Task.GENERATE_CONTENT, + fakeApiSettings, + false, + '{}', + { signal: controller.signal, timeout: timeoutDuration } + ); + + // Schedule external abort after timeout + setTimeout(() => controller.abort(abortReason), timeoutDuration * 2); + + // Advance time past the timeout + await clock.tickAsync(timeoutDuration + 1); + + await expect(requestPromise).to.be.rejectedWith( + AIError, + /Timeout has expired/ + ); + }); + + it('should pass internal signal to fetch options', async () => { + const mockResponse = new Response('{}', { + status: 200, + statusText: 'OK' + }); + fetchStub.resolves(mockResponse); + await makeRequest( 'models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, - '' + '{}' ); - } catch (e) { - expect((e as AIError).code).to.equal(AIErrorCode.API_NOT_ENABLED); - expect((e as AIError).message).to.include('my-project'); - expect((e as AIError).message).to.include('googleapis.com'); - } - expect(fetchStub).to.be.calledOnce; + + expect(fetchStub).to.have.been.calledOnce; + const fetchOptions = fetchStub.firstCall.args[1] as RequestInit; + expect(fetchOptions.signal).to.exist; + expect(fetchOptions.signal).to.be.instanceOf(AbortSignal); + expect(fetchOptions.signal?.aborted).to.be.false; + }); }); }); diff --git a/packages/ai/src/requests/request.ts b/packages/ai/src/requests/request.ts index 31c5e9b8125..bd977affba8 100644 --- a/packages/ai/src/requests/request.ts +++ b/packages/ai/src/requests/request.ts @@ -15,7 +15,7 @@ * limitations under the License. */ -import { ErrorDetails, RequestOptions, AIErrorCode } from '../types'; +import { SingleRequestOptions, AIErrorCode, ErrorDetails } from '../types'; import { AIError } from '../errors'; import { ApiSettings } from '../types/internal'; import { @@ -28,6 +28,9 @@ import { import { logger } from '../logger'; import { GoogleAIBackend, VertexAIBackend } from '../backend'; +const TIMEOUT_EXPIRED_MESSAGE = 'Timeout has expired.'; +const ABORT_ERROR_NAME = 'AbortError'; + export enum Task { GENERATE_CONTENT = 'generateContent', STREAM_GENERATE_CONTENT = 'streamGenerateContent', @@ -41,7 +44,7 @@ export class RequestUrl { public task: Task, public apiSettings: ApiSettings, public stream: boolean, - public requestOptions?: RequestOptions + public requestOptions?: SingleRequestOptions ) {} toString(): string { const url = new URL(this.baseUrl); // Throws if the URL is invalid @@ -127,9 +130,15 @@ export async function constructRequest( apiSettings: ApiSettings, stream: boolean, body: string, - requestOptions?: RequestOptions + singleRequestOptions?: SingleRequestOptions ): Promise<{ url: string; fetchOptions: RequestInit }> { - const url = new RequestUrl(model, task, apiSettings, stream, requestOptions); + const url = new RequestUrl( + model, + task, + apiSettings, + stream, + singleRequestOptions + ); return { url: url.toString(), fetchOptions: { @@ -146,11 +155,49 @@ export async function makeRequest( apiSettings: ApiSettings, stream: boolean, body: string, - requestOptions?: RequestOptions + singleRequestOptions?: SingleRequestOptions ): Promise { - const url = new RequestUrl(model, task, apiSettings, stream, requestOptions); + const url = new RequestUrl( + model, + task, + apiSettings, + stream, + singleRequestOptions + ); let response; - let fetchTimeoutId: string | number | NodeJS.Timeout | undefined; + + const externalSignal = singleRequestOptions?.signal; + const timeoutMillis = + singleRequestOptions?.timeout != null && singleRequestOptions.timeout >= 0 + ? singleRequestOptions.timeout + : DEFAULT_FETCH_TIMEOUT_MS; + const internalAbortController = new AbortController(); + const fetchTimeoutId = setTimeout(() => { + internalAbortController.abort(TIMEOUT_EXPIRED_MESSAGE); + logger.debug( + `Aborting request to ${url} due to timeout (${timeoutMillis}ms)` + ); + }, timeoutMillis); + + if (externalSignal) { + if (externalSignal.aborted) { + clearTimeout(fetchTimeoutId); + throw new DOMException( + externalSignal.reason ?? 'Aborted externally before fetch', + ABORT_ERROR_NAME + ); + } + + const externalAbortListener = (): void => { + logger.debug(`Aborting request to ${url} due to external abort signal.`); + internalAbortController.abort(externalSignal.reason); + }; + + externalSignal.addEventListener('abort', externalAbortListener, { + once: true + }); + } + try { const request = await constructRequest( model, @@ -158,16 +205,9 @@ export async function makeRequest( apiSettings, stream, body, - requestOptions + singleRequestOptions ); - // Timeout is 180s by default - const timeoutMillis = - requestOptions?.timeout != null && requestOptions.timeout >= 0 - ? requestOptions.timeout - : DEFAULT_FETCH_TIMEOUT_MS; - const abortController = new AbortController(); - fetchTimeoutId = setTimeout(() => abortController.abort(), timeoutMillis); - request.fetchOptions.signal = abortController.signal; + request.fetchOptions.signal = internalAbortController.signal; response = await fetch(request.url, request.fetchOptions); if (!response.ok) { diff --git a/packages/ai/src/types/requests.ts b/packages/ai/src/types/requests.ts index 67f45095c2a..a22508a1d72 100644 --- a/packages/ai/src/types/requests.ts +++ b/packages/ai/src/types/requests.ts @@ -161,6 +161,47 @@ export interface RequestOptions { baseUrl?: string; } +/** + * Options that can be provided per-request. + * Extends the base {@link RequestOptions} (like `timeout` and `baseUrl`) + * with request-specific controls like cancellation via `AbortSignal`. + * + * Options specified here will override any default {@link RequestOptions} + * configured on a model (for example, {@link GenerativeModel}). + * + * @public + */ +export interface SingleRequestOptions extends RequestOptions { + /** + * An `AbortSignal` instance that allows cancelling ongoing requests (like `generateContent` or + * `generateImages`). + * + * If provided, calling `abort()` on the corresponding `AbortController` + * will attempt to cancel the underlying HTTP request. An `AbortError` will be thrown + * if cancellation is successful. + * + * Note that this will not cancel the request in the backend, so any applicable billing charges + * will still be applied despite cancellation. + * + * @example + * ```javascript + * const controller = new AbortController(); + * const model = getGenerativeModel({ + * // ... + * }); + * model.generateContent( + * "Write a story about a magic backpack.", + * { signal: controller.signal } + * ); + * + * // To cancel request: + * controller.abort(); + * ``` + * @see https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal + */ + signal?: AbortSignal; +} + /** * Defines a tool that model can call to access external knowledge. * @public