@@ -30,9 +30,11 @@ import {
30
30
import {
31
31
Availability ,
32
32
LanguageModel ,
33
+ LanguageModelExpected ,
33
34
LanguageModelMessage ,
34
35
LanguageModelMessageContent ,
35
- LanguageModelMessageRole
36
+ LanguageModelMessageRole ,
37
+ LanguageModelMessageType
36
38
} from '../types/language-model' ;
37
39
38
40
/**
@@ -48,13 +50,10 @@ export class ChromeAdapter {
48
50
constructor (
49
51
private languageModelProvider ?: LanguageModel ,
50
52
private mode ?: InferenceMode ,
51
- private onDeviceParams : OnDeviceParams = {
52
- createOptions : {
53
- // Defaults to support image inputs for convenience.
54
- expectedInputs : [ { type : 'image' } ]
55
- }
56
- }
57
- ) { }
53
+ private onDeviceParams : OnDeviceParams = { }
54
+ ) {
55
+ this . onDeviceParams . createOptions ??= { } ;
56
+ }
58
57
59
58
/**
60
59
* Checks if a given request can be made on-device.
@@ -85,8 +84,10 @@ export class ChromeAdapter {
85
84
return false ;
86
85
}
87
86
87
+ const expectedInputs = ChromeAdapter . extractExpectedInputs ( request ) ;
88
+
88
89
// Triggers out-of-band download so model will eventually become available.
89
- const availability = await this . downloadIfAvailable ( ) ;
90
+ const availability = await this . downloadIfAvailable ( expectedInputs ) ;
90
91
91
92
if ( this . mode === 'only_on_device' ) {
92
93
return true ;
@@ -158,6 +159,33 @@ export class ChromeAdapter {
158
159
) ;
159
160
}
160
161
162
+ /**
163
+ * Maps
164
+ * <a href="https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#blob">
165
+ * Vertex's input mime types</a> to
166
+ * <a href="https://github.com/webmachinelearning/prompt-api?tab=readme-ov-file#full-api-surface-in-web-idl">
167
+ * Chrome's expected types</a>.
168
+ *
169
+ * <p>Chrome's API checks availability by type. It's tedious to specify the types in advance, so
170
+ * this method infers the types.</p>
171
+ */
172
+ private static extractExpectedInputs (
173
+ request : GenerateContentRequest
174
+ ) : LanguageModelExpected [ ] {
175
+ const inputSet = new Set < LanguageModelExpected > ( ) ;
176
+ for ( const content of request . contents ) {
177
+ for ( const part of content . parts ) {
178
+ if ( part . inlineData ) {
179
+ const type = part . inlineData . mimeType . split (
180
+ '/'
181
+ ) [ 0 ] as LanguageModelMessageType ;
182
+ inputSet . add ( { type } ) ;
183
+ }
184
+ }
185
+ }
186
+ return Array . from ( inputSet ) ;
187
+ }
188
+
161
189
/**
162
190
* Asserts inference for the given request can be performed by an on-device model.
163
191
*/
@@ -196,12 +224,21 @@ export class ChromeAdapter {
196
224
/**
197
225
* Encapsulates logic to get availability and download a model if one is downloadable.
198
226
*/
199
- private async downloadIfAvailable ( ) : Promise < Availability | undefined > {
227
+ private async downloadIfAvailable (
228
+ expectedInputs : LanguageModelExpected [ ]
229
+ ) : Promise < Availability | undefined > {
230
+ // Side-effect: updates construction-time params with request-time params.
231
+ // This is required because params are referenced through multiple flows.
232
+ // TODO: remove this side effect, since we need to also pass options when creating a session.
233
+ Object . assign ( this . onDeviceParams . createOptions ! , { expectedInputs } ) ;
234
+
200
235
const availability = await this . languageModelProvider ?. availability (
201
236
this . onDeviceParams . createOptions
202
237
) ;
203
238
204
239
if ( availability === Availability . downloadable ) {
240
+ // Side-effect: triggers out-of-band model download.
241
+ // This is required because Chrome manages the model download.
205
242
this . download ( ) ;
206
243
}
207
244
0 commit comments