@@ -30,12 +30,14 @@ import {
30
30
import {
31
31
Availability ,
32
32
LanguageModel ,
33
+ LanguageModelCreateOptions ,
33
34
LanguageModelExpected ,
34
35
LanguageModelMessage ,
35
36
LanguageModelMessageContent ,
36
37
LanguageModelMessageRole ,
37
38
LanguageModelMessageType
38
39
} from '../types/language-model' ;
40
+ import { deepExtend } from '@firebase/util' ;
39
41
40
42
/**
41
43
* Defines an inference "backend" that uses Chrome's on-device model,
@@ -51,9 +53,7 @@ export class ChromeAdapter {
51
53
private languageModelProvider ?: LanguageModel ,
52
54
private mode ?: InferenceMode ,
53
55
private onDeviceParams : OnDeviceParams = { }
54
- ) {
55
- this . onDeviceParams . createOptions ??= { } ;
56
- }
56
+ ) { }
57
57
58
58
/**
59
59
* Checks if a given request can be made on-device.
@@ -84,10 +84,11 @@ export class ChromeAdapter {
84
84
return false ;
85
85
}
86
86
87
- const expectedInputs = ChromeAdapter . extractExpectedInputs ( request ) ;
87
+ const requestOptions = this . inferCreateOptions ( request ) ;
88
+ const mergedOptions = this . mergeCreateOptions ( requestOptions ) ;
88
89
89
90
// Triggers out-of-band download so model will eventually become available.
90
- const availability = await this . downloadIfAvailable ( expectedInputs ) ;
91
+ const availability = await this . downloadIfAvailable ( mergedOptions ) ;
91
92
92
93
if ( this . mode === 'only_on_device' ) {
93
94
return true ;
@@ -119,7 +120,9 @@ export class ChromeAdapter {
119
120
* @returns {@link Response }, so we can reuse common response formatting.
120
121
*/
121
122
async generateContent ( request : GenerateContentRequest ) : Promise < Response > {
122
- const session = await this . createSession ( ) ;
123
+ const requestOptions = this . inferCreateOptions ( request ) ;
124
+ const mergedOptions = this . mergeCreateOptions ( requestOptions ) ;
125
+ const session = await this . createSession ( mergedOptions ) ;
123
126
const contents = await Promise . all (
124
127
request . contents . map ( ChromeAdapter . toLanguageModelMessage )
125
128
) ;
@@ -141,7 +144,9 @@ export class ChromeAdapter {
141
144
async generateContentStream (
142
145
request : GenerateContentRequest
143
146
) : Promise < Response > {
144
- const session = await this . createSession ( ) ;
147
+ const inferredOptions = this . inferCreateOptions ( request ) ;
148
+ const mergedOptions = this . mergeCreateOptions ( inferredOptions ) ;
149
+ const session = await this . createSession ( mergedOptions ) ;
145
150
const contents = await Promise . all (
146
151
request . contents . map ( ChromeAdapter . toLanguageModelMessage )
147
152
) ;
@@ -164,14 +169,14 @@ export class ChromeAdapter {
164
169
* <a href="https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#blob">
165
170
* Vertex's input mime types</a> to
166
171
* <a href="https://github.com/webmachinelearning/prompt-api?tab=readme-ov-file#full-api-surface-in-web-idl">
167
- * Chrome's expected types</a>.
172
+ * Chrome's expected input types</a>.
168
173
*
169
174
* <p>Chrome's API checks availability by type. It's tedious to specify the types in advance, so
170
175
* this method infers the types.</p>
171
176
*/
172
- private static extractExpectedInputs (
177
+ private inferCreateOptions (
173
178
request : GenerateContentRequest
174
- ) : LanguageModelExpected [ ] {
179
+ ) : LanguageModelCreateOptions {
175
180
const inputSet = new Set < LanguageModelExpected > ( ) ;
176
181
for ( const content of request . contents ) {
177
182
for ( const part of content . parts ) {
@@ -183,7 +188,23 @@ export class ChromeAdapter {
183
188
}
184
189
}
185
190
}
186
- return Array . from ( inputSet ) ;
191
+
192
+ return {
193
+ expectedInputs : Array . from ( inputSet )
194
+ } ;
195
+ }
196
+
197
+ /**
198
+ * Assembles a unified {@link LanguageModelCreateOptions} from create- and request-time options.
199
+ * Request-time options take priority over create-time options.
200
+ */
201
+ private mergeCreateOptions (
202
+ requestOptions : LanguageModelCreateOptions
203
+ ) : LanguageModelCreateOptions {
204
+ return deepExtend (
205
+ this . onDeviceParams . createOptions ,
206
+ requestOptions
207
+ ) as LanguageModelCreateOptions ;
187
208
}
188
209
189
210
/**
@@ -225,15 +246,10 @@ export class ChromeAdapter {
225
246
* Encapsulates logic to get availability and download a model if one is downloadable.
226
247
*/
227
248
private async downloadIfAvailable (
228
- expectedInputs : LanguageModelExpected [ ]
249
+ createOptions : LanguageModelCreateOptions
229
250
) : Promise < Availability | undefined > {
230
- // Side-effect: updates construction-time params with request-time params.
231
- // This is required because params are referenced through multiple flows.
232
- // TODO: remove this side effect, since we need to also pass options when creating a session.
233
- Object . assign ( this . onDeviceParams . createOptions ! , { expectedInputs } ) ;
234
-
235
251
const availability = await this . languageModelProvider ?. availability (
236
- this . onDeviceParams . createOptions
252
+ createOptions
237
253
) ;
238
254
239
255
if ( availability === Availability . downloadable ) {
@@ -328,16 +344,16 @@ export class ChromeAdapter {
328
344
* <p>Chrome will remove a model from memory if it's no longer in use, so this method ensures a
329
345
* new session is created before an old session is destroyed.</p>
330
346
*/
331
- private async createSession ( ) : Promise < LanguageModel > {
347
+ private async createSession (
348
+ createOptions : LanguageModelCreateOptions
349
+ ) : Promise < LanguageModel > {
332
350
if ( ! this . languageModelProvider ) {
333
351
throw new AIError (
334
352
AIErrorCode . REQUEST_ERROR ,
335
353
'Chrome AI requested for unsupported browser version.'
336
354
) ;
337
355
}
338
- const newSession = await this . languageModelProvider . create (
339
- this . onDeviceParams . createOptions
340
- ) ;
356
+ const newSession = await this . languageModelProvider . create ( createOptions ) ;
341
357
if ( this . oldSession ) {
342
358
this . oldSession . destroy ( ) ;
343
359
}
0 commit comments