From 6bdad50c7d15de90bfbe518a2d7d839cb0986d6a Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Fri, 7 Mar 2025 11:22:50 +0000 Subject: [PATCH 01/11] Add support for defining generative config at query-time (dynamic RAG) --- .github/workflows/main.yaml | 12 +- src/collections/deserialize/index.ts | 36 +- src/collections/generate/index.ts | 283 +++++----- src/collections/generate/integration.test.ts | 84 ++- src/collections/generate/types.ts | 186 ++++--- src/collections/query/check.ts | 6 + src/collections/serialize/index.ts | 136 ++++- src/collections/types/generate.ts | 169 +++++- src/grpc/searcher.ts | 1 + src/proto/v1/generative.ts | 519 ++++++++++++++++++- src/utils/dbVersion.ts | 11 + 11 files changed, 1190 insertions(+), 253 deletions(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 87c01d56..8968c421 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -8,11 +8,12 @@ on: env: WEAVIATE_124: 1.24.26 - WEAVIATE_125: 1.25.30 - WEAVIATE_126: 1.26.14 - WEAVIATE_127: 1.27.11 - WEAVIATE_128: 1.28.4 + WEAVIATE_125: 1.25.34 + WEAVIATE_126: 1.26.17 + WEAVIATE_127: 1.27.14 + WEAVIATE_128: 1.28.8 WEAVIATE_129: 1.29.0 + WEAVIATE_130: 1.30.0-dev-680e323 jobs: checks: @@ -43,7 +44,8 @@ jobs: { node: "22.x", weaviate: $WEAVIATE_128}, { node: "18.x", weaviate: $WEAVIATE_129}, { node: "20.x", weaviate: $WEAVIATE_129}, - { node: "22.x", weaviate: $WEAVIATE_129} + { node: "22.x", weaviate: $WEAVIATE_129}, + { node: "22.x", weaviate: $WEAVIATE_130} ] steps: - uses: actions/checkout@v3 diff --git a/src/collections/deserialize/index.ts b/src/collections/deserialize/index.ts index 588c2642..41f9beff 100644 --- a/src/collections/deserialize/index.ts +++ b/src/collections/deserialize/index.ts @@ -25,6 +25,8 @@ import { AggregateResult, AggregateText, AggregateType, + GenerativeConfigRuntime, + GenerativeMetadata, PropertiesMetrics, } from '../index.js'; import { referenceFromObjects } from '../references/utils.js'; @@ -207,11 +209,24 @@ export class Deserialize { }; } - public generate(reply: SearchReply): GenerativeReturn { + public generate( + reply: SearchReply + ): GenerativeReturn { return { objects: reply.results.map((result) => { return { - generated: result.metadata?.generativePresent ? result.metadata?.generative : undefined, + generated: result.metadata?.generativePresent + ? result.metadata?.generative + : result.generative + ? result.generative.values[0].result + : undefined, + generative: result.generative + ? { + text: result.generative.values[0].result, + debug: result.generative.values[0].debug, + metadata: result.generative.values[0].metadata as GenerativeMetadata, + } + : undefined, metadata: Deserialize.metadata(result.metadata), properties: this.properties(result.properties), references: this.references(result.properties), @@ -219,7 +234,18 @@ export class Deserialize { vectors: Deserialize.vectors(result.metadata), } as any; }), - generated: reply.generativeGroupedResult, + generated: + reply.generativeGroupedResult !== '' + ? reply.generativeGroupedResult + : reply.generativeGroupedResults + ? reply.generativeGroupedResults.values[0].result + : undefined, + generative: reply.generativeGroupedResults + ? { + text: reply.generativeGroupedResults?.values[0].result, + metadata: reply.generativeGroupedResults?.values[0].metadata as GenerativeMetadata, + } + : undefined, }; } @@ -252,9 +278,9 @@ export class Deserialize { }; } - public generateGroupBy(reply: SearchReply): GenerativeGroupByReturn { + public generateGroupBy(reply: SearchReply): GenerativeGroupByReturn { const objects: GroupByObject[] = []; - const groups: Record> = {}; + const groups: Record> = {}; reply.groupByResults.forEach((result) => { const objs = result.objects.map((object) => { return { diff --git a/src/collections/generate/index.ts b/src/collections/generate/index.ts index 3af6fef1..60de4f94 100644 --- a/src/collections/generate/index.ts +++ b/src/collections/generate/index.ts @@ -5,6 +5,7 @@ import { DbVersionSupport } from '../../utils/dbVersion.js'; import { WeaviateInvalidInputError } from '../../errors.js'; import { toBase64FromMedia } from '../../index.js'; +import { GenerativeSearch } from '../../proto/v1/generative.js'; import { SearchReply } from '../../proto/v1/search_get.js'; import { Deserialize } from '../deserialize/index.js'; import { Check } from '../query/check.js'; @@ -28,6 +29,7 @@ import { Serialize } from '../serialize/index.js'; import { GenerateOptions, GenerateReturn, + GenerativeConfigRuntime, GenerativeGroupByReturn, GenerativeReturn, GroupByOptions, @@ -51,107 +53,118 @@ class GenerateManager implements Generate { return new GenerateManager(new Check(connection, name, dbVersionSupport, consistencyLevel, tenant)); } - private async parseReply(reply: SearchReply) { + private async parseReply(reply: SearchReply) { const deserialize = await Deserialize.use(this.check.dbVersionSupport); - return deserialize.generate(reply); + return deserialize.generate(reply); } - private async parseGroupByReply( + private async parseGroupByReply( opts: SearchOptions | GroupByOptions | undefined, reply: SearchReply ) { const deserialize = await Deserialize.use(this.check.dbVersionSupport); return Serialize.search.isGroupBy(opts) ? deserialize.generateGroupBy(reply) - : deserialize.generate(reply); + : deserialize.generate(reply); } - public fetchObjects( - generate: GenerateOptions, + public fetchObjects( + generate: GenerateOptions, opts?: FetchObjectsOptions - ): Promise> { - return this.check - .fetchObjects(opts) - .then(({ search }) => + ): Promise> { + return Promise.all([this.check.fetchObjects(opts), this.check.supportForSingleGrouped()]) + .then(async ([{ search }, supportsSingleGrouped]) => search.withFetch({ ...Serialize.search.fetchObjects(opts), - generative: Serialize.generative(generate), + generative: await Serialize.generative({ supportsSingleGrouped }, generate), }) ) .then((reply) => this.parseReply(reply)); } - public bm25( + public bm25( query: string, - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseBm25Options - ): Promise>; - public bm25( + ): Promise>; + public bm25( query: string, - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByBm25Options - ): Promise>; - public bm25(query: string, generate: GenerateOptions, opts?: Bm25Options): GenerateReturn { - return this.check - .bm25(opts) - .then(({ search }) => + ): Promise>; + public bm25( + query: string, + generate: GenerateOptions, + opts?: Bm25Options + ): GenerateReturn { + return Promise.all([this.check.bm25(opts), this.check.supportForSingleGrouped()]) + .then(async ([{ search }, supportsSingleGrouped]) => search.withBm25({ ...Serialize.search.bm25(query, opts), - generative: Serialize.generative(generate), + generative: await Serialize.generative({ supportsSingleGrouped }, generate), }) ) .then((reply) => this.parseGroupByReply(opts, reply)); } - public hybrid( + public hybrid( query: string, - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseHybridOptions - ): Promise>; - public hybrid( + ): Promise>; + public hybrid( query: string, - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByHybridOptions - ): Promise>; - public hybrid(query: string, generate: GenerateOptions, opts?: HybridOptions): GenerateReturn { - return this.check - .hybridSearch(opts) - .then(({ search, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets }) => - search.withHybrid({ - ...Serialize.search.hybrid( - { - query, - supportsTargets, - supportsVectorsForTargets, - supportsWeightsForTargets, - }, - opts - ), - generative: Serialize.generative(generate), - }) + ): Promise>; + public hybrid( + query: string, + generate: GenerateOptions, + opts?: HybridOptions + ): GenerateReturn { + return Promise.all([this.check.hybridSearch(opts), this.check.supportForSingleGrouped()]) + .then( + async ([ + { search, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets }, + supportsSingleGrouped, + ]) => + search.withHybrid({ + ...Serialize.search.hybrid( + { + query, + supportsTargets, + supportsVectorsForTargets, + supportsWeightsForTargets, + }, + opts + ), + generative: await Serialize.generative({ supportsSingleGrouped }, generate), + }) ) .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearImage( + public nearImage( image: string | Buffer, - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseNearOptions - ): Promise>; - public nearImage( + ): Promise>; + public nearImage( image: string | Buffer, - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByNearOptions - ): Promise>; - public nearImage( + ): Promise>; + public nearImage( image: string | Buffer, - generate: GenerateOptions, + generate: GenerateOptions, opts?: NearOptions - ): GenerateReturn { - return this.check - .nearSearch(opts) - .then(({ search, supportsTargets, supportsWeightsForTargets }) => - toBase64FromMedia(image).then((image) => + ): GenerateReturn { + return Promise.all([this.check.nearSearch(opts), this.check.supportForSingleGrouped()]) + .then(([{ search, supportsTargets, supportsWeightsForTargets }, supportsSingleGrouped]) => + Promise.all([ + toBase64FromMedia(image), + Serialize.generative({ supportsSingleGrouped }, generate), + ]).then(([image, generative]) => search.withNearImage({ ...Serialize.search.nearImage( { @@ -161,27 +174,30 @@ class GenerateManager implements Generate { }, opts ), - generative: Serialize.generative(generate), + generative, }) ) ) .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearObject( + public nearObject( id: string, - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseNearOptions - ): Promise>; - public nearObject( + ): Promise>; + public nearObject( id: string, - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByNearOptions - ): Promise>; - public nearObject(id: string, generate: GenerateOptions, opts?: NearOptions): GenerateReturn { - return this.check - .nearSearch(opts) - .then(({ search, supportsTargets, supportsWeightsForTargets }) => + ): Promise>; + public nearObject( + id: string, + generate: GenerateOptions, + opts?: NearOptions + ): GenerateReturn { + return Promise.all([this.check.nearSearch(opts), this.check.supportForSingleGrouped()]) + .then(async ([{ search, supportsTargets, supportsWeightsForTargets }, supportsSingleGrouped]) => search.withNearObject({ ...Serialize.search.nearObject( { @@ -191,30 +207,29 @@ class GenerateManager implements Generate { }, opts ), - generative: Serialize.generative(generate), + generative: await Serialize.generative({ supportsSingleGrouped }, generate), }) ) .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearText( + public nearText( query: string | string[], - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseNearTextOptions - ): Promise>; - public nearText( + ): Promise>; + public nearText( query: string | string[], - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByNearTextOptions - ): Promise>; - public nearText( + ): Promise>; + public nearText( query: string | string[], - generate: GenerateOptions, + generate: GenerateOptions, opts?: NearOptions - ): GenerateReturn { - return this.check - .nearSearch(opts) - .then(({ search, supportsTargets, supportsWeightsForTargets }) => + ): GenerateReturn { + return Promise.all([this.check.nearSearch(opts), this.check.supportForSingleGrouped()]) + .then(async ([{ search, supportsTargets, supportsWeightsForTargets }, supportsSingleGrouped]) => search.withNearText({ ...Serialize.search.nearText( { @@ -224,114 +239,124 @@ class GenerateManager implements Generate { }, opts ), - generative: Serialize.generative(generate), + generative: await Serialize.generative({ supportsSingleGrouped }, generate), }) ) .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearVector( + public nearVector( vector: number[], - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseNearOptions - ): Promise>; - public nearVector( + ): Promise>; + public nearVector( vector: number[], - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByNearOptions - ): Promise>; - public nearVector( + ): Promise>; + public nearVector( vector: number[], - generate: GenerateOptions, + generate: GenerateOptions, opts?: NearOptions - ): GenerateReturn { - return this.check - .nearVector(vector, opts) - .then(({ search, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets }) => - search.withNearVector({ - ...Serialize.search.nearVector( - { - vector, - supportsTargets, - supportsVectorsForTargets, - supportsWeightsForTargets, - }, - opts - ), - generative: Serialize.generative(generate), - }) + ): GenerateReturn { + return Promise.all([this.check.nearVector(vector, opts), this.check.supportForSingleGrouped()]) + .then( + async ([ + { search, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets }, + supportsSingleGrouped, + ]) => + search.withNearVector({ + ...Serialize.search.nearVector( + { + vector, + supportsTargets, + supportsVectorsForTargets, + supportsWeightsForTargets, + }, + opts + ), + generative: await Serialize.generative({ supportsSingleGrouped }, generate), + }) ) .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearMedia( + public nearMedia( media: string | Buffer, type: NearMediaType, - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseNearOptions - ): Promise>; - public nearMedia( + ): Promise>; + public nearMedia( media: string | Buffer, type: NearMediaType, - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByNearOptions - ): Promise>; - public nearMedia( + ): Promise>; + public nearMedia( media: string | Buffer, type: NearMediaType, - generate: GenerateOptions, + generate: GenerateOptions, opts?: NearOptions - ): GenerateReturn { - return this.check - .nearSearch(opts) - .then(({ search, supportsTargets, supportsWeightsForTargets }) => { + ): GenerateReturn { + return Promise.all([this.check.nearSearch(opts), this.check.supportForSingleGrouped()]) + .then(([{ search, supportsTargets, supportsWeightsForTargets }, supportsSingleGrouped]) => { const args = { supportsTargets, supportsWeightsForTargets, }; - const generative = Serialize.generative(generate); - let send: (media: string) => Promise; + let send: (media: string, generative: GenerativeSearch) => Promise; switch (type) { case 'audio': - send = (media) => + send = (media, generative) => search.withNearAudio({ ...Serialize.search.nearAudio({ audio: media, ...args }, opts), generative, }); break; case 'depth': - send = (media) => + send = (media, generative) => search.withNearDepth({ ...Serialize.search.nearDepth({ depth: media, ...args }, opts), generative, }); break; case 'image': - send = (media) => + send = (media, generative) => search.withNearImage({ ...Serialize.search.nearImage({ image: media, ...args }, opts), generative, }); break; case 'imu': - send = (media) => - search.withNearIMU({ ...Serialize.search.nearIMU({ imu: media, ...args }, opts), generative }); + send = (media, generative) => + search.withNearIMU({ + ...Serialize.search.nearIMU({ imu: media, ...args }, opts), + generative, + }); break; case 'thermal': - send = (media) => + send = (media, generative) => search.withNearThermal({ ...Serialize.search.nearThermal({ thermal: media, ...args }, opts), generative, }); break; case 'video': - send = (media) => - search.withNearVideo({ ...Serialize.search.nearVideo({ video: media, ...args }), generative }); + send = (media, generative) => + search.withNearVideo({ + ...Serialize.search.nearVideo({ video: media, ...args }), + generative, + }); break; default: throw new WeaviateInvalidInputError(`Invalid media type: ${type}`); } - return toBase64FromMedia(media).then(send); + return Promise.all([ + toBase64FromMedia(media), + Serialize.generative({ supportsSingleGrouped }, generate), + ]).then(([media, generative]) => send(media, generative)); }) .then((reply) => this.parseGroupByReply(opts, reply)); } diff --git a/src/collections/generate/integration.test.ts b/src/collections/generate/integration.test.ts index 1be98451..5e478093 100644 --- a/src/collections/generate/integration.test.ts +++ b/src/collections/generate/integration.test.ts @@ -27,10 +27,10 @@ maybe('Testing of the collection.generate methods with a simple collection', () testProp: string; }; - const generateOpts: GenerateOptions = { + const generateOpts = { singlePrompt: 'Write a haiku about ducks for {testProp}', groupedTask: 'What is the value of testProp here?', - groupedProperties: ['testProp'], + groupedProperties: ['testProp'] as 'testProp'[], }; afterAll(() => { @@ -162,7 +162,7 @@ maybe('Testing of the groupBy collection.generate methods with a simple collecti testProp: string; }; - const generateOpts: GenerateOptions = { + const generateOpts: GenerateOptions = { singlePrompt: 'Write a haiku about ducks for {testProp}', groupedTask: 'What is the value of testProp here?', groupedProperties: ['testProp'], @@ -421,3 +421,81 @@ maybe('Testing of the collection.generate methods with a multi vector collection expect(ret.objects[1].generated).toBeDefined(); }); }); + +maybe('Testing of the collection.generate methods with runtime generative config', () => { + let client: WeaviateClient; + let collection: Collection; + const collectionName = 'TestCollectionGenerateConfigRuntime'; + + type TestCollectionGenerateConfigRuntime = { + testProp: string; + }; + + afterAll(() => { + return client.collections.delete(collectionName).catch((err) => { + console.error(err); + throw err; + }); + }); + + beforeAll(async () => { + client = await makeOpenAIClient(); + collection = client.collections.get(collectionName); + return client.collections + .create({ + name: collectionName, + properties: [ + { + name: 'testProp', + dataType: 'text', + }, + ], + }) + .then(() => { + return collection.data.insert({ + properties: { + testProp: 'test', + }, + }); + }); + }); + + it('should generate using a runtime config without search', async () => { + const query = () => + collection.generate.fetchObjects({ + singlePrompt: { + prompt: 'Write a haiku about ducks for {testProp}', + debug: true, + metadata: true, + }, + groupedTask: { + prompt: 'What is the value of testProp here?', + nonBlobProperties: ['testProp'], + metadata: true, + }, + config: { + name: 'generative-openai', + config: { + model: 'gpt-4o-mini', + }, + }, + }); + + if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 30, 0))) { + await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError); + return; + } + + const res = await query(); + expect(res.objects.length).toEqual(1); + expect(res.generated).toBeDefined(); + expect(res.generative?.text).toBeDefined(); + expect(res.generative?.metadata).toBeDefined(); + res.objects.forEach((obj) => { + expect(obj.generated).toBeDefined(); + expect(obj.generative?.text).toBeDefined(); + expect(obj.generative?.metadata).toBeDefined(); + expect(obj.generative?.debug).toBeDefined(); + }); + }); +}); diff --git a/src/collections/generate/types.ts b/src/collections/generate/types.ts index b211a46a..27548bfb 100644 --- a/src/collections/generate/types.ts +++ b/src/collections/generate/types.ts @@ -18,6 +18,7 @@ import { import { GenerateOptions, GenerateReturn, + GenerativeConfigRuntime, GenerativeGroupByReturn, GenerativeReturn, } from '../types/index.js'; @@ -31,11 +32,15 @@ interface Bm25 { * This overload is for performing a search without the `groupBy` param. * * @param {string} query - The query to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {BaseBm25Options} [opts] - The available options for performing the BM25 search. - * @return {Promise>} - The results of the search including the generated data. + * @return {Promise>} - The results of the search including the generated data. */ - bm25(query: string, generate: GenerateOptions, opts?: BaseBm25Options): Promise>; + bm25( + query: string, + generate: GenerateOptions, + opts?: BaseBm25Options + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a keyword-based BM25 search of objects in this collection. * @@ -44,15 +49,15 @@ interface Bm25 { * This overload is for performing a search with the `groupBy` param. * * @param {string} query - The query to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {GroupByBm25Options} opts - The available options for performing the BM25 search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ - bm25( + bm25( query: string, - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByBm25Options - ): Promise>; + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a keyword-based BM25 search of objects in this collection. * @@ -61,11 +66,15 @@ interface Bm25 { * This overload is for performing a search with a programmatically defined `opts` param. * * @param {string} query - The query to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {Bm25Options} [opts] - The available options for performing the BM25 search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @return {GenerateReturn} - The results of the search including the generated data. */ - bm25(query: string, generate: GenerateOptions, opts?: Bm25Options): GenerateReturn; + bm25( + query: string, + generate: GenerateOptions, + opts?: Bm25Options + ): GenerateReturn; } interface Hybrid { @@ -77,15 +86,15 @@ interface Hybrid { * This overload is for performing a search without the `groupBy` param. * * @param {string} query - The query to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {BaseHybridOptions} [opts] - The available options for performing the hybrid search. - * @return {Promise>} - The results of the search including the generated data. + * @return {Promise>} - The results of the search including the generated data. */ - hybrid( + hybrid( query: string, - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseHybridOptions - ): Promise>; + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of an object search in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. * @@ -94,15 +103,15 @@ interface Hybrid { * This overload is for performing a search with the `groupBy` param. * * @param {string} query - The query to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {GroupByHybridOptions} opts - The available options for performing the hybrid search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ - hybrid( + hybrid( query: string, - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByHybridOptions - ): Promise>; + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of an object search in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. * @@ -111,11 +120,15 @@ interface Hybrid { * This overload is for performing a search with a programmatically defined `opts` param. * * @param {string} query - The query to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {HybridOptions} [opts] - The available options for performing the hybrid search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @return {GenerateReturn} - The results of the search including the generated data. */ - hybrid(query: string, generate: GenerateOptions, opts?: HybridOptions): GenerateReturn; + hybrid( + query: string, + generate: GenerateOptions, + opts?: HybridOptions + ): GenerateReturn; } interface NearMedia { @@ -130,16 +143,16 @@ interface NearMedia { * * @param {string | Buffer} media - The media file to search on. This can be a base64 string, a file path string, or a buffer. * @param {NearMediaType} type - The type of media to search on. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {BaseNearOptions} [opts] - The available options for performing the near-media search. - * @return {Promise>} - The results of the search including the generated data. + * @return {Promise>} - The results of the search including the generated data. */ - nearMedia( + nearMedia( media: string | Buffer, type: NearMediaType, - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseNearOptions - ): Promise>; + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-audio object search in this collection using an audio-capable vectorization module and vector-based similarity search. * @@ -151,16 +164,16 @@ interface NearMedia { * * @param {string | Buffer} media - The media file to search on. This can be a base64 string, a file path string, or a buffer. * @param {NearMediaType} type - The type of media to search on. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {GroupByNearOptions} opts - The available options for performing the near-media search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ - nearMedia( + nearMedia( media: string | Buffer, type: NearMediaType, - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByNearOptions - ): Promise>; + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-audio object search in this collection using an audio-capable vectorization module and vector-based similarity search. * @@ -172,16 +185,16 @@ interface NearMedia { * * @param {string | Buffer} media - The media to search on. This can be a base64 string, a file path string, or a buffer. * @param {NearMediaType} type - The type of media to search on. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {NearOptions} [opts] - The available options for performing the near-media search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @return {GenerateReturn} - The results of the search including the generated data. */ - nearMedia( + nearMedia( media: string | Buffer, type: NearMediaType, - generate: GenerateOptions, + generate: GenerateOptions, opts?: NearOptions - ): GenerateReturn; + ): GenerateReturn; } interface NearObject { @@ -193,15 +206,15 @@ interface NearObject { * This overload is for performing a search without the `groupBy` param. * * @param {string} id - The ID of the object to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {BaseNearOptions} [opts] - The available options for performing the near-object search. - * @return {Promise>} - The results of the search including the generated data. + * @return {Promise>} - The results of the search including the generated data. */ - nearObject( + nearObject( id: string, - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseNearOptions - ): Promise>; + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-object object search in this collection using a vector-based similarity search. * @@ -210,15 +223,15 @@ interface NearObject { * This overload is for performing a search with the `groupBy` param. * * @param {string} id - The ID of the object to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {GroupByNearOptions} opts - The available options for performing the near-object search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ - nearObject( + nearObject( id: string, - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByNearOptions - ): Promise>; + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-object object search in this collection using a vector-based similarity search. * @@ -227,11 +240,15 @@ interface NearObject { * This overload is for performing a search with a programmatically defined `opts` param. * * @param {string} id - The ID of the object to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {NearOptions} [opts] - The available options for performing the near-object search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @return {GenerateReturn} - The results of the search including the generated data. */ - nearObject(id: string, generate: GenerateOptions, opts?: NearOptions): GenerateReturn; + nearObject( + id: string, + generate: GenerateOptions, + opts?: NearOptions + ): GenerateReturn; } interface NearText { @@ -245,15 +262,15 @@ interface NearText { * This overload is for performing a search without the `groupBy` param. * * @param {string | string[]} query - The query to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {BaseNearTextOptions} [opts] - The available options for performing the near-text search. - * @return {Promise>} - The results of the search including the generated data. + * @return {Promise>} - The results of the search including the generated data. */ - nearText( + nearText( query: string | string[], - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseNearTextOptions - ): Promise>; + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-image object search in this collection using the image-capable vectorization module and vector-based similarity search. * @@ -264,15 +281,15 @@ interface NearText { * This overload is for performing a search with the `groupBy` param. * * @param {string | string[]} query - The query to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {GroupByNearTextOptions} opts - The available options for performing the near-text search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ - nearText( + nearText( query: string | string[], - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByNearTextOptions - ): Promise>; + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-image object search in this collection using the image-capable vectorization module and vector-based similarity search. * @@ -283,15 +300,15 @@ interface NearText { * This overload is for performing a search with a programmatically defined `opts` param. * * @param {string | string[]} query - The query to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {NearTextOptions} [opts] - The available options for performing the near-text search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @return {GenerateReturn} - The results of the search including the generated data. */ - nearText( + nearText( query: string | string[], - generate: GenerateOptions, + generate: GenerateOptions, opts?: NearTextOptions - ): GenerateReturn; + ): GenerateReturn; } interface NearVector { @@ -303,15 +320,15 @@ interface NearVector { * This overload is for performing a search without the `groupBy` param. * * @param {NearVectorInputType} vector - The vector(s) to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {BaseNearOptions} [opts] - The available options for performing the near-vector search. - * @return {Promise>} - The results of the search including the generated data. + * @return {Promise>} - The results of the search including the generated data. */ - nearVector( + nearVector( vector: NearVectorInputType, - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseNearOptions - ): Promise>; + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-vector object search in this collection using vector-based similarity search. * @@ -320,15 +337,15 @@ interface NearVector { * This overload is for performing a search with the `groupBy` param. * * @param {NearVectorInputType} vector - The vector(s) to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {GroupByNearOptions} opts - The available options for performing the near-vector search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ - nearVector( + nearVector( vector: NearVectorInputType, - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByNearOptions - ): Promise>; + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-vector object search in this collection using vector-based similarity search. * @@ -337,15 +354,15 @@ interface NearVector { * This overload is for performing a search with a programmatically defined `opts` param. * * @param {NearVectorInputType} vector - The vector(s) to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {NearOptions} [opts] - The available options for performing the near-vector search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @return {GenerateReturn} - The results of the search including the generated data. */ - nearVector( + nearVector( vector: NearVectorInputType, - generate: GenerateOptions, + generate: GenerateOptions, opts?: NearOptions - ): GenerateReturn; + ): GenerateReturn; } export interface Generate @@ -355,5 +372,8 @@ export interface Generate NearObject, NearText, NearVector { - fetchObjects: (generate: GenerateOptions, opts?: FetchObjectsOptions) => Promise>; + fetchObjects: ( + generate: GenerateOptions, + opts?: FetchObjectsOptions + ) => Promise>; } diff --git a/src/collections/query/check.ts b/src/collections/query/check.ts index 291738de..81084f66 100644 --- a/src/collections/query/check.ts +++ b/src/collections/query/check.ts @@ -98,6 +98,12 @@ export class Check { return check.supports; }; + public supportForSingleGrouped = async () => { + const check = await this.dbVersionSupport.supportsSingleGrouped(); + if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message); + return check.supports; + }; + public nearSearch = (opts?: BaseNearOptions) => { return Promise.all([ this.getSearcher(), diff --git a/src/collections/serialize/index.ts b/src/collections/serialize/index.ts index 747bf00a..6ca6585c 100644 --- a/src/collections/serialize/index.ts +++ b/src/collections/serialize/index.ts @@ -25,7 +25,12 @@ import { BatchObject_Properties, BatchObject_SingleTargetRefProps, } from '../../proto/v1/batch.js'; -import { GenerativeSearch } from '../../proto/v1/generative.js'; +import { + GenerativeProvider, + GenerativeSearch, + GenerativeSearch_Grouped, + GenerativeSearch_Single, +} from '../../proto/v1/generative.js'; import { GroupBy, MetadataRequest, @@ -63,6 +68,7 @@ import { SearchNearVectorArgs, SearchNearVideoArgs, } from '../../grpc/searcher.js'; +import { toBase64FromMedia } from '../../index.js'; import { AggregateRequest_Aggregation, AggregateRequest_Aggregation_Boolean, @@ -82,6 +88,7 @@ import { ObjectArrayProperties, ObjectProperties, ObjectPropertiesValue, + TextArray, TextArrayProperties, Vectors as VectorsGrpc, } from '../../proto/v1/base.js'; @@ -97,10 +104,13 @@ import { AggregateBaseOptions, AggregateHybridOptions, AggregateNearOptions, + GenerativeConfigRuntime, GroupByAggregate, + GroupedTask, MultiTargetVectorJoin, PrimitiveKeys, PropertiesMetrics, + SinglePrompt, } from '../index.js'; import { BaseHybridOptions, @@ -818,14 +828,126 @@ export class Serialize { return vec !== undefined && !Array.isArray(vec) && Object.values(vec).some(ArrayInputGuards.is2DArray); }; - public static generative = (generative?: GenerateOptions): GenerativeSearch => { - return GenerativeSearch.fromPartial({ - singleResponsePrompt: generative?.singlePrompt, - groupedResponseTask: generative?.groupedTask, - groupedProperties: generative?.groupedProperties as string[], - }); + private static generativeQuery = async ( + generative: GenerativeConfigRuntime, + opts?: { metadata?: boolean; images?: (string | Buffer)[]; imageProperties?: string[] } + ): Promise => { + const withImages = async >( + config: T, + imgs?: (string | Buffer)[], + imgProps?: string[] + ): Promise => { + if (imgs == undefined && imgProps == undefined) { + return config; + } + return { + ...config, + images: TextArray.fromPartial({ + values: imgs ? await Promise.all(imgs.map(toBase64FromMedia)) : undefined, + }), + imageProperties: TextArray.fromPartial({ values: imgProps }), + }; + }; + + const provider = GenerativeProvider.fromPartial({ returnMetadata: opts?.metadata }); + switch (generative.name) { + case 'generative-anthropic': + provider.anthropic = await withImages(generative.config, opts?.images, opts?.imageProperties); + break; + case 'generative-anyscale': + provider.anyscale = generative.config; + break; + case 'generative-aws': + provider.aws = await withImages(generative.config, opts?.images, opts?.imageProperties); + break; + case 'generative-cohere': + provider.cohere = generative.config; + break; + case 'generative-databricks': + provider.databricks = generative.config; + break; + case 'generative-dummy': + provider.dummy = generative.config; + break; + case 'generative-friendliai': + provider.friendliai = generative.config; + break; + case 'generative-google': + provider.google = await withImages(generative.config, opts?.images, opts?.imageProperties); + break; + case 'generative-mistral': + provider.mistral = generative.config; + break; + case 'generative-nvidia': + provider.nvidia = generative.config; + break; + case 'generative-ollama': + provider.ollama = await withImages(generative.config, opts?.images, opts?.imageProperties); + break; + case 'generative-openai': + provider.openai = await withImages(generative.config, opts?.images, opts?.imageProperties); + break; + } + return provider; + }; + + public static generative = async ( + args: { supportsSingleGrouped: boolean }, + opts?: GenerateOptions + ): Promise => { + const singlePrompt = Serialize.isSinglePrompt(opts?.singlePrompt) + ? opts.singlePrompt.prompt + : opts?.singlePrompt; + const singlePromptDebug = Serialize.isSinglePrompt(opts?.singlePrompt) + ? opts.singlePrompt.debug + : undefined; + + const groupedTask = Serialize.isGroupedTask(opts?.groupedTask) + ? opts.groupedTask.prompt + : opts?.groupedTask; + const groupedProperties = Serialize.isGroupedTask(opts?.groupedTask) + ? opts.groupedTask.nonBlobProperties + : opts?.groupedProperties; + + const singleOpts = Serialize.isSinglePrompt(opts?.singlePrompt) ? opts.singlePrompt : undefined; + const groupedOpts = Serialize.isGroupedTask(opts?.groupedTask) ? opts.groupedTask : undefined; + + return args.supportsSingleGrouped + ? GenerativeSearch.fromPartial({ + single: opts?.singlePrompt + ? GenerativeSearch_Single.fromPartial({ + prompt: singlePrompt, + debug: singlePromptDebug, + queries: opts.config ? [await Serialize.generativeQuery(opts.config, singleOpts)] : undefined, + }) + : undefined, + grouped: opts?.groupedTask + ? GenerativeSearch_Grouped.fromPartial({ + task: groupedTask, + queries: opts.config + ? [await Serialize.generativeQuery(opts.config, groupedOpts)] + : undefined, + properties: groupedProperties + ? TextArray.fromPartial({ values: groupedProperties as string[] }) + : undefined, + }) + : undefined, + }) + : GenerativeSearch.fromPartial({ + singleResponsePrompt: singlePrompt, + groupedResponseTask: groupedTask, + groupedProperties: groupedProperties as string[], + }); }; + public static isSinglePrompt(arg?: string | SinglePrompt): arg is SinglePrompt { + return typeof arg !== 'string' && arg !== undefined && arg.prompt !== undefined; + } + + public static isGroupedTask(arg?: string | GroupedTask): arg is GroupedTask { + return typeof arg !== 'string' && arg !== undefined && arg.prompt !== undefined; + } + private static bm25QueryProperties = ( properties?: (PrimitiveKeys | Bm25QueryProperty)[] ): string[] | undefined => { diff --git a/src/collections/types/generate.ts b/src/collections/types/generate.ts index b3f6bac2..31eca2e2 100644 --- a/src/collections/types/generate.ts +++ b/src/collections/types/generate.ts @@ -1,54 +1,183 @@ +import { + GenerativeAWS as GenerativeAWSGRPC, + GenerativeAWSMetadata, + GenerativeAnthropic as GenerativeAnthropicGRPC, + GenerativeAnthropicMetadata, + GenerativeAnyscale as GenerativeAnyscaleGRPC, + GenerativeAnyscaleMetadata, + GenerativeCohere as GenerativeCohereGRPC, + GenerativeCohereMetadata, + GenerativeDatabricks as GenerativeDatabricksGRPC, + GenerativeDatabricksMetadata, + GenerativeDebug, + GenerativeDummy as GenerativeDummyGRPC, + GenerativeDummyMetadata, + GenerativeFriendliAI as GenerativeFriendliAIGRPC, + GenerativeFriendliAIMetadata, + GenerativeGoogle as GenerativeGoogleGRPC, + GenerativeGoogleMetadata, + GenerativeMistral as GenerativeMistralGRPC, + GenerativeMistralMetadata, + GenerativeNvidia as GenerativeNvidiaGRPC, + GenerativeNvidiaMetadata, + GenerativeOllama as GenerativeOllamaGRPC, + GenerativeOllamaMetadata, + GenerativeOpenAI as GenerativeOpenAIGRPC, + GenerativeOpenAIMetadata, +} from '../../proto/v1/generative.js'; +import { ModuleConfig } from '../index.js'; import { GroupByObject, GroupByResult, WeaviateGenericObject, WeaviateNonGenericObject } from './query.js'; -export type GenerativeGenericObject = WeaviateGenericObject & { - /** The LLM-generated output applicable to this single object. */ +export type GenerativeGenericObject< + T, + C extends GenerativeConfigRuntime | undefined +> = WeaviateGenericObject & { + /** @deprecated (use `generative.text` instead) The LLM-generated output applicable to this single object. */ generated?: string; + /** Generative data returned from the LLM inference on this object. */ + generative?: GenerativeSingle; }; -export type GenerativeNonGenericObject = WeaviateNonGenericObject & { - /** The LLM-generated output applicable to this single object. */ - generated?: string; -}; +export type GenerativeNonGenericObject = + WeaviateNonGenericObject & { + /** @deprecated (use `generative.text` instead) The LLM-generated output applicable to this single object. */ + generated?: string; + /** Generative data returned from the LLM inference on this object. */ + generative?: GenerativeSingle; + }; /** An object belonging to a collection as returned by the methods in the `collection.generate` namespace. * * Depending on the generic type `T`, the object will have subfields that map from `T`'s specific type definition. * If not, then the object will be non-generic and have a `properties` field that maps from a generic string to a `WeaviateField`. */ -export type GenerativeObject = T extends undefined - ? GenerativeNonGenericObject - : GenerativeGenericObject; +export type GenerativeObject = T extends undefined + ? GenerativeNonGenericObject + : GenerativeGenericObject; + +export type GenerativeSingle = { + debug?: GenerativeDebug; + metadata?: GenerativeMetadata; + text?: string; +}; + +export type GenerativeGrouped = { + metadata?: GenerativeMetadata; + text?: string; +}; /** The return of a query method in the `collection.generate` namespace. */ -export type GenerativeReturn = { +export type GenerativeReturn = { /** The objects that were found by the query. */ - objects: GenerativeObject[]; - /** The LLM-generated output applicable to this query as a whole. */ + objects: GenerativeObject[]; + /** @deprecated (use `generative.text` instead) The LLM-generated output applicable to this query as a whole. */ generated?: string; + generative?: GenerativeGrouped; }; -export type GenerativeGroupByResult = GroupByResult & { +export type GenerativeGroupByResult = GroupByResult & { + /** @deprecated (use `generative.text` instead) The LLM-generated output applicable to this query as a whole. */ generated?: string; + generative?: GenerativeSingle; }; /** The return of a query method in the `collection.generate` namespace where the `groupBy` argument was specified. */ -export type GenerativeGroupByReturn = { +export type GenerativeGroupByReturn = { /** The objects that were found by the query. */ objects: GroupByObject[]; /** The groups that were created by the query. */ - groups: Record>; - /** The LLM-generated output applicable to this query as a whole. */ + groups: Record>; + /** @deprecated (use `generative.text` instead) The LLM-generated output applicable to this query as a whole. */ generated?: string; + generative?: GenerativeGrouped; }; /** Options available when defining queries using methods in the `collection.generate` namespace. */ -export type GenerateOptions = { +export type GenerateOptions = { /** The prompt to use when generating content relevant to each object of the collection individually. */ - singlePrompt?: string; + singlePrompt?: string | SinglePrompt; /** The prompt to use when generating content relevant to objects returned by the query as a whole. */ - groupedTask?: string; + groupedTask?: string | GroupedTask; /** The properties to use as context to be injected into the `groupedTask` prompt when performing the grouped generation. */ groupedProperties?: T extends undefined ? string[] : (keyof T)[]; + config?: C; +}; + +export type SinglePrompt = { + prompt: string; + debug?: boolean; + metadata?: boolean; + images?: (string | Buffer)[]; + imageProperties?: string[]; +}; + +export type GroupedTask = { + prompt: string; + metadata?: boolean; + nonBlobProperties?: T extends undefined ? string[] : (keyof T)[]; + images?: (string | Buffer)[]; + imageProperties?: string[]; }; -export type GenerateReturn = Promise> | Promise>; +export type GenerativeConfigRuntime = + | ModuleConfig<'generative-anthropic', GenerativeAnthropicConfigRuntime> + | ModuleConfig<'generative-anyscale', GenerativeAnyscaleConfigRuntime> + | ModuleConfig<'generative-aws', GenerativeAWSConfigRuntime> + | ModuleConfig<'generative-cohere', GenerativeCohereConfigRuntime> + | ModuleConfig<'generative-databricks', GenerativeDatabricksConfigRuntime> + | ModuleConfig<'generative-dummy', GenerativeDummyConfigRuntime> + | ModuleConfig<'generative-friendliai', GenerativeFriendliAIConfigRuntime> + | ModuleConfig<'generative-google', GenerativeGoogleConfigRuntime> + | ModuleConfig<'generative-mistral', GenerativeMistralConfigRuntime> + | ModuleConfig<'generative-nvidia', GenerativeNvidiaConfigRuntime> + | ModuleConfig<'generative-ollama', GenerativeOllamaConfigRuntime> + | ModuleConfig<'generative-openai', GenerativeOpenAIConfigRuntime>; + +export type GenerativeMetadata = C extends undefined + ? never + : C extends infer R extends GenerativeConfigRuntime + ? R['name'] extends 'generative-anthropic' + ? GenerativeAnthropicMetadata + : R['name'] extends 'generative-anyscale' + ? GenerativeAnyscaleMetadata + : R['name'] extends 'generative-aws' + ? GenerativeAWSMetadata + : R['name'] extends 'generative-cohere' + ? GenerativeCohereMetadata + : R['name'] extends 'generative-databricks' + ? GenerativeDatabricksMetadata + : R['name'] extends 'generative-dummy' + ? GenerativeDummyMetadata + : R['name'] extends 'generative-friendliai' + ? GenerativeFriendliAIMetadata + : R['name'] extends 'generative-google' + ? GenerativeGoogleMetadata + : R['name'] extends 'generative-mistral' + ? GenerativeMistralMetadata + : R['name'] extends 'generative-nvidia' + ? GenerativeNvidiaMetadata + : R['name'] extends 'generative-ollama' + ? GenerativeOllamaMetadata + : R['name'] extends 'generative-openai' + ? GenerativeOpenAIMetadata + : never + : never; + +export type GenerateReturn = + | Promise> + | Promise>; + +type omitFields = 'images' | 'imageProperties'; + +export type GenerativeAnthropicConfigRuntime = Omit; +export type GenerativeAnyscaleConfigRuntime = Omit; +export type GenerativeAWSConfigRuntime = Omit; +export type GenerativeCohereConfigRuntime = Omit; +export type GenerativeDatabricksConfigRuntime = Omit; +export type GenerativeDummyConfigRuntime = Omit; +export type GenerativeFriendliAIConfigRuntime = Omit; +export type GenerativeGoogleConfigRuntime = Omit; +export type GenerativeMistralConfigRuntime = Omit; +export type GenerativeNvidiaConfigRuntime = Omit; +export type GenerativeOllamaConfigRuntime = Omit; +export type GenerativeOpenAIConfigRuntime = Omit; diff --git a/src/grpc/searcher.ts b/src/grpc/searcher.ts index 6dc6103e..20496c6f 100644 --- a/src/grpc/searcher.ts +++ b/src/grpc/searcher.ts @@ -171,6 +171,7 @@ export default class Searcher extends Base implements Search { tenant: this.tenant, uses123Api: true, uses125Api: true, + uses127Api: true, }, { metadata: this.metadata, diff --git a/src/proto/v1/generative.ts b/src/proto/v1/generative.ts index 12b1619f..2abae4ba 100644 --- a/src/proto/v1/generative.ts +++ b/src/proto/v1/generative.ts @@ -51,6 +51,7 @@ export interface GenerativeProvider { google?: GenerativeGoogle | undefined; databricks?: GenerativeDatabricks | undefined; friendliai?: GenerativeFriendliAI | undefined; + nvidia?: GenerativeNvidia | undefined; } export interface GenerativeAnthropic { @@ -61,6 +62,8 @@ export interface GenerativeAnthropic { topK?: number | undefined; topP?: number | undefined; stopSequences?: TextArray | undefined; + images?: TextArray | undefined; + imageProperties?: TextArray | undefined; } export interface GenerativeAnyscale { @@ -77,6 +80,8 @@ export interface GenerativeAWS { endpoint?: string | undefined; targetModel?: string | undefined; targetVariant?: string | undefined; + images?: TextArray | undefined; + imageProperties?: TextArray | undefined; } export interface GenerativeCohere { @@ -106,6 +111,8 @@ export interface GenerativeOllama { apiEndpoint?: string | undefined; model?: string | undefined; temperature?: number | undefined; + images?: TextArray | undefined; + imageProperties?: TextArray | undefined; } export interface GenerativeOpenAI { @@ -122,6 +129,8 @@ export interface GenerativeOpenAI { resourceName?: string | undefined; deploymentId?: string | undefined; isAzure?: boolean | undefined; + images?: TextArray | undefined; + imageProperties?: TextArray | undefined; } export interface GenerativeGoogle { @@ -137,6 +146,8 @@ export interface GenerativeGoogle { projectId?: string | undefined; endpointId?: string | undefined; region?: string | undefined; + images?: TextArray | undefined; + imageProperties?: TextArray | undefined; } export interface GenerativeDatabricks { @@ -162,6 +173,14 @@ export interface GenerativeFriendliAI { topP?: number | undefined; } +export interface GenerativeNvidia { + baseUrl?: string | undefined; + model?: string | undefined; + temperature?: number | undefined; + topP?: number | undefined; + maxTokens?: number | undefined; +} + export interface GenerativeAnthropicMetadata { usage: GenerativeAnthropicMetadata_Usage | undefined; } @@ -273,6 +292,16 @@ export interface GenerativeFriendliAIMetadata_Usage { totalTokens?: number | undefined; } +export interface GenerativeNvidiaMetadata { + usage?: GenerativeNvidiaMetadata_Usage | undefined; +} + +export interface GenerativeNvidiaMetadata_Usage { + promptTokens?: number | undefined; + completionTokens?: number | undefined; + totalTokens?: number | undefined; +} + export interface GenerativeMetadata { anthropic?: GenerativeAnthropicMetadata | undefined; anyscale?: GenerativeAnyscaleMetadata | undefined; @@ -285,6 +314,7 @@ export interface GenerativeMetadata { google?: GenerativeGoogleMetadata | undefined; databricks?: GenerativeDatabricksMetadata | undefined; friendliai?: GenerativeFriendliAIMetadata | undefined; + nvidia?: GenerativeNvidiaMetadata | undefined; } export interface GenerativeReply { @@ -630,6 +660,7 @@ function createBaseGenerativeProvider(): GenerativeProvider { google: undefined, databricks: undefined, friendliai: undefined, + nvidia: undefined, }; } @@ -671,6 +702,9 @@ export const GenerativeProvider = { if (message.friendliai !== undefined) { GenerativeFriendliAI.encode(message.friendliai, writer.uint32(98).fork()).ldelim(); } + if (message.nvidia !== undefined) { + GenerativeNvidia.encode(message.nvidia, writer.uint32(106).fork()).ldelim(); + } return writer; }, @@ -765,6 +799,13 @@ export const GenerativeProvider = { message.friendliai = GenerativeFriendliAI.decode(reader, reader.uint32()); continue; + case 13: + if (tag !== 106) { + break; + } + + message.nvidia = GenerativeNvidia.decode(reader, reader.uint32()); + continue; } if ((tag & 7) === 4 || tag === 0) { break; @@ -788,6 +829,7 @@ export const GenerativeProvider = { google: isSet(object.google) ? GenerativeGoogle.fromJSON(object.google) : undefined, databricks: isSet(object.databricks) ? GenerativeDatabricks.fromJSON(object.databricks) : undefined, friendliai: isSet(object.friendliai) ? GenerativeFriendliAI.fromJSON(object.friendliai) : undefined, + nvidia: isSet(object.nvidia) ? GenerativeNvidia.fromJSON(object.nvidia) : undefined, }; }, @@ -829,6 +871,9 @@ export const GenerativeProvider = { if (message.friendliai !== undefined) { obj.friendliai = GenerativeFriendliAI.toJSON(message.friendliai); } + if (message.nvidia !== undefined) { + obj.nvidia = GenerativeNvidia.toJSON(message.nvidia); + } return obj; }, @@ -869,6 +914,9 @@ export const GenerativeProvider = { message.friendliai = (object.friendliai !== undefined && object.friendliai !== null) ? GenerativeFriendliAI.fromPartial(object.friendliai) : undefined; + message.nvidia = (object.nvidia !== undefined && object.nvidia !== null) + ? GenerativeNvidia.fromPartial(object.nvidia) + : undefined; return message; }, }; @@ -882,6 +930,8 @@ function createBaseGenerativeAnthropic(): GenerativeAnthropic { topK: undefined, topP: undefined, stopSequences: undefined, + images: undefined, + imageProperties: undefined, }; } @@ -908,6 +958,12 @@ export const GenerativeAnthropic = { if (message.stopSequences !== undefined) { TextArray.encode(message.stopSequences, writer.uint32(58).fork()).ldelim(); } + if (message.images !== undefined) { + TextArray.encode(message.images, writer.uint32(66).fork()).ldelim(); + } + if (message.imageProperties !== undefined) { + TextArray.encode(message.imageProperties, writer.uint32(74).fork()).ldelim(); + } return writer; }, @@ -967,6 +1023,20 @@ export const GenerativeAnthropic = { message.stopSequences = TextArray.decode(reader, reader.uint32()); continue; + case 8: + if (tag !== 66) { + break; + } + + message.images = TextArray.decode(reader, reader.uint32()); + continue; + case 9: + if (tag !== 74) { + break; + } + + message.imageProperties = TextArray.decode(reader, reader.uint32()); + continue; } if ((tag & 7) === 4 || tag === 0) { break; @@ -985,6 +1055,8 @@ export const GenerativeAnthropic = { topK: isSet(object.topK) ? globalThis.Number(object.topK) : undefined, topP: isSet(object.topP) ? globalThis.Number(object.topP) : undefined, stopSequences: isSet(object.stopSequences) ? TextArray.fromJSON(object.stopSequences) : undefined, + images: isSet(object.images) ? TextArray.fromJSON(object.images) : undefined, + imageProperties: isSet(object.imageProperties) ? TextArray.fromJSON(object.imageProperties) : undefined, }; }, @@ -1011,6 +1083,12 @@ export const GenerativeAnthropic = { if (message.stopSequences !== undefined) { obj.stopSequences = TextArray.toJSON(message.stopSequences); } + if (message.images !== undefined) { + obj.images = TextArray.toJSON(message.images); + } + if (message.imageProperties !== undefined) { + obj.imageProperties = TextArray.toJSON(message.imageProperties); + } return obj; }, @@ -1028,6 +1106,12 @@ export const GenerativeAnthropic = { message.stopSequences = (object.stopSequences !== undefined && object.stopSequences !== null) ? TextArray.fromPartial(object.stopSequences) : undefined; + message.images = (object.images !== undefined && object.images !== null) + ? TextArray.fromPartial(object.images) + : undefined; + message.imageProperties = (object.imageProperties !== undefined && object.imageProperties !== null) + ? TextArray.fromPartial(object.imageProperties) + : undefined; return message; }, }; @@ -1130,6 +1214,8 @@ function createBaseGenerativeAWS(): GenerativeAWS { endpoint: undefined, targetModel: undefined, targetVariant: undefined, + images: undefined, + imageProperties: undefined, }; } @@ -1156,6 +1242,12 @@ export const GenerativeAWS = { if (message.targetVariant !== undefined) { writer.uint32(106).string(message.targetVariant); } + if (message.images !== undefined) { + TextArray.encode(message.images, writer.uint32(114).fork()).ldelim(); + } + if (message.imageProperties !== undefined) { + TextArray.encode(message.imageProperties, writer.uint32(122).fork()).ldelim(); + } return writer; }, @@ -1215,6 +1307,20 @@ export const GenerativeAWS = { message.targetVariant = reader.string(); continue; + case 14: + if (tag !== 114) { + break; + } + + message.images = TextArray.decode(reader, reader.uint32()); + continue; + case 15: + if (tag !== 122) { + break; + } + + message.imageProperties = TextArray.decode(reader, reader.uint32()); + continue; } if ((tag & 7) === 4 || tag === 0) { break; @@ -1233,6 +1339,8 @@ export const GenerativeAWS = { endpoint: isSet(object.endpoint) ? globalThis.String(object.endpoint) : undefined, targetModel: isSet(object.targetModel) ? globalThis.String(object.targetModel) : undefined, targetVariant: isSet(object.targetVariant) ? globalThis.String(object.targetVariant) : undefined, + images: isSet(object.images) ? TextArray.fromJSON(object.images) : undefined, + imageProperties: isSet(object.imageProperties) ? TextArray.fromJSON(object.imageProperties) : undefined, }; }, @@ -1259,6 +1367,12 @@ export const GenerativeAWS = { if (message.targetVariant !== undefined) { obj.targetVariant = message.targetVariant; } + if (message.images !== undefined) { + obj.images = TextArray.toJSON(message.images); + } + if (message.imageProperties !== undefined) { + obj.imageProperties = TextArray.toJSON(message.imageProperties); + } return obj; }, @@ -1274,6 +1388,12 @@ export const GenerativeAWS = { message.endpoint = object.endpoint ?? undefined; message.targetModel = object.targetModel ?? undefined; message.targetVariant = object.targetVariant ?? undefined; + message.images = (object.images !== undefined && object.images !== null) + ? TextArray.fromPartial(object.images) + : undefined; + message.imageProperties = (object.imageProperties !== undefined && object.imageProperties !== null) + ? TextArray.fromPartial(object.imageProperties) + : undefined; return message; }, }; @@ -1632,7 +1752,13 @@ export const GenerativeMistral = { }; function createBaseGenerativeOllama(): GenerativeOllama { - return { apiEndpoint: undefined, model: undefined, temperature: undefined }; + return { + apiEndpoint: undefined, + model: undefined, + temperature: undefined, + images: undefined, + imageProperties: undefined, + }; } export const GenerativeOllama = { @@ -1646,6 +1772,12 @@ export const GenerativeOllama = { if (message.temperature !== undefined) { writer.uint32(25).double(message.temperature); } + if (message.images !== undefined) { + TextArray.encode(message.images, writer.uint32(34).fork()).ldelim(); + } + if (message.imageProperties !== undefined) { + TextArray.encode(message.imageProperties, writer.uint32(42).fork()).ldelim(); + } return writer; }, @@ -1677,6 +1809,20 @@ export const GenerativeOllama = { message.temperature = reader.double(); continue; + case 4: + if (tag !== 34) { + break; + } + + message.images = TextArray.decode(reader, reader.uint32()); + continue; + case 5: + if (tag !== 42) { + break; + } + + message.imageProperties = TextArray.decode(reader, reader.uint32()); + continue; } if ((tag & 7) === 4 || tag === 0) { break; @@ -1691,6 +1837,8 @@ export const GenerativeOllama = { apiEndpoint: isSet(object.apiEndpoint) ? globalThis.String(object.apiEndpoint) : undefined, model: isSet(object.model) ? globalThis.String(object.model) : undefined, temperature: isSet(object.temperature) ? globalThis.Number(object.temperature) : undefined, + images: isSet(object.images) ? TextArray.fromJSON(object.images) : undefined, + imageProperties: isSet(object.imageProperties) ? TextArray.fromJSON(object.imageProperties) : undefined, }; }, @@ -1705,6 +1853,12 @@ export const GenerativeOllama = { if (message.temperature !== undefined) { obj.temperature = message.temperature; } + if (message.images !== undefined) { + obj.images = TextArray.toJSON(message.images); + } + if (message.imageProperties !== undefined) { + obj.imageProperties = TextArray.toJSON(message.imageProperties); + } return obj; }, @@ -1716,6 +1870,12 @@ export const GenerativeOllama = { message.apiEndpoint = object.apiEndpoint ?? undefined; message.model = object.model ?? undefined; message.temperature = object.temperature ?? undefined; + message.images = (object.images !== undefined && object.images !== null) + ? TextArray.fromPartial(object.images) + : undefined; + message.imageProperties = (object.imageProperties !== undefined && object.imageProperties !== null) + ? TextArray.fromPartial(object.imageProperties) + : undefined; return message; }, }; @@ -1735,6 +1895,8 @@ function createBaseGenerativeOpenAI(): GenerativeOpenAI { resourceName: undefined, deploymentId: undefined, isAzure: undefined, + images: undefined, + imageProperties: undefined, }; } @@ -1779,6 +1941,12 @@ export const GenerativeOpenAI = { if (message.isAzure !== undefined) { writer.uint32(104).bool(message.isAzure); } + if (message.images !== undefined) { + TextArray.encode(message.images, writer.uint32(114).fork()).ldelim(); + } + if (message.imageProperties !== undefined) { + TextArray.encode(message.imageProperties, writer.uint32(122).fork()).ldelim(); + } return writer; }, @@ -1880,6 +2048,20 @@ export const GenerativeOpenAI = { message.isAzure = reader.bool(); continue; + case 14: + if (tag !== 114) { + break; + } + + message.images = TextArray.decode(reader, reader.uint32()); + continue; + case 15: + if (tag !== 122) { + break; + } + + message.imageProperties = TextArray.decode(reader, reader.uint32()); + continue; } if ((tag & 7) === 4 || tag === 0) { break; @@ -1904,6 +2086,8 @@ export const GenerativeOpenAI = { resourceName: isSet(object.resourceName) ? globalThis.String(object.resourceName) : undefined, deploymentId: isSet(object.deploymentId) ? globalThis.String(object.deploymentId) : undefined, isAzure: isSet(object.isAzure) ? globalThis.Boolean(object.isAzure) : undefined, + images: isSet(object.images) ? TextArray.fromJSON(object.images) : undefined, + imageProperties: isSet(object.imageProperties) ? TextArray.fromJSON(object.imageProperties) : undefined, }; }, @@ -1948,6 +2132,12 @@ export const GenerativeOpenAI = { if (message.isAzure !== undefined) { obj.isAzure = message.isAzure; } + if (message.images !== undefined) { + obj.images = TextArray.toJSON(message.images); + } + if (message.imageProperties !== undefined) { + obj.imageProperties = TextArray.toJSON(message.imageProperties); + } return obj; }, @@ -1969,6 +2159,12 @@ export const GenerativeOpenAI = { message.resourceName = object.resourceName ?? undefined; message.deploymentId = object.deploymentId ?? undefined; message.isAzure = object.isAzure ?? undefined; + message.images = (object.images !== undefined && object.images !== null) + ? TextArray.fromPartial(object.images) + : undefined; + message.imageProperties = (object.imageProperties !== undefined && object.imageProperties !== null) + ? TextArray.fromPartial(object.imageProperties) + : undefined; return message; }, }; @@ -1987,6 +2183,8 @@ function createBaseGenerativeGoogle(): GenerativeGoogle { projectId: undefined, endpointId: undefined, region: undefined, + images: undefined, + imageProperties: undefined, }; } @@ -2028,6 +2226,12 @@ export const GenerativeGoogle = { if (message.region !== undefined) { writer.uint32(98).string(message.region); } + if (message.images !== undefined) { + TextArray.encode(message.images, writer.uint32(106).fork()).ldelim(); + } + if (message.imageProperties !== undefined) { + TextArray.encode(message.imageProperties, writer.uint32(114).fork()).ldelim(); + } return writer; }, @@ -2122,6 +2326,20 @@ export const GenerativeGoogle = { message.region = reader.string(); continue; + case 13: + if (tag !== 106) { + break; + } + + message.images = TextArray.decode(reader, reader.uint32()); + continue; + case 14: + if (tag !== 114) { + break; + } + + message.imageProperties = TextArray.decode(reader, reader.uint32()); + continue; } if ((tag & 7) === 4 || tag === 0) { break; @@ -2145,6 +2363,8 @@ export const GenerativeGoogle = { projectId: isSet(object.projectId) ? globalThis.String(object.projectId) : undefined, endpointId: isSet(object.endpointId) ? globalThis.String(object.endpointId) : undefined, region: isSet(object.region) ? globalThis.String(object.region) : undefined, + images: isSet(object.images) ? TextArray.fromJSON(object.images) : undefined, + imageProperties: isSet(object.imageProperties) ? TextArray.fromJSON(object.imageProperties) : undefined, }; }, @@ -2186,6 +2406,12 @@ export const GenerativeGoogle = { if (message.region !== undefined) { obj.region = message.region; } + if (message.images !== undefined) { + obj.images = TextArray.toJSON(message.images); + } + if (message.imageProperties !== undefined) { + obj.imageProperties = TextArray.toJSON(message.imageProperties); + } return obj; }, @@ -2208,6 +2434,12 @@ export const GenerativeGoogle = { message.projectId = object.projectId ?? undefined; message.endpointId = object.endpointId ?? undefined; message.region = object.region ?? undefined; + message.images = (object.images !== undefined && object.images !== null) + ? TextArray.fromPartial(object.images) + : undefined; + message.imageProperties = (object.imageProperties !== undefined && object.imageProperties !== null) + ? TextArray.fromPartial(object.imageProperties) + : undefined; return message; }, }; @@ -2574,6 +2806,125 @@ export const GenerativeFriendliAI = { }, }; +function createBaseGenerativeNvidia(): GenerativeNvidia { + return { baseUrl: undefined, model: undefined, temperature: undefined, topP: undefined, maxTokens: undefined }; +} + +export const GenerativeNvidia = { + encode(message: GenerativeNvidia, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.baseUrl !== undefined) { + writer.uint32(10).string(message.baseUrl); + } + if (message.model !== undefined) { + writer.uint32(18).string(message.model); + } + if (message.temperature !== undefined) { + writer.uint32(25).double(message.temperature); + } + if (message.topP !== undefined) { + writer.uint32(33).double(message.topP); + } + if (message.maxTokens !== undefined) { + writer.uint32(40).int64(message.maxTokens); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): GenerativeNvidia { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseGenerativeNvidia(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.baseUrl = reader.string(); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.model = reader.string(); + continue; + case 3: + if (tag !== 25) { + break; + } + + message.temperature = reader.double(); + continue; + case 4: + if (tag !== 33) { + break; + } + + message.topP = reader.double(); + continue; + case 5: + if (tag !== 40) { + break; + } + + message.maxTokens = longToNumber(reader.int64() as Long); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): GenerativeNvidia { + return { + baseUrl: isSet(object.baseUrl) ? globalThis.String(object.baseUrl) : undefined, + model: isSet(object.model) ? globalThis.String(object.model) : undefined, + temperature: isSet(object.temperature) ? globalThis.Number(object.temperature) : undefined, + topP: isSet(object.topP) ? globalThis.Number(object.topP) : undefined, + maxTokens: isSet(object.maxTokens) ? globalThis.Number(object.maxTokens) : undefined, + }; + }, + + toJSON(message: GenerativeNvidia): unknown { + const obj: any = {}; + if (message.baseUrl !== undefined) { + obj.baseUrl = message.baseUrl; + } + if (message.model !== undefined) { + obj.model = message.model; + } + if (message.temperature !== undefined) { + obj.temperature = message.temperature; + } + if (message.topP !== undefined) { + obj.topP = message.topP; + } + if (message.maxTokens !== undefined) { + obj.maxTokens = Math.round(message.maxTokens); + } + return obj; + }, + + create(base?: DeepPartial): GenerativeNvidia { + return GenerativeNvidia.fromPartial(base ?? {}); + }, + fromPartial(object: DeepPartial): GenerativeNvidia { + const message = createBaseGenerativeNvidia(); + message.baseUrl = object.baseUrl ?? undefined; + message.model = object.model ?? undefined; + message.temperature = object.temperature ?? undefined; + message.topP = object.topP ?? undefined; + message.maxTokens = object.maxTokens ?? undefined; + return message; + }, +}; + function createBaseGenerativeAnthropicMetadata(): GenerativeAnthropicMetadata { return { usage: undefined }; } @@ -4246,6 +4597,154 @@ export const GenerativeFriendliAIMetadata_Usage = { }, }; +function createBaseGenerativeNvidiaMetadata(): GenerativeNvidiaMetadata { + return { usage: undefined }; +} + +export const GenerativeNvidiaMetadata = { + encode(message: GenerativeNvidiaMetadata, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.usage !== undefined) { + GenerativeNvidiaMetadata_Usage.encode(message.usage, writer.uint32(10).fork()).ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): GenerativeNvidiaMetadata { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseGenerativeNvidiaMetadata(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.usage = GenerativeNvidiaMetadata_Usage.decode(reader, reader.uint32()); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): GenerativeNvidiaMetadata { + return { usage: isSet(object.usage) ? GenerativeNvidiaMetadata_Usage.fromJSON(object.usage) : undefined }; + }, + + toJSON(message: GenerativeNvidiaMetadata): unknown { + const obj: any = {}; + if (message.usage !== undefined) { + obj.usage = GenerativeNvidiaMetadata_Usage.toJSON(message.usage); + } + return obj; + }, + + create(base?: DeepPartial): GenerativeNvidiaMetadata { + return GenerativeNvidiaMetadata.fromPartial(base ?? {}); + }, + fromPartial(object: DeepPartial): GenerativeNvidiaMetadata { + const message = createBaseGenerativeNvidiaMetadata(); + message.usage = (object.usage !== undefined && object.usage !== null) + ? GenerativeNvidiaMetadata_Usage.fromPartial(object.usage) + : undefined; + return message; + }, +}; + +function createBaseGenerativeNvidiaMetadata_Usage(): GenerativeNvidiaMetadata_Usage { + return { promptTokens: undefined, completionTokens: undefined, totalTokens: undefined }; +} + +export const GenerativeNvidiaMetadata_Usage = { + encode(message: GenerativeNvidiaMetadata_Usage, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.promptTokens !== undefined) { + writer.uint32(8).int64(message.promptTokens); + } + if (message.completionTokens !== undefined) { + writer.uint32(16).int64(message.completionTokens); + } + if (message.totalTokens !== undefined) { + writer.uint32(24).int64(message.totalTokens); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): GenerativeNvidiaMetadata_Usage { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseGenerativeNvidiaMetadata_Usage(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 8) { + break; + } + + message.promptTokens = longToNumber(reader.int64() as Long); + continue; + case 2: + if (tag !== 16) { + break; + } + + message.completionTokens = longToNumber(reader.int64() as Long); + continue; + case 3: + if (tag !== 24) { + break; + } + + message.totalTokens = longToNumber(reader.int64() as Long); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): GenerativeNvidiaMetadata_Usage { + return { + promptTokens: isSet(object.promptTokens) ? globalThis.Number(object.promptTokens) : undefined, + completionTokens: isSet(object.completionTokens) ? globalThis.Number(object.completionTokens) : undefined, + totalTokens: isSet(object.totalTokens) ? globalThis.Number(object.totalTokens) : undefined, + }; + }, + + toJSON(message: GenerativeNvidiaMetadata_Usage): unknown { + const obj: any = {}; + if (message.promptTokens !== undefined) { + obj.promptTokens = Math.round(message.promptTokens); + } + if (message.completionTokens !== undefined) { + obj.completionTokens = Math.round(message.completionTokens); + } + if (message.totalTokens !== undefined) { + obj.totalTokens = Math.round(message.totalTokens); + } + return obj; + }, + + create(base?: DeepPartial): GenerativeNvidiaMetadata_Usage { + return GenerativeNvidiaMetadata_Usage.fromPartial(base ?? {}); + }, + fromPartial(object: DeepPartial): GenerativeNvidiaMetadata_Usage { + const message = createBaseGenerativeNvidiaMetadata_Usage(); + message.promptTokens = object.promptTokens ?? undefined; + message.completionTokens = object.completionTokens ?? undefined; + message.totalTokens = object.totalTokens ?? undefined; + return message; + }, +}; + function createBaseGenerativeMetadata(): GenerativeMetadata { return { anthropic: undefined, @@ -4259,6 +4758,7 @@ function createBaseGenerativeMetadata(): GenerativeMetadata { google: undefined, databricks: undefined, friendliai: undefined, + nvidia: undefined, }; } @@ -4297,6 +4797,9 @@ export const GenerativeMetadata = { if (message.friendliai !== undefined) { GenerativeFriendliAIMetadata.encode(message.friendliai, writer.uint32(90).fork()).ldelim(); } + if (message.nvidia !== undefined) { + GenerativeNvidiaMetadata.encode(message.nvidia, writer.uint32(98).fork()).ldelim(); + } return writer; }, @@ -4384,6 +4887,13 @@ export const GenerativeMetadata = { message.friendliai = GenerativeFriendliAIMetadata.decode(reader, reader.uint32()); continue; + case 12: + if (tag !== 98) { + break; + } + + message.nvidia = GenerativeNvidiaMetadata.decode(reader, reader.uint32()); + continue; } if ((tag & 7) === 4 || tag === 0) { break; @@ -4406,6 +4916,7 @@ export const GenerativeMetadata = { google: isSet(object.google) ? GenerativeGoogleMetadata.fromJSON(object.google) : undefined, databricks: isSet(object.databricks) ? GenerativeDatabricksMetadata.fromJSON(object.databricks) : undefined, friendliai: isSet(object.friendliai) ? GenerativeFriendliAIMetadata.fromJSON(object.friendliai) : undefined, + nvidia: isSet(object.nvidia) ? GenerativeNvidiaMetadata.fromJSON(object.nvidia) : undefined, }; }, @@ -4444,6 +4955,9 @@ export const GenerativeMetadata = { if (message.friendliai !== undefined) { obj.friendliai = GenerativeFriendliAIMetadata.toJSON(message.friendliai); } + if (message.nvidia !== undefined) { + obj.nvidia = GenerativeNvidiaMetadata.toJSON(message.nvidia); + } return obj; }, @@ -4485,6 +4999,9 @@ export const GenerativeMetadata = { message.friendliai = (object.friendliai !== undefined && object.friendliai !== null) ? GenerativeFriendliAIMetadata.fromPartial(object.friendliai) : undefined; + message.nvidia = (object.nvidia !== undefined && object.nvidia !== null) + ? GenerativeNvidiaMetadata.fromPartial(object.nvidia) + : undefined; return message; }, }; diff --git a/src/utils/dbVersion.ts b/src/utils/dbVersion.ts index 279537e2..707b5f26 100644 --- a/src/utils/dbVersion.ts +++ b/src/utils/dbVersion.ts @@ -219,6 +219,17 @@ export class DbVersionSupport { }; }); }; + + supportsSingleGrouped = () => + this.dbVersionProvider.getVersion().then((version) => ({ + version, + supports: + (version.isAtLeast(1, 27, 14) && version.isLowerThan(1, 28, 0)) || + (version.isAtLeast(1, 28, 8) && version.isLowerThan(1, 29, 0)) || + (version.isAtLeast(1, 29, 0) && version.isLowerThan(1, 30, 0)) || + version.isAtLeast(1, 30, 0), + message: this.errorMessage('Single/Grouped fields in gRPC', version.show(), '1.30.0'), + })); } const EMPTY_VERSION = ''; From a50f56f32e4fc7a9b41c5740a5ec437857cf8a63 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Fri, 7 Mar 2025 11:53:46 +0000 Subject: [PATCH 02/11] Fix unit test --- src/collections/serialize/unit.test.ts | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/collections/serialize/unit.test.ts b/src/collections/serialize/unit.test.ts index 721d1e46..d6d5d627 100644 --- a/src/collections/serialize/unit.test.ts +++ b/src/collections/serialize/unit.test.ts @@ -442,11 +442,14 @@ describe('Unit testing of Serialize', () => { }); it('should parse args for generative', () => { - const args = Serialize.generative({ - singlePrompt: 'test', - groupedProperties: ['name'], - groupedTask: 'testing', - }); + const args = Serialize.generative( + { supportsSingleGrouped: false }, + { + singlePrompt: 'test', + groupedProperties: ['name'], + groupedTask: 'testing', + } + ); expect(args).toEqual({ singleResponsePrompt: 'test', groupedProperties: ['name'], From 5eb09da5b65fdb06f5be96eae4f3d1e9a784502e Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Fri, 7 Mar 2025 13:55:38 +0000 Subject: [PATCH 03/11] Again fix unit test --- src/collections/serialize/unit.test.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/collections/serialize/unit.test.ts b/src/collections/serialize/unit.test.ts index d6d5d627..6d9f9612 100644 --- a/src/collections/serialize/unit.test.ts +++ b/src/collections/serialize/unit.test.ts @@ -441,8 +441,8 @@ describe('Unit testing of Serialize', () => { }); }); - it('should parse args for generative', () => { - const args = Serialize.generative( + it('should parse args for generative', async () => { + const args = await Serialize.generative( { supportsSingleGrouped: false }, { singlePrompt: 'test', From c16cbd3199304492b20cc4b9ce2a37a0091c0c6c Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Fri, 7 Mar 2025 14:40:38 +0000 Subject: [PATCH 04/11] Add concurrency limit to CI on a per branch basis to cancel old runs --- .github/workflows/main.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 3595cb1f..cecd2521 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -15,6 +15,10 @@ env: WEAVIATE_129: 1.29.0 WEAVIATE_130: 1.30.0-dev-680e323 +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: checks: runs-on: ubuntu-latest From befa2affbab51f7505634ce3ba555fa8875a7641 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Mon, 10 Mar 2025 09:15:59 +0000 Subject: [PATCH 05/11] Add test of string-only usage with runtime generative --- src/collections/generate/integration.test.ts | 33 +++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/src/collections/generate/integration.test.ts b/src/collections/generate/integration.test.ts index 5e478093..b93a2286 100644 --- a/src/collections/generate/integration.test.ts +++ b/src/collections/generate/integration.test.ts @@ -460,7 +460,7 @@ maybe('Testing of the collection.generate methods with runtime generative config }); }); - it('should generate using a runtime config without search', async () => { + it('should generate using a runtime config without search and with extras', async () => { const query = () => collection.generate.fetchObjects({ singlePrompt: { @@ -498,4 +498,35 @@ maybe('Testing of the collection.generate methods with runtime generative config expect(obj.generative?.debug).toBeDefined(); }); }); + + it('should generate using a runtime config without search nor extras', async () => { + const query = () => + collection.generate.fetchObjects({ + singlePrompt: 'Write a haiku about ducks for {testProp}', + groupedTask: 'What is the value of testProp here?', + config: { + name: 'generative-openai', + config: { + model: 'gpt-4o-mini', + }, + }, + }); + + if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 30, 0))) { + await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError); + return; + } + + const res = await query(); + expect(res.objects.length).toEqual(1); + expect(res.generated).toBeDefined(); + expect(res.generative?.text).toBeDefined(); + expect(res.generative?.metadata).toBeUndefined(); + res.objects.forEach((obj) => { + expect(obj.generated).toBeDefined(); + expect(obj.generative?.text).toBeDefined(); + expect(obj.generative?.metadata).toBeUndefined(); + expect(obj.generative?.debug).toBeUndefined(); + }); + }); }); From c688c838f87700016a77a08913bc8e6008be718f Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Mon, 10 Mar 2025 11:47:28 +0000 Subject: [PATCH 06/11] Add factory to produce user friendly gen runtime config objects --- src/collections/config/types/generative.ts | 7 + src/collections/configure/generative.ts | 18 ++ src/collections/configure/types/generative.ts | 6 + src/collections/generate/config.ts | 282 ++++++++++++++++++ src/collections/generate/index.ts | 1 + src/collections/generate/integration.test.ts | 15 +- src/collections/generate/unit.test.ts | 280 +++++++++++++++++ src/collections/types/generate.ts | 181 +++++++++-- 8 files changed, 757 insertions(+), 33 deletions(-) create mode 100644 src/collections/generate/config.ts create mode 100644 src/collections/generate/unit.test.ts diff --git a/src/collections/config/types/generative.ts b/src/collections/config/types/generative.ts index 667bc347..7ff426b5 100644 --- a/src/collections/config/types/generative.ts +++ b/src/collections/config/types/generative.ts @@ -58,6 +58,13 @@ export type GenerativeMistralConfig = { temperature?: number; }; +export type GenerativeNvidiaConfig = { + baseURL?: string; + maxTokens?: number; + model?: string; + temperature?: number; +}; + export type GenerativeOllamaConfig = { apiEndpoint?: string; model?: string; diff --git a/src/collections/configure/generative.ts b/src/collections/configure/generative.ts index 730f2bcb..d4ee3154 100644 --- a/src/collections/configure/generative.ts +++ b/src/collections/configure/generative.ts @@ -8,6 +8,7 @@ import { GenerativeFriendliAIConfig, GenerativeGoogleConfig, GenerativeMistralConfig, + GenerativeNvidiaConfig, GenerativeOllamaConfig, GenerativeOpenAIConfig, GenerativePaLMConfig, @@ -22,6 +23,7 @@ import { GenerativeDatabricksConfigCreate, GenerativeFriendliAIConfigCreate, GenerativeMistralConfigCreate, + GenerativeNvidiaConfigCreate, GenerativeOllamaConfigCreate, GenerativeOpenAIConfigCreate, GenerativePaLMConfigCreate, @@ -169,6 +171,22 @@ export default { config, }; }, + /** + * Create a `ModuleConfig<'generative-nvidia', GenerativeNvidiaConfig | undefined>` object for use when performing AI generation using the `generative-mistral` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/nvidia/generative) for detailed usage. + * + * @param {GenerativeNvidiaConfigCreate} [config] The configuration for the `generative-nvidia` module. + * @returns {ModuleConfig<'generative-nvidia', GenerativeNvidiaConfig | undefined>} The configuration object. + */ + nvidia( + config?: GenerativeNvidiaConfigCreate + ): ModuleConfig<'generative-nvidia', GenerativeNvidiaConfig | undefined> { + return { + name: 'generative-nvidia', + config, + }; + }, /** * Create a `ModuleConfig<'generative-ollama', GenerativeOllamaConfig | undefined>` object for use when performing AI generation using the `generative-ollama` module. * diff --git a/src/collections/configure/types/generative.ts b/src/collections/configure/types/generative.ts index 2b1a18cf..ccf22ec6 100644 --- a/src/collections/configure/types/generative.ts +++ b/src/collections/configure/types/generative.ts @@ -5,6 +5,7 @@ import { GenerativeDatabricksConfig, GenerativeFriendliAIConfig, GenerativeMistralConfig, + GenerativeNvidiaConfig, GenerativeOllamaConfig, GenerativePaLMConfig, } from '../../index.js'; @@ -44,6 +45,8 @@ export type GenerativeFriendliAIConfigCreate = GenerativeFriendliAIConfig; export type GenerativeMistralConfigCreate = GenerativeMistralConfig; +export type GenerativeNvidiaConfigCreate = GenerativeNvidiaConfig; + export type GenerativeOllamaConfigCreate = GenerativeOllamaConfig; export type GenerativeOpenAIConfigCreate = GenerativeOpenAIConfigBaseCreate & { @@ -61,6 +64,7 @@ export type GenerativeConfigCreate = | GenerativeDatabricksConfigCreate | GenerativeFriendliAIConfigCreate | GenerativeMistralConfigCreate + | GenerativeNvidiaConfigCreate | GenerativeOllamaConfigCreate | GenerativeOpenAIConfigCreate | GenerativePaLMConfigCreate @@ -81,6 +85,8 @@ export type GenerativeConfigCreateType = G extends 'generative-anthropic' ? GenerativeFriendliAIConfigCreate : G extends 'generative-mistral' ? GenerativeMistralConfigCreate + : G extends 'generative-nvidia' + ? GenerativeNvidiaConfigCreate : G extends 'generative-ollama' ? GenerativeOllamaConfigCreate : G extends 'generative-openai' diff --git a/src/collections/generate/config.ts b/src/collections/generate/config.ts new file mode 100644 index 00000000..a633bacc --- /dev/null +++ b/src/collections/generate/config.ts @@ -0,0 +1,282 @@ +import { TextArray } from '../../proto/v1/base.js'; +import { ModuleConfig } from '../config/types/index.js'; +import { + GenerativeAWSConfigRuntime, + GenerativeAnthropicConfigRuntime, + GenerativeAnyscaleConfigRuntime, + GenerativeCohereConfigRuntime, + GenerativeConfigRuntimeType, + GenerativeDatabricksConfigRuntime, + GenerativeFriendliAIConfigRuntime, + GenerativeGoogleConfigRuntime, + GenerativeMistralConfigRuntime, + GenerativeNvidiaConfigRuntime, + GenerativeOllamaConfigRuntime, + GenerativeOpenAIConfigRuntime, +} from '../index.js'; + +export const generativeConfigRuntime = { + /** + * Create a `ModuleConfig<'generative-anthropic', GenerativeConfigRuntimeType<'generative-anthropic'> | undefined>` object for use when performing runtime-specific AI generation using the `generative-anthropic` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/anthropic/generative) for detailed usage. + * + * @param {GenerativeAnthropicConfigCreateRuntime} [config] The configuration for the `generative-anthropic` module. + * @returns {ModuleConfig<'generative-anthropic', GenerativeAnthropicConfigCreateRuntime | undefined>} The configuration object. + */ + anthropic( + config?: GenerativeAnthropicConfigRuntime + ): ModuleConfig<'generative-anthropic', GenerativeConfigRuntimeType<'generative-anthropic'> | undefined> { + const { baseURL, stopSequences, ...rest } = config || {}; + return { + name: 'generative-anthropic', + config: config + ? { + ...rest, + baseUrl: baseURL, + stopSequences: TextArray.fromPartial({ values: stopSequences }), + } + : undefined, + }; + }, + /** + * Create a `ModuleConfig<'generative-anyscale', GenerativeConfigRuntimeType<'generative-anyscale'> | undefined>` object for use when performing runtime-specific AI generation using the `generative-anyscale` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/anyscale/generative) for detailed usage. + * + * @param {GenerativeAnyscaleConfigRuntime} [config] The configuration for the `generative-aws` module. + * @returns {ModuleConfig<'generative-anyscale', GenerativeConfigRuntimeType<'generative-anyscale'> | undefined>} The configuration object. + */ + anyscale( + config?: GenerativeAnyscaleConfigRuntime + ): ModuleConfig<'generative-anyscale', GenerativeConfigRuntimeType<'generative-anyscale'> | undefined> { + const { baseURL, ...rest } = config || {}; + return { + name: 'generative-anyscale', + config: config + ? { + ...rest, + baseUrl: baseURL, + } + : undefined, + }; + }, + /** + * Create a `ModuleConfig<'generative-aws', GenerativeConfigRuntimeType<'generative-aws'> | undefined>` object for use when performing runtime-specific AI generation using the `generative-aws` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/aws/generative) for detailed usage. + * + * @param {GenerativeAWSConfigRuntime} [config] The configuration for the `generative-aws` module. + * @returns {ModuleConfig<'generative-aws', GenerativeConfigRuntimeType<'generative-aws'> | undefined>} The configuration object. + */ + aws( + config?: GenerativeAWSConfigRuntime + ): ModuleConfig<'generative-aws', GenerativeConfigRuntimeType<'generative-aws'> | undefined> { + return { + name: 'generative-aws', + config, + }; + }, + /** + * Create a `ModuleConfig<'generative-openai', GenerativeConfigRuntimeType<'generative-azure-openai'>>` object for use when performing runtime-specific AI generation using the `generative-openai` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/openai/generative) for detailed usage. + * + * @param {GenerativeAzureOpenAIConfigRuntime} [config] The configuration for the `generative-openai` module. + * @returns {ModuleConfig<'generative-azure-openai', GenerativeConfigRuntimeType<'generative-azure-openai'>>} The configuration object. + */ + azureOpenAI: ( + config?: GenerativeOpenAIConfigRuntime + ): ModuleConfig<'generative-azure-openai', GenerativeConfigRuntimeType<'generative-azure-openai'>> => { + const { baseURL, model, stop, ...rest } = config || {}; + return { + name: 'generative-azure-openai', + config: config + ? { + ...rest, + baseUrl: baseURL, + model: model ?? '', + isAzure: true, + stop: TextArray.fromPartial({ values: stop }), + } + : { model: '', isAzure: true }, + }; + }, + /** + * Create a `ModuleConfig<'generative-cohere', GenerativeConfigRuntimeType<'generative-cohere'> | undefined>` object for use when performing runtime-specific AI generation using the `generative-cohere` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/cohere/generative) for detailed usage. + * + * @param {GenerativeCohereConfigRuntime} [config] The configuration for the `generative-cohere` module. + * @returns {ModuleConfig<'generative-cohere', GenerativeConfigRuntimeType<'generative-cohere'> | undefined>} The configuration object. + */ + cohere: ( + config?: GenerativeCohereConfigRuntime + ): ModuleConfig<'generative-cohere', GenerativeConfigRuntimeType<'generative-cohere'> | undefined> => { + const { baseURL, stopSequences, ...rest } = config || {}; + return { + name: 'generative-cohere', + config: config + ? { + ...rest, + baseUrl: baseURL, + stopSequences: TextArray.fromPartial({ values: stopSequences }), + } + : undefined, + }; + }, + /** + * Create a `ModuleConfig<'generative-databricks', GenerativeConfigRuntimeType<'generative-databricks'> | undefined>` object for use when performing runtime-specific AI generation using the `generative-databricks` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/databricks/generative) for detailed usage. + * + * @param {GenerativeDatabricksConfigRuntime} [config] The configuration for the `generative-databricks` module. + * @returns {ModuleConfig<'generative-databricks', GenerativeConfigRuntimeType<'generative-databricks'> | undefined>} The configuration object. + */ + databricks: ( + config?: GenerativeDatabricksConfigRuntime + ): ModuleConfig< + 'generative-databricks', + GenerativeConfigRuntimeType<'generative-databricks'> | undefined + > => { + const { stop, ...rest } = config || {}; + return { + name: 'generative-databricks', + config: config + ? { + ...rest, + stop: TextArray.fromPartial({ values: stop }), + } + : undefined, + }; + }, + /** + * Create a `ModuleConfig<'generative-friendliai', GenerativeConfigRuntimeType<'generative-friendliai'> | undefined>` object for use when performing runtime-specific AI generation using the `generative-friendliai` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/friendliai/generative) for detailed usage. + * + * @param {GenerativeFriendliAIConfigRuntime} [config] The configuration for the `generative-friendliai` module. + * @returns {ModuleConfig<'generative-databricks', GenerativeConfigRuntimeType<'generative-friendliai'> | undefined>} The configuration object. + */ + friendliai( + config?: GenerativeFriendliAIConfigRuntime + ): ModuleConfig<'generative-friendliai', GenerativeConfigRuntimeType<'generative-friendliai'> | undefined> { + const { baseURL, ...rest } = config || {}; + return { + name: 'generative-friendliai', + config: config + ? { + ...rest, + baseUrl: baseURL, + } + : undefined, + }; + }, + /** + * Create a `ModuleConfig<'generative-mistral', GenerativeConfigRuntimeType<'generative-mistral'> | undefined>` object for use when performing runtime-specific AI generation using the `generative-mistral` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/mistral/generative) for detailed usage. + * + * @param {GenerativeMistralConfigRuntime} [config] The configuration for the `generative-mistral` module. + * @returns {ModuleConfig<'generative-mistral', GenerativeConfigRuntimeType<'generative-mistral'> | undefined>} The configuration object. + */ + mistral( + config?: GenerativeMistralConfigRuntime + ): ModuleConfig<'generative-mistral', GenerativeConfigRuntimeType<'generative-mistral'> | undefined> { + const { baseURL, ...rest } = config || {}; + return { + name: 'generative-mistral', + config: config + ? { + baseUrl: baseURL, + ...rest, + } + : undefined, + }; + }, + /** + * Create a `ModuleConfig<'generative-nvidia', GenerativeConfigRuntimeType<'generative-nvidia'> | undefined>` object for use when performing runtime-specific AI generation using the `generative-mistral` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/nvidia/generative) for detailed usage. + * + * @param {GenerativeNvidiaConfigCreate} [config] The configuration for the `generative-nvidia` module. + * @returns {ModuleConfig<'generative-nvidia', GenerativeConfigRuntimeType<'generative-nvidia'> | undefined>} The configuration object. + */ + nvidia( + config?: GenerativeNvidiaConfigRuntime + ): ModuleConfig<'generative-nvidia', GenerativeConfigRuntimeType<'generative-nvidia'> | undefined> { + const { baseURL, ...rest } = config || {}; + return { + name: 'generative-nvidia', + config: config + ? { + ...rest, + baseUrl: baseURL, + } + : undefined, + }; + }, + /** + * Create a `ModuleConfig<'generative-ollama', GenerativeConfigRuntimeType<'generative-ollama'> | undefined>` object for use when performing runtime-specific AI generation using the `generative-ollama` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/ollama/generative) for detailed usage. + * + * @param {GenerativeOllamaConfigRuntime} [config] The configuration for the `generative-openai` module. + * @returns {ModuleConfig<'generative-ollama', GenerativeConfigRuntimeType<'generative-ollama'> | undefined>} The configuration object. + */ + ollama( + config?: GenerativeOllamaConfigRuntime + ): ModuleConfig<'generative-ollama', GenerativeConfigRuntimeType<'generative-ollama'> | undefined> { + return { + name: 'generative-ollama', + config, + }; + }, + /** + * Create a `ModuleConfig<'generative-openai', GenerativeConfigRuntimeType<'generative-openai'>>` object for use when performing runtime-specific AI generation using the `generative-openai` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/openai/generative) for detailed usage. + * + * @param {GenerativeOpenAIConfigRuntime} [config] The configuration for the `generative-openai` module. + * @returns {ModuleConfig<'generative-openai', GenerativeConfigRuntimeType<'generative-openai'>>} The configuration object. + */ + openAI: ( + config?: GenerativeOpenAIConfigRuntime + ): ModuleConfig<'generative-openai', GenerativeConfigRuntimeType<'generative-openai'>> => { + const { baseURL, model, stop, ...rest } = config || {}; + return { + name: 'generative-openai', + config: config + ? { + ...rest, + baseUrl: baseURL, + model: model ?? '', + isAzure: false, + stop: TextArray.fromPartial({ values: stop }), + } + : { model: '', isAzure: false }, + }; + }, + /** + * Create a `ModuleConfig<'generative-google', GenerativeConfigRuntimeType<'generative-openai'> | undefined>` object for use when performing runtime-specific AI generation using the `generative-google` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/google/generative) for detailed usage. + * + * @param {GenerativeGoogleConfigRuntime} [config] The configuration for the `generative-palm` module. + * @returns {ModuleConfig<'generative-google', GenerativeConfigRuntimeType<'generative-google'> | undefined>} The configuration object. + */ + google: ( + config?: GenerativeGoogleConfigRuntime + ): ModuleConfig<'generative-google', GenerativeConfigRuntimeType<'generative-google'> | undefined> => { + const { stopSequences, ...rest } = config || {}; + return { + name: 'generative-google', + config: config + ? { + ...rest, + stopSequences: TextArray.fromPartial({ values: stopSequences }), + } + : undefined, + }; + }, +}; diff --git a/src/collections/generate/index.ts b/src/collections/generate/index.ts index 60de4f94..3ebcff77 100644 --- a/src/collections/generate/index.ts +++ b/src/collections/generate/index.ts @@ -364,4 +364,5 @@ class GenerateManager implements Generate { export default GenerateManager.use; +export { generativeConfigRuntime } from './config.js'; export { Generate } from './types.js'; diff --git a/src/collections/generate/integration.test.ts b/src/collections/generate/integration.test.ts index b93a2286..3b6708fd 100644 --- a/src/collections/generate/integration.test.ts +++ b/src/collections/generate/integration.test.ts @@ -1,7 +1,7 @@ /* eslint-disable @typescript-eslint/no-non-null-assertion */ /* eslint-disable @typescript-eslint/no-non-null-asserted-optional-chain */ import { WeaviateUnsupportedFeatureError } from '../../errors.js'; -import weaviate, { WeaviateClient } from '../../index.js'; +import weaviate, { WeaviateClient, generativeConfigRuntime } from '../../index.js'; import { Collection } from '../collection/index.js'; import { GenerateOptions, GroupByOptions } from '../types/index.js'; @@ -460,7 +460,7 @@ maybe('Testing of the collection.generate methods with runtime generative config }); }); - it('should generate using a runtime config without search and with extras', async () => { + it.only('should generate using a runtime config without search and with extras', async () => { const query = () => collection.generate.fetchObjects({ singlePrompt: { @@ -473,12 +473,10 @@ maybe('Testing of the collection.generate methods with runtime generative config nonBlobProperties: ['testProp'], metadata: true, }, - config: { - name: 'generative-openai', - config: { - model: 'gpt-4o-mini', - }, - }, + config: generativeConfigRuntime.openAI({ + model: 'gpt-4o-mini', + stop: ['\n'], + }), }); if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 30, 0))) { @@ -508,6 +506,7 @@ maybe('Testing of the collection.generate methods with runtime generative config name: 'generative-openai', config: { model: 'gpt-4o-mini', + stop: { values: ['\n'] }, }, }, }); diff --git a/src/collections/generate/unit.test.ts b/src/collections/generate/unit.test.ts new file mode 100644 index 00000000..e2e296f0 --- /dev/null +++ b/src/collections/generate/unit.test.ts @@ -0,0 +1,280 @@ +import { GenerativeConfigRuntimeType, ModuleConfig } from '../types'; +import { generativeConfigRuntime } from './config'; + +// only tests fields that must be mapped from some public name to a gRPC name, e.g. baseURL -> baseUrl and stop: string[] -> stop: TextArray +describe('Unit testing of the generativeConfigRuntime factory methods', () => { + describe('anthropic', () => { + it('with defaults', () => { + const config = generativeConfigRuntime.anthropic(); + expect(config).toEqual< + ModuleConfig<'generative-anthropic', GenerativeConfigRuntimeType<'generative-anthropic'> | undefined> + >({ + name: 'generative-anthropic', + config: undefined, + }); + }); + it('with values', () => { + const config = generativeConfigRuntime.anthropic({ + baseURL: 'http://localhost:8080', + stopSequences: ['a', 'b', 'c'], + }); + expect(config).toEqual< + ModuleConfig<'generative-anthropic', GenerativeConfigRuntimeType<'generative-anthropic'> | undefined> + >({ + name: 'generative-anthropic', + config: { + baseUrl: 'http://localhost:8080', + stopSequences: { values: ['a', 'b', 'c'] }, + }, + }); + }); + }); + + describe('anyscale', () => { + it('with defaults', () => { + const config = generativeConfigRuntime.anyscale(); + expect(config).toEqual< + ModuleConfig<'generative-anyscale', GenerativeConfigRuntimeType<'generative-anyscale'> | undefined> + >({ + name: 'generative-anyscale', + config: undefined, + }); + }); + it('with values', () => { + const config = generativeConfigRuntime.anyscale({ + baseURL: 'http://localhost:8080', + }); + expect(config).toEqual< + ModuleConfig<'generative-anyscale', GenerativeConfigRuntimeType<'generative-anyscale'> | undefined> + >({ + name: 'generative-anyscale', + config: { + baseUrl: 'http://localhost:8080', + }, + }); + }); + }); + + describe('aws', () => { + it('with defaults', () => { + const config = generativeConfigRuntime.aws(); + expect(config).toEqual< + ModuleConfig<'generative-aws', GenerativeConfigRuntimeType<'generative-aws'> | undefined> + >({ + name: 'generative-aws', + config: undefined, + }); + }); + }); + + describe('azure-openai', () => { + it('with defaults', () => { + const config = generativeConfigRuntime.azureOpenAI(); + expect(config).toEqual< + ModuleConfig<'generative-azure-openai', GenerativeConfigRuntimeType<'generative-azure-openai'>> + >({ + name: 'generative-azure-openai', + config: { model: '', isAzure: true }, + }); + }); + it('with values', () => { + const config = generativeConfigRuntime.azureOpenAI({ + baseURL: 'http://localhost:8080', + model: 'model', + stop: ['a', 'b', 'c'], + }); + expect(config).toEqual< + ModuleConfig<'generative-azure-openai', GenerativeConfigRuntimeType<'generative-azure-openai'>> + >({ + name: 'generative-azure-openai', + config: { + baseUrl: 'http://localhost:8080', + stop: { values: ['a', 'b', 'c'] }, + model: 'model', + isAzure: true, + }, + }); + }); + }); + + describe('cohere', () => { + it('with defaults', () => { + const config = generativeConfigRuntime.cohere(); + expect(config).toEqual< + ModuleConfig<'generative-cohere', GenerativeConfigRuntimeType<'generative-cohere'> | undefined> + >({ + name: 'generative-cohere', + config: undefined, + }); + }); + it('with values', () => { + const config = generativeConfigRuntime.cohere({ + baseURL: 'http://localhost:8080', + stopSequences: ['a', 'b', 'c'], + }); + expect(config).toEqual< + ModuleConfig<'generative-cohere', GenerativeConfigRuntimeType<'generative-cohere'> | undefined> + >({ + name: 'generative-cohere', + config: { + baseUrl: 'http://localhost:8080', + stopSequences: { values: ['a', 'b', 'c'] }, + }, + }); + }); + }); + + describe('databricks', () => { + it('with defaults', () => { + const config = generativeConfigRuntime.databricks(); + expect(config).toEqual< + ModuleConfig< + 'generative-databricks', + GenerativeConfigRuntimeType<'generative-databricks'> | undefined + > + >({ + name: 'generative-databricks', + config: undefined, + }); + }); + it('with values', () => { + const config = generativeConfigRuntime.databricks({ + stop: ['a', 'b', 'c'], + }); + expect(config).toEqual< + ModuleConfig< + 'generative-databricks', + GenerativeConfigRuntimeType<'generative-databricks'> | undefined + > + >({ + name: 'generative-databricks', + config: { + stop: { values: ['a', 'b', 'c'] }, + }, + }); + }); + }); + + describe('friendliai', () => { + it('with defaults', () => { + const config = generativeConfigRuntime.friendliai(); + expect(config).toEqual< + ModuleConfig< + 'generative-friendliai', + GenerativeConfigRuntimeType<'generative-friendliai'> | undefined + > + >({ + name: 'generative-friendliai', + config: undefined, + }); + }); + it('with values', () => { + const config = generativeConfigRuntime.friendliai({ + baseURL: 'http://localhost:8080', + }); + expect(config).toEqual< + ModuleConfig< + 'generative-friendliai', + GenerativeConfigRuntimeType<'generative-friendliai'> | undefined + > + >({ + name: 'generative-friendliai', + config: { + baseUrl: 'http://localhost:8080', + }, + }); + }); + }); + + describe('mistral', () => { + it('with defaults', () => { + const config = generativeConfigRuntime.mistral(); + expect(config).toEqual< + ModuleConfig<'generative-mistral', GenerativeConfigRuntimeType<'generative-mistral'> | undefined> + >({ + name: 'generative-mistral', + config: undefined, + }); + }); + it('with values', () => { + const config = generativeConfigRuntime.mistral({ + baseURL: 'http://localhost:8080', + }); + expect(config).toEqual< + ModuleConfig<'generative-mistral', GenerativeConfigRuntimeType<'generative-mistral'> | undefined> + >({ + name: 'generative-mistral', + config: { + baseUrl: 'http://localhost:8080', + }, + }); + }); + }); + + describe('nvidia', () => { + it('with defaults', () => { + const config = generativeConfigRuntime.nvidia(); + expect(config).toEqual< + ModuleConfig<'generative-nvidia', GenerativeConfigRuntimeType<'generative-nvidia'> | undefined> + >({ + name: 'generative-nvidia', + config: undefined, + }); + }); + it('with values', () => { + const config = generativeConfigRuntime.nvidia({ + baseURL: 'http://localhost:8080', + }); + expect(config).toEqual< + ModuleConfig<'generative-nvidia', GenerativeConfigRuntimeType<'generative-nvidia'> | undefined> + >({ + name: 'generative-nvidia', + config: { + baseUrl: 'http://localhost:8080', + }, + }); + }); + }); + + describe('ollama', () => { + it('with defaults', () => { + const config = generativeConfigRuntime.ollama(); + expect(config).toEqual< + ModuleConfig<'generative-ollama', GenerativeConfigRuntimeType<'generative-ollama'> | undefined> + >({ + name: 'generative-ollama', + config: undefined, + }); + }); + }); + + describe('openai', () => { + it('with defaults', () => { + const config = generativeConfigRuntime.openAI(); + expect(config).toEqual< + ModuleConfig<'generative-openai', GenerativeConfigRuntimeType<'generative-openai'>> + >({ + name: 'generative-openai', + config: { model: '', isAzure: false }, + }); + }); + it('with values', () => { + const config = generativeConfigRuntime.openAI({ + baseURL: 'http://localhost:8080', + model: 'model', + stop: ['a', 'b', 'c'], + }); + expect(config).toEqual< + ModuleConfig<'generative-openai', GenerativeConfigRuntimeType<'generative-openai'>> + >({ + name: 'generative-openai', + config: { + baseUrl: 'http://localhost:8080', + isAzure: false, + model: 'model', + stop: { values: ['a', 'b', 'c'] }, + }, + }); + }); + }); +}); diff --git a/src/collections/types/generate.ts b/src/collections/types/generate.ts index 31eca2e2..53bf208c 100644 --- a/src/collections/types/generate.ts +++ b/src/collections/types/generate.ts @@ -119,19 +119,50 @@ export type GroupedTask = { imageProperties?: string[]; }; +type omitFields = 'images' | 'imageProperties'; + export type GenerativeConfigRuntime = - | ModuleConfig<'generative-anthropic', GenerativeAnthropicConfigRuntime> - | ModuleConfig<'generative-anyscale', GenerativeAnyscaleConfigRuntime> - | ModuleConfig<'generative-aws', GenerativeAWSConfigRuntime> - | ModuleConfig<'generative-cohere', GenerativeCohereConfigRuntime> - | ModuleConfig<'generative-databricks', GenerativeDatabricksConfigRuntime> - | ModuleConfig<'generative-dummy', GenerativeDummyConfigRuntime> - | ModuleConfig<'generative-friendliai', GenerativeFriendliAIConfigRuntime> - | ModuleConfig<'generative-google', GenerativeGoogleConfigRuntime> - | ModuleConfig<'generative-mistral', GenerativeMistralConfigRuntime> - | ModuleConfig<'generative-nvidia', GenerativeNvidiaConfigRuntime> - | ModuleConfig<'generative-ollama', GenerativeOllamaConfigRuntime> - | ModuleConfig<'generative-openai', GenerativeOpenAIConfigRuntime>; + | ModuleConfig<'generative-anthropic', GenerativeConfigRuntimeType<'generative-anthropic'>> + | ModuleConfig<'generative-anyscale', GenerativeConfigRuntimeType<'generative-anyscale'>> + | ModuleConfig<'generative-aws', GenerativeConfigRuntimeType<'generative-aws'>> + | ModuleConfig<'generative-azure-openai', GenerativeConfigRuntimeType<'generative-azure-openai'>> + | ModuleConfig<'generative-cohere', GenerativeConfigRuntimeType<'generative-cohere'>> + | ModuleConfig<'generative-databricks', GenerativeConfigRuntimeType<'generative-databricks'>> + | ModuleConfig<'generative-dummy', GenerativeConfigRuntimeType<'generative-dummy'>> + | ModuleConfig<'generative-friendliai', GenerativeConfigRuntimeType<'generative-friendliai'>> + | ModuleConfig<'generative-google', GenerativeConfigRuntimeType<'generative-google'>> + | ModuleConfig<'generative-mistral', GenerativeConfigRuntimeType<'generative-mistral'>> + | ModuleConfig<'generative-nvidia', GenerativeConfigRuntimeType<'generative-nvidia'>> + | ModuleConfig<'generative-ollama', GenerativeConfigRuntimeType<'generative-ollama'>> + | ModuleConfig<'generative-openai', GenerativeConfigRuntimeType<'generative-openai'>>; + +export type GenerativeConfigRuntimeType = G extends 'generative-anthropic' + ? Omit + : G extends 'generative-anyscale' + ? Omit + : G extends 'generative-aws' + ? Omit + : G extends 'generative-azure-openai' + ? Omit & { isAzure: true } + : G extends 'generative-cohere' + ? Omit + : G extends 'generative-databricks' + ? Omit + : G extends 'generative-google' + ? Omit + : G extends 'generative-friendliai' + ? Omit + : G extends 'generative-mistral' + ? Omit + : G extends 'generative-nvidia' + ? Omit + : G extends 'generative-ollama' + ? Omit + : G extends 'generative-openai' + ? Omit & { isAzure?: false } + : G extends 'none' + ? undefined + : Record | undefined; export type GenerativeMetadata = C extends undefined ? never @@ -167,17 +198,117 @@ export type GenerateReturn = | Promise> | Promise>; -type omitFields = 'images' | 'imageProperties'; +export type GenerativeAnthropicConfigRuntime = { + baseURL?: string | undefined; + maxTokens?: number | undefined; + model?: string | undefined; + temperature?: number | undefined; + topK?: number | undefined; + topP?: number | undefined; + stopSequences?: string[] | undefined; +}; + +export type GenerativeAnyscaleConfigRuntime = { + baseURL?: string | undefined; + model?: string | undefined; + temperature?: number | undefined; +}; + +export type GenerativeAWSConfigRuntime = { + model?: string | undefined; + temperature?: number | undefined; + service?: string | undefined; + region?: string | undefined; + endpoint?: string | undefined; + targetModel?: string | undefined; + targetVariant?: string | undefined; +}; + +export type GenerativeCohereConfigRuntime = { + baseURL?: string | undefined; + frequencyPenalty?: number | undefined; + maxTokens?: number | undefined; + model?: string | undefined; + k?: number | undefined; + p?: number | undefined; + presencePenalty?: number | undefined; + stopSequences?: string[] | undefined; + temperature?: number | undefined; +}; + +export type GenerativeDatabricksConfigRuntime = { + endpoint?: string | undefined; + model?: string | undefined; + frequencyPenalty?: number | undefined; + logProbs?: boolean | undefined; + topLogProbs?: number | undefined; + maxTokens?: number | undefined; + n?: number | undefined; + presencePenalty?: number | undefined; + stop?: string[] | undefined; + temperature?: number | undefined; + topP?: number | undefined; +}; -export type GenerativeAnthropicConfigRuntime = Omit; -export type GenerativeAnyscaleConfigRuntime = Omit; -export type GenerativeAWSConfigRuntime = Omit; -export type GenerativeCohereConfigRuntime = Omit; -export type GenerativeDatabricksConfigRuntime = Omit; -export type GenerativeDummyConfigRuntime = Omit; -export type GenerativeFriendliAIConfigRuntime = Omit; -export type GenerativeGoogleConfigRuntime = Omit; -export type GenerativeMistralConfigRuntime = Omit; -export type GenerativeNvidiaConfigRuntime = Omit; -export type GenerativeOllamaConfigRuntime = Omit; -export type GenerativeOpenAIConfigRuntime = Omit; +export type GenerativeDummyConfigRuntime = GenerativeDummyGRPC; + +export type GenerativeFriendliAIConfigRuntime = { + baseURL?: string | undefined; + model?: string | undefined; + maxTokens?: number | undefined; + temperature?: number | undefined; + n?: number | undefined; + topP?: number | undefined; +}; + +export type GenerativeGoogleConfigRuntime = { + frequencyPenalty?: number | undefined; + maxTokens?: number | undefined; + model?: string | undefined; + presencePenalty?: number | undefined; + temperature?: number | undefined; + topK?: number | undefined; + topP?: number | undefined; + stopSequences?: string[] | undefined; + apiEndpoint?: string | undefined; + projectId?: string | undefined; + endpointId?: string | undefined; + region?: string | undefined; +}; + +export type GenerativeMistralConfigRuntime = { + baseURL?: string | undefined; + maxTokens?: number | undefined; + model?: string | undefined; + temperature?: number | undefined; + topP?: number | undefined; +}; + +export type GenerativeNvidiaConfigRuntime = { + baseURL?: string | undefined; + model?: string | undefined; + temperature?: number | undefined; + topP?: number | undefined; + maxTokens?: number | undefined; +}; + +export type GenerativeOllamaConfigRuntime = { + apiEndpoint?: string | undefined; + model?: string | undefined; + temperature?: number | undefined; +}; + +export type GenerativeOpenAIConfigRuntime = { + frequencyPenalty?: number | undefined; + maxTokens?: number | undefined; + model?: string; + n?: number | undefined; + presencePenalty?: number | undefined; + stop?: string[] | undefined; + temperature?: number | undefined; + topP?: number | undefined; + baseURL?: string | undefined; + apiVersion?: string | undefined; + resourceName?: string | undefined; + deploymentId?: string | undefined; +}; From 98a932bbf1baadf1df0bd78be633ca982a77cc71 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Mon, 10 Mar 2025 11:50:49 +0000 Subject: [PATCH 07/11] Remove `it.only` from test --- src/collections/generate/integration.test.ts | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/collections/generate/integration.test.ts b/src/collections/generate/integration.test.ts index 3b6708fd..468e93bf 100644 --- a/src/collections/generate/integration.test.ts +++ b/src/collections/generate/integration.test.ts @@ -460,7 +460,7 @@ maybe('Testing of the collection.generate methods with runtime generative config }); }); - it.only('should generate using a runtime config without search and with extras', async () => { + it('should generate using a runtime config without search and with extras', async () => { const query = () => collection.generate.fetchObjects({ singlePrompt: { @@ -474,7 +474,6 @@ maybe('Testing of the collection.generate methods with runtime generative config metadata: true, }, config: generativeConfigRuntime.openAI({ - model: 'gpt-4o-mini', stop: ['\n'], }), }); From f533aaa6fae21d9a835650e598111bdf4acf4b46 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Mon, 17 Mar 2025 10:45:36 +0000 Subject: [PATCH 08/11] Update CI images --- .github/workflows/main.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index cecd2521..554171ab 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -10,9 +10,9 @@ env: WEAVIATE_124: 1.24.26 WEAVIATE_125: 1.25.34 WEAVIATE_126: 1.26.17 - WEAVIATE_127: 1.27.14 - WEAVIATE_128: 1.28.8 - WEAVIATE_129: 1.29.0 + WEAVIATE_127: 1.27.15 + WEAVIATE_128: 1.28.11 + WEAVIATE_129: 1.29.1 WEAVIATE_130: 1.30.0-dev-680e323 concurrency: From c5c93d9d1e2d57413ec9d6e3687b7bade5377e9a Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Mon, 17 Mar 2025 10:45:49 +0000 Subject: [PATCH 09/11] Update name of version checker method --- src/collections/generate/index.ts | 16 ++++++++-------- src/collections/query/check.ts | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/collections/generate/index.ts b/src/collections/generate/index.ts index 3ebcff77..799313f8 100644 --- a/src/collections/generate/index.ts +++ b/src/collections/generate/index.ts @@ -72,7 +72,7 @@ class GenerateManager implements Generate { generate: GenerateOptions, opts?: FetchObjectsOptions ): Promise> { - return Promise.all([this.check.fetchObjects(opts), this.check.supportForSingleGrouped()]) + return Promise.all([this.check.fetchObjects(opts), this.check.supportForSingleGroupedGenerative()]) .then(async ([{ search }, supportsSingleGrouped]) => search.withFetch({ ...Serialize.search.fetchObjects(opts), @@ -97,7 +97,7 @@ class GenerateManager implements Generate { generate: GenerateOptions, opts?: Bm25Options ): GenerateReturn { - return Promise.all([this.check.bm25(opts), this.check.supportForSingleGrouped()]) + return Promise.all([this.check.bm25(opts), this.check.supportForSingleGroupedGenerative()]) .then(async ([{ search }, supportsSingleGrouped]) => search.withBm25({ ...Serialize.search.bm25(query, opts), @@ -122,7 +122,7 @@ class GenerateManager implements Generate { generate: GenerateOptions, opts?: HybridOptions ): GenerateReturn { - return Promise.all([this.check.hybridSearch(opts), this.check.supportForSingleGrouped()]) + return Promise.all([this.check.hybridSearch(opts), this.check.supportForSingleGroupedGenerative()]) .then( async ([ { search, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets }, @@ -159,7 +159,7 @@ class GenerateManager implements Generate { generate: GenerateOptions, opts?: NearOptions ): GenerateReturn { - return Promise.all([this.check.nearSearch(opts), this.check.supportForSingleGrouped()]) + return Promise.all([this.check.nearSearch(opts), this.check.supportForSingleGroupedGenerative()]) .then(([{ search, supportsTargets, supportsWeightsForTargets }, supportsSingleGrouped]) => Promise.all([ toBase64FromMedia(image), @@ -196,7 +196,7 @@ class GenerateManager implements Generate { generate: GenerateOptions, opts?: NearOptions ): GenerateReturn { - return Promise.all([this.check.nearSearch(opts), this.check.supportForSingleGrouped()]) + return Promise.all([this.check.nearSearch(opts), this.check.supportForSingleGroupedGenerative()]) .then(async ([{ search, supportsTargets, supportsWeightsForTargets }, supportsSingleGrouped]) => search.withNearObject({ ...Serialize.search.nearObject( @@ -228,7 +228,7 @@ class GenerateManager implements Generate { generate: GenerateOptions, opts?: NearOptions ): GenerateReturn { - return Promise.all([this.check.nearSearch(opts), this.check.supportForSingleGrouped()]) + return Promise.all([this.check.nearSearch(opts), this.check.supportForSingleGroupedGenerative()]) .then(async ([{ search, supportsTargets, supportsWeightsForTargets }, supportsSingleGrouped]) => search.withNearText({ ...Serialize.search.nearText( @@ -260,7 +260,7 @@ class GenerateManager implements Generate { generate: GenerateOptions, opts?: NearOptions ): GenerateReturn { - return Promise.all([this.check.nearVector(vector, opts), this.check.supportForSingleGrouped()]) + return Promise.all([this.check.nearVector(vector, opts), this.check.supportForSingleGroupedGenerative()]) .then( async ([ { search, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets }, @@ -300,7 +300,7 @@ class GenerateManager implements Generate { generate: GenerateOptions, opts?: NearOptions ): GenerateReturn { - return Promise.all([this.check.nearSearch(opts), this.check.supportForSingleGrouped()]) + return Promise.all([this.check.nearSearch(opts), this.check.supportForSingleGroupedGenerative()]) .then(([{ search, supportsTargets, supportsWeightsForTargets }, supportsSingleGrouped]) => { const args = { supportsTargets, diff --git a/src/collections/query/check.ts b/src/collections/query/check.ts index 81084f66..ebe87835 100644 --- a/src/collections/query/check.ts +++ b/src/collections/query/check.ts @@ -98,7 +98,7 @@ export class Check { return check.supports; }; - public supportForSingleGrouped = async () => { + public supportForSingleGroupedGenerative = async () => { const check = await this.dbVersionSupport.supportsSingleGrouped(); if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message); return check.supports; From 72f1f22aef1f6f9e4eb01b109b02ffced898b1fa Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Mon, 17 Mar 2025 10:51:18 +0000 Subject: [PATCH 10/11] Update to use latest proto with optional model fix in openai --- src/collections/generate/config.ts | 10 ++++------ src/proto/v1/generative.ts | 12 ++++++------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/collections/generate/config.ts b/src/collections/generate/config.ts index a633bacc..1ab12b04 100644 --- a/src/collections/generate/config.ts +++ b/src/collections/generate/config.ts @@ -88,18 +88,17 @@ export const generativeConfigRuntime = { azureOpenAI: ( config?: GenerativeOpenAIConfigRuntime ): ModuleConfig<'generative-azure-openai', GenerativeConfigRuntimeType<'generative-azure-openai'>> => { - const { baseURL, model, stop, ...rest } = config || {}; + const { baseURL, stop, ...rest } = config || {}; return { name: 'generative-azure-openai', config: config ? { ...rest, baseUrl: baseURL, - model: model ?? '', isAzure: true, stop: TextArray.fromPartial({ values: stop }), } - : { model: '', isAzure: true }, + : { isAzure: true }, }; }, /** @@ -243,18 +242,17 @@ export const generativeConfigRuntime = { openAI: ( config?: GenerativeOpenAIConfigRuntime ): ModuleConfig<'generative-openai', GenerativeConfigRuntimeType<'generative-openai'>> => { - const { baseURL, model, stop, ...rest } = config || {}; + const { baseURL, stop, ...rest } = config || {}; return { name: 'generative-openai', config: config ? { ...rest, baseUrl: baseURL, - model: model ?? '', isAzure: false, stop: TextArray.fromPartial({ values: stop }), } - : { model: '', isAzure: false }, + : { isAzure: false }, }; }, /** diff --git a/src/proto/v1/generative.ts b/src/proto/v1/generative.ts index 2abae4ba..fe0805fa 100644 --- a/src/proto/v1/generative.ts +++ b/src/proto/v1/generative.ts @@ -118,7 +118,7 @@ export interface GenerativeOllama { export interface GenerativeOpenAI { frequencyPenalty?: number | undefined; maxTokens?: number | undefined; - model: string; + model?: string | undefined; n?: number | undefined; presencePenalty?: number | undefined; stop?: TextArray | undefined; @@ -1884,7 +1884,7 @@ function createBaseGenerativeOpenAI(): GenerativeOpenAI { return { frequencyPenalty: undefined, maxTokens: undefined, - model: "", + model: undefined, n: undefined, presencePenalty: undefined, stop: undefined, @@ -1908,7 +1908,7 @@ export const GenerativeOpenAI = { if (message.maxTokens !== undefined) { writer.uint32(16).int64(message.maxTokens); } - if (message.model !== "") { + if (message.model !== undefined) { writer.uint32(26).string(message.model); } if (message.n !== undefined) { @@ -2075,7 +2075,7 @@ export const GenerativeOpenAI = { return { frequencyPenalty: isSet(object.frequencyPenalty) ? globalThis.Number(object.frequencyPenalty) : undefined, maxTokens: isSet(object.maxTokens) ? globalThis.Number(object.maxTokens) : undefined, - model: isSet(object.model) ? globalThis.String(object.model) : "", + model: isSet(object.model) ? globalThis.String(object.model) : undefined, n: isSet(object.n) ? globalThis.Number(object.n) : undefined, presencePenalty: isSet(object.presencePenalty) ? globalThis.Number(object.presencePenalty) : undefined, stop: isSet(object.stop) ? TextArray.fromJSON(object.stop) : undefined, @@ -2099,7 +2099,7 @@ export const GenerativeOpenAI = { if (message.maxTokens !== undefined) { obj.maxTokens = Math.round(message.maxTokens); } - if (message.model !== "") { + if (message.model !== undefined) { obj.model = message.model; } if (message.n !== undefined) { @@ -2148,7 +2148,7 @@ export const GenerativeOpenAI = { const message = createBaseGenerativeOpenAI(); message.frequencyPenalty = object.frequencyPenalty ?? undefined; message.maxTokens = object.maxTokens ?? undefined; - message.model = object.model ?? ""; + message.model = object.model ?? undefined; message.n = object.n ?? undefined; message.presencePenalty = object.presencePenalty ?? undefined; message.stop = (object.stop !== undefined && object.stop !== null) ? TextArray.fromPartial(object.stop) : undefined; From 1a11acd5984557ab6253fbf50e2b18fa4aca1548 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Mon, 17 Mar 2025 11:51:46 +0000 Subject: [PATCH 11/11] Fix unit test --- src/collections/generate/unit.test.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/collections/generate/unit.test.ts b/src/collections/generate/unit.test.ts index e2e296f0..63ff17b9 100644 --- a/src/collections/generate/unit.test.ts +++ b/src/collections/generate/unit.test.ts @@ -74,7 +74,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => { ModuleConfig<'generative-azure-openai', GenerativeConfigRuntimeType<'generative-azure-openai'>> >({ name: 'generative-azure-openai', - config: { model: '', isAzure: true }, + config: { isAzure: true }, }); }); it('with values', () => { @@ -255,7 +255,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => { ModuleConfig<'generative-openai', GenerativeConfigRuntimeType<'generative-openai'>> >({ name: 'generative-openai', - config: { model: '', isAzure: false }, + config: { isAzure: false }, }); }); it('with values', () => {