diff --git a/.gitignore b/.gitignore index 5aaf5c0b5be..a989576ccf3 100644 --- a/.gitignore +++ b/.gitignore @@ -103,4 +103,4 @@ vertexai-sdk-test-data mocks-lookup.ts # temp changeset output -changeset-temp.json \ No newline at end of file +changeset-temp.json diff --git a/.vscode/launch.json b/.vscode/launch.json index df14c96daa6..1f627304b61 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -9,7 +9,7 @@ "type": "node", "request": "launch", "program": "${workspaceFolder}/node_modules/.bin/_mocha", - "cwd": "${workspaceRoot}/packages/vertexai", + "cwd": "${workspaceRoot}/packages/ai", "args": [ "--require", "ts-node/register", @@ -24,6 +24,26 @@ }, "sourceMaps": true }, + { + "name": "AI Integration Tests (node)", + "type": "node", + "request": "launch", + "program": "${workspaceFolder}/node_modules/.bin/_mocha", + "cwd": "${workspaceRoot}/packages/ai", + "args": [ + "--require", + "ts-node/register", + "--require", + "src/index.node.ts", + "--timeout", + "5000", + "integration/**/*.test.ts" + ], + "env": { + "TS_NODE_COMPILER_OPTIONS": "{\"module\":\"commonjs\"}" + }, + "sourceMaps": true + }, { "type": "node", "request": "launch", diff --git a/packages/ai/integration/chat.test.ts b/packages/ai/integration/chat.test.ts new file mode 100644 index 00000000000..6af0e7a9af9 --- /dev/null +++ b/packages/ai/integration/chat.test.ts @@ -0,0 +1,132 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { expect } from 'chai'; +import { + Content, + GenerationConfig, + HarmBlockThreshold, + HarmCategory, + SafetySetting, + getGenerativeModel +} from '../src'; +import { testConfigs, TOKEN_COUNT_DELTA } from './constants'; + +describe('Chat Session', () => { + testConfigs.forEach(testConfig => { + describe(`${testConfig.toString()}`, () => { + const commonGenerationConfig: GenerationConfig = { + temperature: 0, + topP: 0, + responseMimeType: 'text/plain' + }; + + const commonSafetySettings: SafetySetting[] = [ + { + category: HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE + }, + { + category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE + }, + { + category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE + }, + { + category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE + } + ]; + + const commonSystemInstruction: Content = { + role: 'system', + parts: [ + { + text: 'You are a friendly and helpful assistant.' + } + ] + }; + + it('startChat and sendMessage: text input, text output', async () => { + const model = getGenerativeModel(testConfig.ai, { + model: testConfig.model, + generationConfig: commonGenerationConfig, + safetySettings: commonSafetySettings, + systemInstruction: commonSystemInstruction + }); + + const chat = model.startChat(); + const result1 = await chat.sendMessage( + 'What is the capital of France?' + ); + const response1 = result1.response; + expect(response1.text().trim().toLowerCase()).to.include('paris'); + + let history = await chat.getHistory(); + expect(history.length).to.equal(2); + expect(history[0].role).to.equal('user'); + expect(history[0].parts[0].text).to.equal( + 'What is the capital of France?' + ); + expect(history[1].role).to.equal('model'); + expect(history[1].parts[0].text?.toLowerCase()).to.include('paris'); + + expect(response1.usageMetadata).to.not.be.null; + // Token counts can vary slightly in chat context + expect(response1.usageMetadata!.promptTokenCount).to.be.closeTo( + 15, // "What is the capital of France?" + system instruction + TOKEN_COUNT_DELTA + 2 // More variance for chat context + ); + expect(response1.usageMetadata!.candidatesTokenCount).to.be.closeTo( + 8, // "Paris" + TOKEN_COUNT_DELTA + ); + expect(response1.usageMetadata!.totalTokenCount).to.be.closeTo( + 23, // "What is the capital of France?" + system instruction + "Paris" + TOKEN_COUNT_DELTA + 3 // More variance for chat context + ); + + const result2 = await chat.sendMessage('And what about Italy?'); + const response2 = result2.response; + expect(response2.text().trim().toLowerCase()).to.include('rome'); + + history = await chat.getHistory(); + expect(history.length).to.equal(4); + expect(history[2].role).to.equal('user'); + expect(history[2].parts[0].text).to.equal('And what about Italy?'); + expect(history[3].role).to.equal('model'); + expect(history[3].parts[0].text?.toLowerCase()).to.include('rome'); + + expect(response2.usageMetadata).to.not.be.null; + expect(response2.usageMetadata!.promptTokenCount).to.be.closeTo( + 28, // History + "And what about Italy?" + system instruction + TOKEN_COUNT_DELTA + 5 // More variance for chat context with history + ); + expect(response2.usageMetadata!.candidatesTokenCount).to.be.closeTo( + 8, + TOKEN_COUNT_DELTA + ); + expect(response2.usageMetadata!.totalTokenCount).to.be.closeTo( + 36, + TOKEN_COUNT_DELTA + ); + }); + }); + }); +}); diff --git a/packages/ai/integration/constants.ts b/packages/ai/integration/constants.ts new file mode 100644 index 00000000000..68aebf9eddc --- /dev/null +++ b/packages/ai/integration/constants.ts @@ -0,0 +1,81 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { initializeApp } from '@firebase/app'; +import { + AI, + Backend, + BackendType, + GoogleAIBackend, + VertexAIBackend, + getAI +} from '../src'; +import { FIREBASE_CONFIG } from './firebase-config'; + +const app = initializeApp(FIREBASE_CONFIG); + +/** + * Test config that all tests will be ran against. + */ +export type TestConfig = Readonly<{ + ai: AI; + model: string; + /** This will be used to output the test config at runtime */ + toString: () => string; +}>; + +function formatConfigAsString(config: { ai: AI; model: string }): string { + return `${backendNames.get(config.ai.backend.backendType)} ${config.model}`; +} + +const backends: readonly Backend[] = [ + new GoogleAIBackend(), + new VertexAIBackend() +]; + +const backendNames: Map = new Map([ + [BackendType.GOOGLE_AI, 'Google AI'], + [BackendType.VERTEX_AI, 'Vertex AI'] +]); + +const modelNames: readonly string[] = ['gemini-2.0-flash']; + +/** + * Array of test configurations that is iterated over to get full coverage + * of backends and models. Contains all combinations of backends and models. + */ +export const testConfigs: readonly TestConfig[] = backends.flatMap(backend => { + return modelNames.map(modelName => { + const ai = getAI(app, { backend }); + return { + ai: getAI(app, { backend }), + model: modelName, + toString: () => formatConfigAsString({ ai, model: modelName }) + }; + }); +}); + +export const TINY_IMG_BASE64 = + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABAQMAAAAl21bKAAAAA1BMVEUAAACnej3aAAAAAXRSTlMAQObYZgAAAApJREFUCNdjYAAAAAIAAeIhvDMAAAAASUVORK5CYII='; +export const IMAGE_MIME_TYPE = 'image/png'; +export const TINY_MP3_BASE64 = + 'SUQzBAAAAAAAIlRTU0UAAAAOAAADTGF2ZjYxLjcuMTAwAAAAAAAAAAAAAAD/+0DAAAAAAAAAAAAAAAAAAAAAAABJbmZvAAAADwAAAAUAAAK+AGhoaGhoaGhoaGhoaGhoaGhoaGiOjo6Ojo6Ojo6Ojo6Ojo6Ojo6OjrS0tLS0tLS0tLS0tLS0tLS0tLS02tra2tra2tra2tra2tra2tra2tr//////////////////////////wAAAABMYXZjNjEuMTkAAAAAAAAAAAAAAAAkAwYAAAAAAAACvhC6DYoAAAAAAP/7EMQAA8AAAaQAAAAgAAA0gAAABExBTUUzLjEwMFVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVV//sQxCmDwAABpAAAACAAADSAAAAEVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVX/+xDEUwPAAAGkAAAAIAAANIAAAARVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVf/7EMR8g8AAAaQAAAAgAAA0gAAABFVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVV//sQxKYDwAABpAAAACAAADSAAAAEVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVU='; +export const AUDIO_MIME_TYPE = 'audio/mpeg'; + +// Token counts are only expected to differ by at most this number of tokens. +// Set to 1 for whitespace that is not always present. +export const TOKEN_COUNT_DELTA = 1; diff --git a/packages/ai/integration/count-tokens.test.ts b/packages/ai/integration/count-tokens.test.ts new file mode 100644 index 00000000000..3256a9ed9d7 --- /dev/null +++ b/packages/ai/integration/count-tokens.test.ts @@ -0,0 +1,286 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { expect } from 'chai'; +import { + Content, + GenerationConfig, + HarmBlockMethod, + HarmBlockThreshold, + HarmCategory, + Modality, + SafetySetting, + getGenerativeModel, + Part, + CountTokensRequest, + InlineDataPart, + FileDataPart, + BackendType +} from '../src'; +import { + AUDIO_MIME_TYPE, + IMAGE_MIME_TYPE, + TINY_IMG_BASE64, + TINY_MP3_BASE64, + testConfigs +} from './constants'; +import { FIREBASE_CONFIG } from './firebase-config'; + +describe('Count Tokens', () => { + testConfigs.forEach(testConfig => { + describe(`${testConfig.toString()}`, () => { + it('text input', async () => { + const generationConfig: GenerationConfig = { + temperature: 0, + topP: 0, + responseMimeType: 'text/plain' + }; + + const safetySettings: SafetySetting[] = [ + { + category: HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + method: HarmBlockMethod.PROBABILITY + }, + { + category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + method: HarmBlockMethod.SEVERITY + }, + { + category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE + }, + { + category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE + } + ]; + + const systemInstruction: Content = { + role: 'system', + parts: [ + { + text: 'You are a friendly and helpful assistant.' + } + ] + }; + const model = getGenerativeModel(testConfig.ai, { + model: testConfig.model, + generationConfig, + systemInstruction, + safetySettings + }); + + const response = await model.countTokens('Why is the sky blue?'); + + expect(response.promptTokensDetails).to.exist; + expect(response.promptTokensDetails!.length).to.equal(1); + expect(response.promptTokensDetails![0].modality).to.equal( + Modality.TEXT + ); + if (testConfig.ai.backend.backendType === BackendType.GOOGLE_AI) { + expect(response.totalTokens).to.equal(7); + expect(response.totalBillableCharacters).to.be.undefined; + expect(response.promptTokensDetails![0].tokenCount).to.equal(7); + } else if ( + testConfig.ai.backend.backendType === BackendType.VERTEX_AI + ) { + expect(response.totalTokens).to.equal(6); + expect(response.totalBillableCharacters).to.equal(16); + expect(response.promptTokensDetails![0].tokenCount).to.equal(6); + } + }); + + it('image input', async () => { + const model = getGenerativeModel(testConfig.ai, { + model: testConfig.model + }); + const imagePart: Part = { + inlineData: { + mimeType: IMAGE_MIME_TYPE, + data: TINY_IMG_BASE64 + } + }; + const response = await model.countTokens([imagePart]); + + if (testConfig.ai.backend.backendType === BackendType.GOOGLE_AI) { + const expectedImageTokens = 259; + expect(response.totalTokens).to.equal(expectedImageTokens); + expect(response.totalBillableCharacters).to.be.undefined; // Incorrect behavior + expect(response.promptTokensDetails!.length).to.equal(2); + expect(response.promptTokensDetails![0]).to.deep.equal({ + modality: Modality.TEXT, // Note: 1 unexpected text token observed for Google AI with image-only input. + tokenCount: 1 + }); + expect(response.promptTokensDetails![1]).to.deep.equal({ + modality: Modality.IMAGE, + tokenCount: 258 + }); + } else if ( + testConfig.ai.backend.backendType === BackendType.VERTEX_AI + ) { + const expectedImageTokens = 258; + expect(response.totalTokens).to.equal(expectedImageTokens); + expect(response.totalBillableCharacters).to.be.undefined; // Incorrect behavior + expect(response.promptTokensDetails!.length).to.equal(1); + // Note: No text tokens are present for Vertex AI with image-only input. + expect(response.promptTokensDetails![0]).to.deep.equal({ + modality: Modality.IMAGE, + tokenCount: 258 + }); + expect(response.promptTokensDetails![0].tokenCount).to.equal( + expectedImageTokens + ); + } + }); + + it('audio input', async () => { + const model = getGenerativeModel(testConfig.ai, { + model: testConfig.model + }); + const audioPart: InlineDataPart = { + inlineData: { + mimeType: AUDIO_MIME_TYPE, + data: TINY_MP3_BASE64 + } + }; + + const response = await model.countTokens([audioPart]); + + expect(response.promptTokensDetails).to.exist; + const textDetails = response.promptTokensDetails!.find( + d => d.modality === Modality.TEXT + ); + const audioDetails = response.promptTokensDetails!.find( + d => d.modality === Modality.AUDIO + ); + + if (testConfig.ai.backend.backendType === BackendType.GOOGLE_AI) { + expect(response.totalTokens).to.equal(6); + expect(response.promptTokensDetails!.length).to.equal(2); + expect(textDetails).to.deep.equal({ + modality: Modality.TEXT, + tokenCount: 1 + }); + expect(audioDetails).to.deep.equal({ + modality: Modality.AUDIO, + tokenCount: 5 + }); + } else if ( + testConfig.ai.backend.backendType === BackendType.VERTEX_AI + ) { + expect(response.totalTokens).to.be.undefined; + expect(response.promptTokensDetails!.length).to.equal(1); // Note: Text modality details absent for Vertex AI with audio-only input. + expect(audioDetails).to.deep.equal({ modality: Modality.AUDIO }); // Note: Audio tokenCount is undefined for Vertex AI with audio-only input. + } + + expect(response.totalBillableCharacters).to.be.undefined; // Incorrect behavior + }); + + it('text, image, and audio input', async () => { + const model = getGenerativeModel(testConfig.ai, { + model: testConfig.model + }); + const textPart: Part = { text: 'Describe these:' }; + const imagePart: Part = { + inlineData: { mimeType: IMAGE_MIME_TYPE, data: TINY_IMG_BASE64 } + }; + const audioPart: Part = { + inlineData: { mimeType: AUDIO_MIME_TYPE, data: TINY_MP3_BASE64 } + }; + + const request: CountTokensRequest = { + contents: [{ role: 'user', parts: [textPart, imagePart, audioPart] }] + }; + const response = await model.countTokens(request); + const textDetails = response.promptTokensDetails!.find( + d => d.modality === Modality.TEXT + ); + const imageDetails = response.promptTokensDetails!.find( + d => d.modality === Modality.IMAGE + ); + const audioDetails = response.promptTokensDetails!.find( + d => d.modality === Modality.AUDIO + ); + expect(response.promptTokensDetails).to.exist; + expect(response.promptTokensDetails!.length).to.equal(3); + + expect(imageDetails).to.deep.equal({ + modality: Modality.IMAGE, + tokenCount: 258 + }); + + if (testConfig.ai.backend.backendType === BackendType.GOOGLE_AI) { + expect(response.totalTokens).to.equal(267); + expect(response.totalBillableCharacters).to.be.undefined; + expect(textDetails).to.deep.equal({ + modality: Modality.TEXT, + tokenCount: 4 + }); + expect(audioDetails).to.deep.equal({ + modality: Modality.AUDIO, + tokenCount: 5 + }); + } else if ( + testConfig.ai.backend.backendType === BackendType.VERTEX_AI + ) { + expect(response.totalTokens).to.equal(261); + expect(textDetails).to.deep.equal({ + modality: Modality.TEXT, + tokenCount: 3 + }); + const expectedText = 'Describe these:'; + expect(response.totalBillableCharacters).to.equal( + expectedText.length - 1 + ); // Note: BillableCharacters observed as (text length - 1) for Vertex AI. + expect(audioDetails).to.deep.equal({ modality: Modality.AUDIO }); // Incorrect behavior because there's no tokenCount + } + }); + + it('public storage reference', async () => { + // This test is not expected to pass when using Google AI. + if (testConfig.ai.backend.backendType === BackendType.GOOGLE_AI) { + return; + } + const model = getGenerativeModel(testConfig.ai, { + model: testConfig.model + }); + const filePart: FileDataPart = { + fileData: { + mimeType: IMAGE_MIME_TYPE, + fileUri: `gs://${FIREBASE_CONFIG.storageBucket}/images/tree.png` + } + }; + + const response = await model.countTokens([filePart]); + + const expectedFileTokens = 258; + expect(response.totalTokens).to.equal(expectedFileTokens); + expect(response.totalBillableCharacters).to.be.undefined; + expect(response.promptTokensDetails).to.exist; + expect(response.promptTokensDetails!.length).to.equal(1); + expect(response.promptTokensDetails![0].modality).to.equal( + Modality.IMAGE + ); + expect(response.promptTokensDetails![0].tokenCount).to.equal( + expectedFileTokens + ); + }); + }); + }); +}); diff --git a/packages/ai/integration/firebase-config.ts b/packages/ai/integration/firebase-config.ts new file mode 100644 index 00000000000..2527f020530 --- /dev/null +++ b/packages/ai/integration/firebase-config.ts @@ -0,0 +1,20 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import * as config from '../../../config/ci.config.json'; + +export const FIREBASE_CONFIG = config; diff --git a/packages/ai/integration/generate-content.test.ts b/packages/ai/integration/generate-content.test.ts new file mode 100644 index 00000000000..af877396cc8 --- /dev/null +++ b/packages/ai/integration/generate-content.test.ts @@ -0,0 +1,141 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { expect } from 'chai'; +import { + Content, + GenerationConfig, + HarmBlockThreshold, + HarmCategory, + Modality, + SafetySetting, + getGenerativeModel +} from '../src'; +import { testConfigs, TOKEN_COUNT_DELTA } from './constants'; + +describe('Generate Content', () => { + testConfigs.forEach(testConfig => { + describe(`${testConfig.toString()}`, () => { + const commonGenerationConfig: GenerationConfig = { + temperature: 0, + topP: 0, + responseMimeType: 'text/plain' + }; + + const commonSafetySettings: SafetySetting[] = [ + { + category: HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE + }, + { + category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE + }, + { + category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE + }, + { + category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE + } + ]; + + const commonSystemInstruction: Content = { + role: 'system', + parts: [ + { + text: 'You are a friendly and helpful assistant.' + } + ] + }; + + it('generateContent: text input, text output', async () => { + const model = getGenerativeModel(testConfig.ai, { + model: testConfig.model, + generationConfig: commonGenerationConfig, + safetySettings: commonSafetySettings, + systemInstruction: commonSystemInstruction + }); + + const result = await model.generateContent( + 'Where is Google headquarters located? Answer with the city name only.' + ); + const response = result.response; + + const trimmedText = response.text().trim(); + expect(trimmedText).to.equal('Mountain View'); + + expect(response.usageMetadata).to.not.be.null; + expect(response.usageMetadata!.promptTokenCount).to.be.closeTo( + 21, + TOKEN_COUNT_DELTA + ); + expect(response.usageMetadata!.candidatesTokenCount).to.be.closeTo( + 4, + TOKEN_COUNT_DELTA + ); + expect(response.usageMetadata!.totalTokenCount).to.be.closeTo( + 25, + TOKEN_COUNT_DELTA * 2 + ); + expect(response.usageMetadata!.promptTokensDetails).to.not.be.null; + expect(response.usageMetadata!.promptTokensDetails!.length).to.equal(1); + expect( + response.usageMetadata!.promptTokensDetails![0].modality + ).to.equal(Modality.TEXT); + expect( + response.usageMetadata!.promptTokensDetails![0].tokenCount + ).to.equal(21); + expect(response.usageMetadata!.candidatesTokensDetails).to.not.be.null; + expect( + response.usageMetadata!.candidatesTokensDetails!.length + ).to.equal(1); + expect( + response.usageMetadata!.candidatesTokensDetails![0].modality + ).to.equal(Modality.TEXT); + expect( + response.usageMetadata!.candidatesTokensDetails![0].tokenCount + ).to.be.closeTo(4, TOKEN_COUNT_DELTA); + }); + + it('generateContentStream: text input, text output', async () => { + const model = getGenerativeModel(testConfig.ai, { + model: testConfig.model, + generationConfig: commonGenerationConfig, + safetySettings: commonSafetySettings, + systemInstruction: commonSystemInstruction + }); + + const result = await model.generateContentStream( + 'Where is Google headquarters located? Answer with the city name only.' + ); + + let streamText = ''; + for await (const chunk of result.stream) { + streamText += chunk.text(); + } + expect(streamText.trim()).to.equal('Mountain View'); + + const response = await result.response; + const trimmedText = response.text().trim(); + expect(trimmedText).to.equal('Mountain View'); + expect(response.usageMetadata).to.be.undefined; // Note: This is incorrect behavior. + }); + }); + }); +}); diff --git a/packages/ai/karma.conf.js b/packages/ai/karma.conf.js index 3fe2a2f9633..e64048d5f6d 100644 --- a/packages/ai/karma.conf.js +++ b/packages/ai/karma.conf.js @@ -16,14 +16,28 @@ */ const karmaBase = require('../../config/karma.base'); +const { argv } = require('yargs'); const files = [`src/**/*.test.ts`]; module.exports = function (config) { const karmaConfig = { ...karmaBase, + + preprocessors: { + ...karmaBase.preprocessors, + 'integration/**/*.ts': ['webpack', 'sourcemap'] + }, + // files to load into karma - files, + files: (() => { + if (argv.integration) { + return ['integration/**']; + } else { + return ['src/**/*.test.ts']; + } + })(), + // frameworks to use // available frameworks: https://npmjs.org/browse/keyword/karma-adapter frameworks: ['mocha'] diff --git a/packages/ai/package.json b/packages/ai/package.json index 98e204320bc..69beb618d97 100644 --- a/packages/ai/package.json +++ b/packages/ai/package.json @@ -39,6 +39,7 @@ "test:ci": "yarn testsetup && node ../../scripts/run_tests_in_ci.js -s test", "test:skip-clone": "karma start", "test:browser": "yarn testsetup && karma start", + "test:integration": "karma start --integration", "api-report": "api-extractor run --local --verbose", "typings:public": "node ../../scripts/build/use_typings.js ./dist/ai-public.d.ts", "trusted-type-check": "tsec -p tsconfig.json --noEmit" diff --git a/packages/ai/rollup.config.js b/packages/ai/rollup.config.js index b3a1e5ebcf8..3b155335898 100644 --- a/packages/ai/rollup.config.js +++ b/packages/ai/rollup.config.js @@ -32,7 +32,12 @@ const buildPlugins = [ typescriptPlugin({ typescript, tsconfigOverride: { - exclude: [...tsconfig.exclude, '**/*.test.ts', 'test-utils'], + exclude: [ + ...tsconfig.exclude, + '**/*.test.ts', + 'test-utils', + 'integration' + ], compilerOptions: { target: 'es2017' }