Skip to content

Commit e002c40

Browse files
committed
Refactor to clarify options merging
1 parent f9d0d83 commit e002c40

File tree

1 file changed

+38
-22
lines changed

1 file changed

+38
-22
lines changed

packages/ai/src/methods/chrome-adapter.ts

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,14 @@ import {
3030
import {
3131
Availability,
3232
LanguageModel,
33+
LanguageModelCreateOptions,
3334
LanguageModelExpected,
3435
LanguageModelMessage,
3536
LanguageModelMessageContent,
3637
LanguageModelMessageRole,
3738
LanguageModelMessageType
3839
} from '../types/language-model';
40+
import { deepExtend } from '@firebase/util';
3941

4042
/**
4143
* Defines an inference "backend" that uses Chrome's on-device model,
@@ -51,9 +53,7 @@ export class ChromeAdapter {
5153
private languageModelProvider?: LanguageModel,
5254
private mode?: InferenceMode,
5355
private onDeviceParams: OnDeviceParams = {}
54-
) {
55-
this.onDeviceParams.createOptions ??= {};
56-
}
56+
) {}
5757

5858
/**
5959
* Checks if a given request can be made on-device.
@@ -84,10 +84,11 @@ export class ChromeAdapter {
8484
return false;
8585
}
8686

87-
const expectedInputs = ChromeAdapter.extractExpectedInputs(request);
87+
const requestOptions = this.inferCreateOptions(request);
88+
const mergedOptions = this.mergeCreateOptions(requestOptions);
8889

8990
// Triggers out-of-band download so model will eventually become available.
90-
const availability = await this.downloadIfAvailable(expectedInputs);
91+
const availability = await this.downloadIfAvailable(mergedOptions);
9192

9293
if (this.mode === 'only_on_device') {
9394
return true;
@@ -119,7 +120,9 @@ export class ChromeAdapter {
119120
* @returns {@link Response}, so we can reuse common response formatting.
120121
*/
121122
async generateContent(request: GenerateContentRequest): Promise<Response> {
122-
const session = await this.createSession();
123+
const requestOptions = this.inferCreateOptions(request);
124+
const mergedOptions = this.mergeCreateOptions(requestOptions);
125+
const session = await this.createSession(mergedOptions);
123126
const contents = await Promise.all(
124127
request.contents.map(ChromeAdapter.toLanguageModelMessage)
125128
);
@@ -141,7 +144,9 @@ export class ChromeAdapter {
141144
async generateContentStream(
142145
request: GenerateContentRequest
143146
): Promise<Response> {
144-
const session = await this.createSession();
147+
const inferredOptions = this.inferCreateOptions(request);
148+
const mergedOptions = this.mergeCreateOptions(inferredOptions);
149+
const session = await this.createSession(mergedOptions);
145150
const contents = await Promise.all(
146151
request.contents.map(ChromeAdapter.toLanguageModelMessage)
147152
);
@@ -164,14 +169,14 @@ export class ChromeAdapter {
164169
* <a href="https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#blob">
165170
* Vertex's input mime types</a> to
166171
* <a href="https://github.com/webmachinelearning/prompt-api?tab=readme-ov-file#full-api-surface-in-web-idl">
167-
* Chrome's expected types</a>.
172+
* Chrome's expected input types</a>.
168173
*
169174
* <p>Chrome's API checks availability by type. It's tedious to specify the types in advance, so
170175
* this method infers the types.</p>
171176
*/
172-
private static extractExpectedInputs(
177+
private inferCreateOptions(
173178
request: GenerateContentRequest
174-
): LanguageModelExpected[] {
179+
): LanguageModelCreateOptions {
175180
const inputSet = new Set<LanguageModelExpected>();
176181
for (const content of request.contents) {
177182
for (const part of content.parts) {
@@ -183,7 +188,23 @@ export class ChromeAdapter {
183188
}
184189
}
185190
}
186-
return Array.from(inputSet);
191+
192+
return {
193+
expectedInputs: Array.from(inputSet)
194+
};
195+
}
196+
197+
/**
198+
* Assembles a unified {@link LanguageModelCreateOptions} from create- and request-time options.
199+
* Request-time options take priority over create-time options.
200+
*/
201+
private mergeCreateOptions(
202+
requestOptions: LanguageModelCreateOptions
203+
): LanguageModelCreateOptions {
204+
return deepExtend(
205+
this.onDeviceParams.createOptions,
206+
requestOptions
207+
) as LanguageModelCreateOptions;
187208
}
188209

189210
/**
@@ -225,15 +246,10 @@ export class ChromeAdapter {
225246
* Encapsulates logic to get availability and download a model if one is downloadable.
226247
*/
227248
private async downloadIfAvailable(
228-
expectedInputs: LanguageModelExpected[]
249+
createOptions: LanguageModelCreateOptions
229250
): Promise<Availability | undefined> {
230-
// Side-effect: updates construction-time params with request-time params.
231-
// This is required because params are referenced through multiple flows.
232-
// TODO: remove this side effect, since we need to also pass options when creating a session.
233-
Object.assign(this.onDeviceParams.createOptions!, { expectedInputs });
234-
235251
const availability = await this.languageModelProvider?.availability(
236-
this.onDeviceParams.createOptions
252+
createOptions
237253
);
238254

239255
if (availability === Availability.downloadable) {
@@ -328,16 +344,16 @@ export class ChromeAdapter {
328344
* <p>Chrome will remove a model from memory if it's no longer in use, so this method ensures a
329345
* new session is created before an old session is destroyed.</p>
330346
*/
331-
private async createSession(): Promise<LanguageModel> {
347+
private async createSession(
348+
createOptions: LanguageModelCreateOptions
349+
): Promise<LanguageModel> {
332350
if (!this.languageModelProvider) {
333351
throw new AIError(
334352
AIErrorCode.REQUEST_ERROR,
335353
'Chrome AI requested for unsupported browser version.'
336354
);
337355
}
338-
const newSession = await this.languageModelProvider.create(
339-
this.onDeviceParams.createOptions
340-
);
356+
const newSession = await this.languageModelProvider.create(createOptions);
341357
if (this.oldSession) {
342358
this.oldSession.destroy();
343359
}

0 commit comments

Comments
 (0)