diff --git a/packages/ai/package.json b/packages/ai/package.json index d159793b206..8382025a68e 100644 --- a/packages/ai/package.json +++ b/packages/ai/package.json @@ -52,6 +52,7 @@ "@firebase/component": "0.6.14", "@firebase/logger": "0.4.4", "@firebase/util": "1.11.1", + "deepmerge": "4.3.1", "tslib": "^2.1.0" }, "license": "Apache-2.0", diff --git a/packages/ai/src/methods/chrome-adapter.test.ts b/packages/ai/src/methods/chrome-adapter.test.ts index f8ea80b0e09..5b245ac1ffb 100644 --- a/packages/ai/src/methods/chrome-adapter.test.ts +++ b/packages/ai/src/methods/chrome-adapter.test.ts @@ -54,30 +54,6 @@ async function toStringArray( describe('ChromeAdapter', () => { describe('constructor', () => { - it('sets image as expected input type by default', async () => { - const languageModelProvider = { - availability: () => Promise.resolve(Availability.available) - } as LanguageModel; - const availabilityStub = stub( - languageModelProvider, - 'availability' - ).resolves(Availability.available); - const adapter = new ChromeAdapter( - languageModelProvider, - 'prefer_on_device' - ); - await adapter.isAvailable({ - contents: [ - { - role: 'user', - parts: [{ text: 'hi' }] - } - ] - }); - expect(availabilityStub).to.have.been.calledWith({ - expectedInputs: [{ type: 'image' }] - }); - }); it('honors explicitly set expected inputs', async () => { const languageModelProvider = { availability: () => Promise.resolve(Availability.available) @@ -299,6 +275,39 @@ describe('ChromeAdapter', () => { }) ).to.be.false; }); + it('extracts and merges expected inputs from the request', async () => { + const languageModelProvider = { + availability: () => Promise.resolve(Availability.available) + } as LanguageModel; + const availabilityStub = stub( + languageModelProvider, + 'availability' + ).resolves(Availability.available); + const adapter = new ChromeAdapter( + languageModelProvider, + 'prefer_on_device', + { + createOptions: { + expectedInputs: [{ type: 'text' }] + } + } + ); + await adapter.isAvailable({ + contents: [ + { + role: 'user', + parts: [ + { text: 'hi' }, + // Triggers image as expected type. + { inlineData: { mimeType: 'image/jpeg', data: 'asd' } } + ] + } + ] + }); + expect(availabilityStub).to.have.been.calledWith({ + expectedInputs: [{ type: 'text' }, { type: 'image' }] + }); + }); }); describe('generateContent', () => { it('throws if Chrome API is undefined', async () => { @@ -378,14 +387,9 @@ describe('ChromeAdapter', () => { ); const promptOutput = 'hi'; const promptStub = stub(languageModel, 'prompt').resolves(promptOutput); - const createOptions = { - systemPrompt: 'be yourself', - expectedInputs: [{ type: 'image' }] - } as LanguageModelCreateOptions; const adapter = new ChromeAdapter( languageModelProvider, - 'prefer_on_device', - { createOptions } + 'prefer_on_device' ); const request = { contents: [ @@ -405,7 +409,9 @@ describe('ChromeAdapter', () => { } as GenerateContentRequest; const response = await adapter.generateContent(request); // Asserts initialization params are proxied. - expect(createStub).to.have.been.calledOnceWith(createOptions); + expect(createStub).to.have.been.calledOnceWith({ + expectedInputs: [{ type: 'image' }] + }); // Asserts Vertex input type is mapped to Chrome type. expect(promptStub).to.have.been.calledOnceWith([ { @@ -606,13 +612,9 @@ describe('ChromeAdapter', () => { } }) ); - const createOptions = { - expectedInputs: [{ type: 'image' }] - } as LanguageModelCreateOptions; const adapter = new ChromeAdapter( languageModelProvider, - 'prefer_on_device', - { createOptions } + 'prefer_on_device' ); const request = { contents: [ @@ -631,7 +633,9 @@ describe('ChromeAdapter', () => { ] } as GenerateContentRequest; const response = await adapter.generateContentStream(request); - expect(createStub).to.have.been.calledOnceWith(createOptions); + expect(createStub).to.have.been.calledOnceWith({ + expectedInputs: [{ type: 'image' }] + }); expect(promptStub).to.have.been.calledOnceWith([ { role: request.contents[0].role, diff --git a/packages/ai/src/methods/chrome-adapter.ts b/packages/ai/src/methods/chrome-adapter.ts index e7bb39c34c8..7f9cb2d7a75 100644 --- a/packages/ai/src/methods/chrome-adapter.ts +++ b/packages/ai/src/methods/chrome-adapter.ts @@ -30,10 +30,14 @@ import { import { Availability, LanguageModel, + LanguageModelCreateOptions, + LanguageModelExpected, LanguageModelMessage, LanguageModelMessageContent, - LanguageModelMessageRole + LanguageModelMessageRole, + LanguageModelMessageType } from '../types/language-model'; +import deepMerge from 'deepmerge'; /** * Defines an inference "backend" that uses Chrome's on-device model, @@ -48,12 +52,7 @@ export class ChromeAdapter { constructor( private languageModelProvider?: LanguageModel, private mode?: InferenceMode, - private onDeviceParams: OnDeviceParams = { - createOptions: { - // Defaults to support image inputs for convenience. - expectedInputs: [{ type: 'image' }] - } - } + private onDeviceParams: OnDeviceParams = {} ) {} /** @@ -85,8 +84,11 @@ export class ChromeAdapter { return false; } + const extractedOptions = this.extractCreateOptions(request); + const mergedOptions = this.mergeCreateOptions(extractedOptions); + // Triggers out-of-band download so model will eventually become available. - const availability = await this.downloadIfAvailable(); + const availability = await this.downloadIfAvailable(mergedOptions); if (this.mode === 'only_on_device') { return true; @@ -118,7 +120,9 @@ export class ChromeAdapter { * @returns {@link Response}, so we can reuse common response formatting. */ async generateContent(request: GenerateContentRequest): Promise { - const session = await this.createSession(); + const extractedOptions = this.extractCreateOptions(request); + const mergedOptions = this.mergeCreateOptions(extractedOptions); + const session = await this.createSession(mergedOptions); const contents = await Promise.all( request.contents.map(ChromeAdapter.toLanguageModelMessage) ); @@ -140,7 +144,9 @@ export class ChromeAdapter { async generateContentStream( request: GenerateContentRequest ): Promise { - const session = await this.createSession(); + const extractedOptions = this.extractCreateOptions(request); + const mergedOptions = this.mergeCreateOptions(extractedOptions); + const session = await this.createSession(mergedOptions); const contents = await Promise.all( request.contents.map(ChromeAdapter.toLanguageModelMessage) ); @@ -158,6 +164,48 @@ export class ChromeAdapter { ); } + /** + * Extracts session creation options specified at request-time. + * + *

In particular, this method maps + * + * Vertex's input mime types to + * + * Chrome's expected input types.

+ * + *

Chrome's API checks availability by type. It's tedious to specify the types in advance, so + * this method infers the types.

+ */ + private extractCreateOptions( + request: GenerateContentRequest + ): LanguageModelCreateOptions { + const inputSet = new Set(); + for (const content of request.contents) { + for (const part of content.parts) { + if (part.inlineData) { + const type = part.inlineData.mimeType.split( + '/' + )[0] as LanguageModelMessageType; + inputSet.add({ type }); + } + } + } + + return { + expectedInputs: Array.from(inputSet) + }; + } + + /** + * Assembles a unified {@link LanguageModelCreateOptions} from create- and request-time options. + * Request-time options take priority over create-time options. + */ + private mergeCreateOptions( + requestOptions: LanguageModelCreateOptions + ): LanguageModelCreateOptions { + return deepMerge(this.onDeviceParams.createOptions || {}, requestOptions); + } + /** * Asserts inference for the given request can be performed by an on-device model. */ @@ -196,13 +244,17 @@ export class ChromeAdapter { /** * Encapsulates logic to get availability and download a model if one is downloadable. */ - private async downloadIfAvailable(): Promise { + private async downloadIfAvailable( + createOptions: LanguageModelCreateOptions + ): Promise { const availability = await this.languageModelProvider?.availability( - this.onDeviceParams.createOptions + createOptions ); if (availability === Availability.downloadable) { - this.download(); + // Side-effect: triggers out-of-band model download. + // This is required because Chrome manages the model download. + this.download(createOptions); } return availability; @@ -212,18 +264,18 @@ export class ChromeAdapter { * Triggers out-of-band download of an on-device model. * *

Chrome only downloads models as needed. Chrome knows a model is needed when code calls - * LanguageModel.create.

+ * {@link LanguageModel.create}.

* *

Since Chrome manages the download, the SDK can only avoid redundant download requests by * tracking if a download has previously been requested.

*/ - private download(): void { + private download(createOptions: LanguageModelCreateOptions): void { if (this.isDownloading) { return; } this.isDownloading = true; this.downloadPromise = this.languageModelProvider - ?.create(this.onDeviceParams.createOptions) + ?.create(createOptions) .then(() => { this.isDownloading = false; }); @@ -291,16 +343,16 @@ export class ChromeAdapter { *

Chrome will remove a model from memory if it's no longer in use, so this method ensures a * new session is created before an old session is destroyed.

*/ - private async createSession(): Promise { + private async createSession( + createOptions: LanguageModelCreateOptions + ): Promise { if (!this.languageModelProvider) { throw new AIError( AIErrorCode.REQUEST_ERROR, 'Chrome AI requested for unsupported browser version.' ); } - const newSession = await this.languageModelProvider.create( - this.onDeviceParams.createOptions - ); + const newSession = await this.languageModelProvider.create(createOptions); if (this.oldSession) { this.oldSession.destroy(); } diff --git a/yarn.lock b/yarn.lock index 51ede769d03..09d7a2eda0e 100644 --- a/yarn.lock +++ b/yarn.lock @@ -6250,7 +6250,7 @@ deep-is@^0.1.3: resolved "https://registry.npmjs.org/deep-is/-/deep-is-0.1.4.tgz#a6f2dce612fadd2ef1f519b73551f17e85199831" integrity sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ== -deepmerge@^4.2.2: +deepmerge@4.3.1, deepmerge@^4.2.2: version "4.3.1" resolved "https://registry.npmjs.org/deepmerge/-/deepmerge-4.3.1.tgz#44b5f2147cd3b00d4b56137685966f26fd25dd4a" integrity sha512-3sUqbMEc77XqpdNO7FRyRog+eW3ph+GYCbj+rK+uYyRMuwsVy0rMiVtPn+QJlKFvWP/1PYpapqYn0Me2knFn+A==