diff --git a/packages/ai/package.json b/packages/ai/package.json
index d159793b206..8382025a68e 100644
--- a/packages/ai/package.json
+++ b/packages/ai/package.json
@@ -52,6 +52,7 @@
"@firebase/component": "0.6.14",
"@firebase/logger": "0.4.4",
"@firebase/util": "1.11.1",
+ "deepmerge": "4.3.1",
"tslib": "^2.1.0"
},
"license": "Apache-2.0",
diff --git a/packages/ai/src/methods/chrome-adapter.test.ts b/packages/ai/src/methods/chrome-adapter.test.ts
index f8ea80b0e09..5b245ac1ffb 100644
--- a/packages/ai/src/methods/chrome-adapter.test.ts
+++ b/packages/ai/src/methods/chrome-adapter.test.ts
@@ -54,30 +54,6 @@ async function toStringArray(
describe('ChromeAdapter', () => {
describe('constructor', () => {
- it('sets image as expected input type by default', async () => {
- const languageModelProvider = {
- availability: () => Promise.resolve(Availability.available)
- } as LanguageModel;
- const availabilityStub = stub(
- languageModelProvider,
- 'availability'
- ).resolves(Availability.available);
- const adapter = new ChromeAdapter(
- languageModelProvider,
- 'prefer_on_device'
- );
- await adapter.isAvailable({
- contents: [
- {
- role: 'user',
- parts: [{ text: 'hi' }]
- }
- ]
- });
- expect(availabilityStub).to.have.been.calledWith({
- expectedInputs: [{ type: 'image' }]
- });
- });
it('honors explicitly set expected inputs', async () => {
const languageModelProvider = {
availability: () => Promise.resolve(Availability.available)
@@ -299,6 +275,39 @@ describe('ChromeAdapter', () => {
})
).to.be.false;
});
+ it('extracts and merges expected inputs from the request', async () => {
+ const languageModelProvider = {
+ availability: () => Promise.resolve(Availability.available)
+ } as LanguageModel;
+ const availabilityStub = stub(
+ languageModelProvider,
+ 'availability'
+ ).resolves(Availability.available);
+ const adapter = new ChromeAdapter(
+ languageModelProvider,
+ 'prefer_on_device',
+ {
+ createOptions: {
+ expectedInputs: [{ type: 'text' }]
+ }
+ }
+ );
+ await adapter.isAvailable({
+ contents: [
+ {
+ role: 'user',
+ parts: [
+ { text: 'hi' },
+ // Triggers image as expected type.
+ { inlineData: { mimeType: 'image/jpeg', data: 'asd' } }
+ ]
+ }
+ ]
+ });
+ expect(availabilityStub).to.have.been.calledWith({
+ expectedInputs: [{ type: 'text' }, { type: 'image' }]
+ });
+ });
});
describe('generateContent', () => {
it('throws if Chrome API is undefined', async () => {
@@ -378,14 +387,9 @@ describe('ChromeAdapter', () => {
);
const promptOutput = 'hi';
const promptStub = stub(languageModel, 'prompt').resolves(promptOutput);
- const createOptions = {
- systemPrompt: 'be yourself',
- expectedInputs: [{ type: 'image' }]
- } as LanguageModelCreateOptions;
const adapter = new ChromeAdapter(
languageModelProvider,
- 'prefer_on_device',
- { createOptions }
+ 'prefer_on_device'
);
const request = {
contents: [
@@ -405,7 +409,9 @@ describe('ChromeAdapter', () => {
} as GenerateContentRequest;
const response = await adapter.generateContent(request);
// Asserts initialization params are proxied.
- expect(createStub).to.have.been.calledOnceWith(createOptions);
+ expect(createStub).to.have.been.calledOnceWith({
+ expectedInputs: [{ type: 'image' }]
+ });
// Asserts Vertex input type is mapped to Chrome type.
expect(promptStub).to.have.been.calledOnceWith([
{
@@ -606,13 +612,9 @@ describe('ChromeAdapter', () => {
}
})
);
- const createOptions = {
- expectedInputs: [{ type: 'image' }]
- } as LanguageModelCreateOptions;
const adapter = new ChromeAdapter(
languageModelProvider,
- 'prefer_on_device',
- { createOptions }
+ 'prefer_on_device'
);
const request = {
contents: [
@@ -631,7 +633,9 @@ describe('ChromeAdapter', () => {
]
} as GenerateContentRequest;
const response = await adapter.generateContentStream(request);
- expect(createStub).to.have.been.calledOnceWith(createOptions);
+ expect(createStub).to.have.been.calledOnceWith({
+ expectedInputs: [{ type: 'image' }]
+ });
expect(promptStub).to.have.been.calledOnceWith([
{
role: request.contents[0].role,
diff --git a/packages/ai/src/methods/chrome-adapter.ts b/packages/ai/src/methods/chrome-adapter.ts
index e7bb39c34c8..7f9cb2d7a75 100644
--- a/packages/ai/src/methods/chrome-adapter.ts
+++ b/packages/ai/src/methods/chrome-adapter.ts
@@ -30,10 +30,14 @@ import {
import {
Availability,
LanguageModel,
+ LanguageModelCreateOptions,
+ LanguageModelExpected,
LanguageModelMessage,
LanguageModelMessageContent,
- LanguageModelMessageRole
+ LanguageModelMessageRole,
+ LanguageModelMessageType
} from '../types/language-model';
+import deepMerge from 'deepmerge';
/**
* Defines an inference "backend" that uses Chrome's on-device model,
@@ -48,12 +52,7 @@ export class ChromeAdapter {
constructor(
private languageModelProvider?: LanguageModel,
private mode?: InferenceMode,
- private onDeviceParams: OnDeviceParams = {
- createOptions: {
- // Defaults to support image inputs for convenience.
- expectedInputs: [{ type: 'image' }]
- }
- }
+ private onDeviceParams: OnDeviceParams = {}
) {}
/**
@@ -85,8 +84,11 @@ export class ChromeAdapter {
return false;
}
+ const extractedOptions = this.extractCreateOptions(request);
+ const mergedOptions = this.mergeCreateOptions(extractedOptions);
+
// Triggers out-of-band download so model will eventually become available.
- const availability = await this.downloadIfAvailable();
+ const availability = await this.downloadIfAvailable(mergedOptions);
if (this.mode === 'only_on_device') {
return true;
@@ -118,7 +120,9 @@ export class ChromeAdapter {
* @returns {@link Response}, so we can reuse common response formatting.
*/
async generateContent(request: GenerateContentRequest): Promise In particular, this method maps
+ *
+ * Vertex's input mime types to
+ *
+ * Chrome's expected input types. Chrome's API checks availability by type. It's tedious to specify the types in advance, so
+ * this method infers the types. Chrome only downloads models as needed. Chrome knows a model is needed when code calls
- * LanguageModel.create.
Since Chrome manages the download, the SDK can only avoid redundant download requests by * tracking if a download has previously been requested.
*/ - private download(): void { + private download(createOptions: LanguageModelCreateOptions): void { if (this.isDownloading) { return; } this.isDownloading = true; this.downloadPromise = this.languageModelProvider - ?.create(this.onDeviceParams.createOptions) + ?.create(createOptions) .then(() => { this.isDownloading = false; }); @@ -291,16 +343,16 @@ export class ChromeAdapter { *Chrome will remove a model from memory if it's no longer in use, so this method ensures a * new session is created before an old session is destroyed.
*/ - private async createSession(): Promise