Skip to content

Commit f9d0d83

Browse files
committed
Infer expected inputs from prompt
1 parent 58d92df commit f9d0d83

File tree

2 files changed

+53
-12
lines changed

2 files changed

+53
-12
lines changed

packages/ai/src/methods/chrome-adapter.test.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ async function toStringArray(
5454

5555
describe('ChromeAdapter', () => {
5656
describe('constructor', () => {
57-
it('sets image as expected input type by default', async () => {
57+
it('determines expected inputs by request inspection', async () => {
5858
const languageModelProvider = {
5959
availability: () => Promise.resolve(Availability.available)
6060
} as LanguageModel;
@@ -70,7 +70,11 @@ describe('ChromeAdapter', () => {
7070
contents: [
7171
{
7272
role: 'user',
73-
parts: [{ text: 'hi' }]
73+
parts: [
74+
{ text: 'hi' },
75+
// Triggers image as expected type.
76+
{ inlineData: { mimeType: 'image/jpeg', data: 'asd' } }
77+
]
7478
}
7579
]
7680
});

packages/ai/src/methods/chrome-adapter.ts

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@ import {
3030
import {
3131
Availability,
3232
LanguageModel,
33+
LanguageModelExpected,
3334
LanguageModelMessage,
3435
LanguageModelMessageContent,
35-
LanguageModelMessageRole
36+
LanguageModelMessageRole,
37+
LanguageModelMessageType
3638
} from '../types/language-model';
3739

3840
/**
@@ -48,13 +50,10 @@ export class ChromeAdapter {
4850
constructor(
4951
private languageModelProvider?: LanguageModel,
5052
private mode?: InferenceMode,
51-
private onDeviceParams: OnDeviceParams = {
52-
createOptions: {
53-
// Defaults to support image inputs for convenience.
54-
expectedInputs: [{ type: 'image' }]
55-
}
56-
}
57-
) {}
53+
private onDeviceParams: OnDeviceParams = {}
54+
) {
55+
this.onDeviceParams.createOptions ??= {};
56+
}
5857

5958
/**
6059
* Checks if a given request can be made on-device.
@@ -85,8 +84,10 @@ export class ChromeAdapter {
8584
return false;
8685
}
8786

87+
const expectedInputs = ChromeAdapter.extractExpectedInputs(request);
88+
8889
// Triggers out-of-band download so model will eventually become available.
89-
const availability = await this.downloadIfAvailable();
90+
const availability = await this.downloadIfAvailable(expectedInputs);
9091

9192
if (this.mode === 'only_on_device') {
9293
return true;
@@ -158,6 +159,33 @@ export class ChromeAdapter {
158159
);
159160
}
160161

162+
/**
163+
* Maps
164+
* <a href="https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#blob">
165+
* Vertex's input mime types</a> to
166+
* <a href="https://github.com/webmachinelearning/prompt-api?tab=readme-ov-file#full-api-surface-in-web-idl">
167+
* Chrome's expected types</a>.
168+
*
169+
* <p>Chrome's API checks availability by type. It's tedious to specify the types in advance, so
170+
* this method infers the types.</p>
171+
*/
172+
private static extractExpectedInputs(
173+
request: GenerateContentRequest
174+
): LanguageModelExpected[] {
175+
const inputSet = new Set<LanguageModelExpected>();
176+
for (const content of request.contents) {
177+
for (const part of content.parts) {
178+
if (part.inlineData) {
179+
const type = part.inlineData.mimeType.split(
180+
'/'
181+
)[0] as LanguageModelMessageType;
182+
inputSet.add({ type });
183+
}
184+
}
185+
}
186+
return Array.from(inputSet);
187+
}
188+
161189
/**
162190
* Asserts inference for the given request can be performed by an on-device model.
163191
*/
@@ -196,12 +224,21 @@ export class ChromeAdapter {
196224
/**
197225
* Encapsulates logic to get availability and download a model if one is downloadable.
198226
*/
199-
private async downloadIfAvailable(): Promise<Availability | undefined> {
227+
private async downloadIfAvailable(
228+
expectedInputs: LanguageModelExpected[]
229+
): Promise<Availability | undefined> {
230+
// Side-effect: updates construction-time params with request-time params.
231+
// This is required because params are referenced through multiple flows.
232+
// TODO: remove this side effect, since we need to also pass options when creating a session.
233+
Object.assign(this.onDeviceParams.createOptions!, { expectedInputs });
234+
200235
const availability = await this.languageModelProvider?.availability(
201236
this.onDeviceParams.createOptions
202237
);
203238

204239
if (availability === Availability.downloadable) {
240+
// Side-effect: triggers out-of-band model download.
241+
// This is required because Chrome manages the model download.
205242
this.download();
206243
}
207244

0 commit comments

Comments
 (0)