From 4829df2c65ccb1de14a511c2332092d1691904cf Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Wed, 15 May 2024 13:00:35 +0100 Subject: [PATCH] Add missing generative modules and unit tests --- src/collections/config/types/generative.ts | 35 ++++ src/collections/configure/generative.ts | 94 ++++++++- src/collections/configure/types/generative.ts | 19 +- src/collections/configure/unit.test.ts | 187 +++++++++++++++--- 4 files changed, 303 insertions(+), 32 deletions(-) diff --git a/src/collections/config/types/generative.ts b/src/collections/config/types/generative.ts index cc149292..63db5697 100644 --- a/src/collections/config/types/generative.ts +++ b/src/collections/config/types/generative.ts @@ -7,6 +7,36 @@ type GenerativeOpenAIConfigBase = { topPProperty?: number; }; +export type GenerativeAWSConfig = { + region: string; + service: string; + model?: string; + endpoint?: string; +}; + +export type GenerativeAnyscaleConfig = { + model?: string; + temperature?: number; +}; + +export type GenerativeMistralConfig = { + maxTokens?: number; + model?: string; + temperature?: number; +}; + +export type GenerativeOctoAIConfig = { + baseURL?: string; + maxTokens?: number; + model?: string; + temperature?: number; +}; + +export type GenerativeOllamaConfig = { + apiEndpoint?: string; + model?: string; +}; + export type GenerativeOpenAIConfig = GenerativeOpenAIConfigBase & { model?: string; }; @@ -53,6 +83,11 @@ export type GenerativeConfigType = G extends 'generative-openai' : Record | undefined; export type GenerativeSearch = + | 'generative-anyscale' + | 'generative-aws' + | 'generative-mistral' + | 'generative-octoai' + | 'generative-ollama' | 'generative-openai' | 'generative-cohere' | 'generative-palm' diff --git a/src/collections/configure/generative.ts b/src/collections/configure/generative.ts index abea28ca..e51447a2 100644 --- a/src/collections/configure/generative.ts +++ b/src/collections/configure/generative.ts @@ -1,18 +1,58 @@ import { + GenerativeAWSConfig, + GenerativeAnyscaleConfig, GenerativeAzureOpenAIConfig, GenerativeCohereConfig, + GenerativeMistralConfig, + GenerativeOctoAIConfig, + GenerativeOllamaConfig, GenerativeOpenAIConfig, GenerativePaLMConfig, ModuleConfig, } from '../config/types/index.js'; import { + GenerativeAWSConfigCreate, + GenerativeAnyscaleConfigCreate, GenerativeAzureOpenAIConfigCreate, GenerativeCohereConfigCreate, + GenerativeMistralConfigCreate, + GenerativeOctoAIConfigCreate, + GenerativeOllamaConfigCreate, GenerativeOpenAIConfigCreate, GenerativePaLMConfigCreate, } from '../index.js'; export default { + /** + * Create a `ModuleConfig<'generative-anyscale', GenerativeAnyscaleConfig | undefined>` object for use when performing AI generation using the `generative-anyscale` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/modules/reader-generator-modules/generative-anyscale) for detailed usage. + * + * @param {GenerativeAnyscaleConfigCreate} config The configuration for the `generative-aws` module. + * @returns {ModuleConfig<'generative-anyscale', GenerativeAnyscaleConfig | undefined>} The configuration object. + */ + anyscale( + config?: GenerativeAnyscaleConfigCreate + ): ModuleConfig<'generative-anyscale', GenerativeAnyscaleConfig | undefined> { + return { + name: 'generative-anyscale', + config, + }; + }, + /** + * Create a `ModuleConfig<'generative-aws', GenerativeAWSConfig>` object for use when performing AI generation using the `generative-aws` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/modules/reader-generator-modules/generative-aws) for detailed usage. + * + * @param {GenerativeAWSConfigCreate} config The configuration for the `generative-aws` module. + * @returns {ModuleConfig<'generative-aws', GenerativeAWSConfig>} The configuration object. + */ + aws(config: GenerativeAWSConfigCreate): ModuleConfig<'generative-aws', GenerativeAWSConfig> { + return { + name: 'generative-aws', + config, + }; + }, /** * Create a `ModuleConfig<'generative-openai', GenerativeAzureOpenAIConfig>` object for use when performing AI generation using the `generative-openai` module. * @@ -64,12 +104,60 @@ export default { }; }, /** - * Create a `ModuleConfig<'generative-openai', GenerativeOpenAIConfig>` object for use when performing AI generation using the `generative-openai` module. + * Create a `ModuleConfig<'generative-mistral', GenerativeMistralConfig | undefined>` object for use when performing AI generation using the `generative-mistral` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/modules/reader-generator-modules/generative-mistral) for detailed usage. + * + * @param {GenerativeMistralConfigCreate} [config] The configuration for the `generative-mistral` module. + * @returns {ModuleConfig<'generative-mistral', GenerativeMistralConfig | undefined>} The configuration object. + */ + mistral( + config?: GenerativeMistralConfigCreate + ): ModuleConfig<'generative-mistral', GenerativeMistralConfig | undefined> { + return { + name: 'generative-mistral', + config, + }; + }, + /** + * Create a `ModuleConfig<'generative-octoai', GenerativeOpenAIConfig | undefined>` object for use when performing AI generation using the `generative-octoai` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/modules/reader-generator-modules/generative-octoai) for detailed usage. + * + * @param {GenerativeOctoAIConfigCreate} [config] The configuration for the `generative-octoai` module. + * @returns {ModuleConfig<'generative-octoai', GenerativeOctoAIConfig | undefined>} The configuration object. + */ + octoai( + config?: GenerativeOctoAIConfigCreate + ): ModuleConfig<'generative-octoai', GenerativeOctoAIConfig | undefined> { + return { + name: 'generative-octoai', + config, + }; + }, + /** + * Create a `ModuleConfig<'generative-ollama', GenerativeOllamaConfig | undefined>` object for use when performing AI generation using the `generative-ollama` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/modules/reader-generator-modules/generative-ollama) for detailed usage. + * + * @param {GenerativeOllamaConfigCreate} [config] The configuration for the `generative-openai` module. + * @returns {ModuleConfig<'generative-ollama', GenerativeOllamaConfig | undefined>} The configuration object. + */ + ollama( + config?: GenerativeOllamaConfigCreate + ): ModuleConfig<'generative-ollama', GenerativeOllamaConfig | undefined> { + return { + name: 'generative-ollama', + config, + }; + }, + /** + * Create a `ModuleConfig<'generative-openai', GenerativeOpenAIConfig | undefined>` object for use when performing AI generation using the `generative-openai` module. * * See the [documentation](https://weaviate.io/developers/weaviate/modules/reader-generator-modules/generative-openai) for detailed usage. * * @param {GenerativeOpenAIConfigCreate} [config] The configuration for the `generative-openai` module. - * @returns {ModuleConfig<'generative-openai', GenerativeOpenAIConfig>} The configuration object. + * @returns {ModuleConfig<'generative-openai', GenerativeOpenAIConfig | undefined>} The configuration object. */ openAI: ( config?: GenerativeOpenAIConfigCreate @@ -100,7 +188,7 @@ export default { palm: (config: GenerativePaLMConfigCreate): ModuleConfig<'generative-palm', GenerativePaLMConfig> => { return { name: 'generative-palm', - config: config, + config, }; }, }; diff --git a/src/collections/configure/types/generative.ts b/src/collections/configure/types/generative.ts index c52a5341..638c469e 100644 --- a/src/collections/configure/types/generative.ts +++ b/src/collections/configure/types/generative.ts @@ -1,4 +1,11 @@ -import { GenerativePaLMConfig } from '../../index.js'; +import { + GenerativeAWSConfig, + GenerativeAnyscaleConfig, + GenerativeMistralConfig, + GenerativeOctoAIConfig, + GenerativeOllamaConfig, + GenerativePaLMConfig, +} from '../../index.js'; type GenerativeOpenAIConfigBaseCreate = { baseURL?: string; @@ -27,6 +34,16 @@ export type GenerativeCohereConfigCreate = { temperature?: number; }; +export type GenerativeAnyscaleConfigCreate = GenerativeAnyscaleConfig; + +export type GenerativeAWSConfigCreate = GenerativeAWSConfig; + +export type GenerativeMistralConfigCreate = GenerativeMistralConfig; + +export type GenerativeOctoAIConfigCreate = GenerativeOctoAIConfig; + +export type GenerativeOllamaConfigCreate = GenerativeOllamaConfig; + export type GenerativePaLMConfigCreate = GenerativePaLMConfig; export type GenerativeConfigCreate = diff --git a/src/collections/configure/unit.test.ts b/src/collections/configure/unit.test.ts index 5b43842f..aef8a4d1 100644 --- a/src/collections/configure/unit.test.ts +++ b/src/collections/configure/unit.test.ts @@ -1,7 +1,12 @@ import { configure } from './index.js'; import { + GenerativeAWSConfig, + GenerativeAnyscaleConfig, GenerativeAzureOpenAIConfig, GenerativeCohereConfig, + GenerativeMistralConfig, + GenerativeOctoAIConfig, + GenerativeOllamaConfig, GenerativeOpenAIConfig, GenerativePaLMConfig, ModuleConfig, @@ -36,7 +41,7 @@ describe('Unit testing of the configure factory class', () => { }); }); - it('should create the correct InvertedIndexConfig type with custom values', () => { + it('should create the correct InvertedIndexConfig type with all values', () => { const config = configure.invertedIndex({ bm25b: 0.5, bm25k1: 1.5, @@ -72,7 +77,7 @@ describe('Unit testing of the configure factory class', () => { }); }); - it('should create the correct MultiTenancyConfig type with custom values', () => { + it('should create the correct MultiTenancyConfig type with all values', () => { const config = configure.multiTenancy({ enabled: false, }); @@ -88,7 +93,7 @@ describe('Unit testing of the configure factory class', () => { }); }); - it('should create the correct ReplicationConfig type with custom values', () => { + it('should create the correct ReplicationConfig type with all values', () => { const config = configure.replication({ factor: 2, }); @@ -106,7 +111,7 @@ describe('Unit testing of the configure factory class', () => { }); }); - it('should create the correct ShardingConfig type with custom values', () => { + it('should create the correct ShardingConfig type with all values', () => { const config = configure.sharding({ virtualPerPhysical: 256, desiredCount: 2, @@ -151,7 +156,7 @@ describe('Unit testing of the configure factory class', () => { }); }); - it('should create the correct HNSW VectorIndexConfig type with custom values', () => { + it('should create the correct HNSW VectorIndexConfig type with all values', () => { const config = configure.vectorIndex.hnsw({ cleanupIntervalSeconds: 120, distanceMetric: 'dot', @@ -234,7 +239,7 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); - it('should create the correct Img2VecNeuralConfig type with custom values', () => { + it('should create the correct Img2VecNeuralConfig type with all values', () => { const config = configure.vectorizer.img2VecNeural('test', { imageFields: ['field1', 'field2'], }); @@ -268,7 +273,7 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); - it('should create the correct Multi2VecClipConfig type with custom values', () => { + it('should create the correct Multi2VecClipConfig type with all values', () => { const config = configure.vectorizer.multi2VecClip('test', { imageFields: ['field1', 'field2'], textFields: ['field3', 'field4'], @@ -306,7 +311,7 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); - it('should create the correct Multi2VecBindConfig type with custom values', () => { + it('should create the correct Multi2VecBindConfig type with all values', () => { const config = configure.vectorizer.multi2VecBind('test', { audioFields: ['field1', 'field2'], depthFields: ['field3', 'field4'], @@ -358,7 +363,7 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); - it('should create the correct Multi2VecPalmConfig type with custom values', () => { + it('should create the correct Multi2VecPalmConfig type with all values', () => { const config = configure.vectorizer.multi2VecPalm('test', { projectId: 'project-id', imageFields: ['field1', 'field2'], @@ -412,7 +417,7 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); - it('should create the correct Text2VecAWSConfig type with custom values', () => { + it('should create the correct Text2VecAWSConfig type with all values', () => { const config = configure.vectorizer.text2VecAWS('test', { endpoint: 'endpoint', model: 'model', @@ -460,7 +465,7 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); - it('should create the correct Text2VecAzureOpenAIConfig type with custom values', () => { + it('should create the correct Text2VecAzureOpenAIConfig type with all values', () => { const config = configure.vectorizer.text2VecAzureOpenAI('test', { baseURL: 'base-url', deploymentID: 'deployment-id', @@ -500,7 +505,7 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); - it('should create the correct Text2VecCohereConfig type with custom values', () => { + it('should create the correct Text2VecCohereConfig type with all values', () => { const config = configure.vectorizer.text2VecCohere('test', { baseURL: 'base-url', model: 'model', @@ -540,7 +545,7 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); - it('should create the correct Text2VecContextionaryConfig type with custom values', () => { + it('should create the correct Text2VecContextionaryConfig type with all values', () => { const config = configure.vectorizer.text2VecContextionary('test', { vectorizeCollectionName: true, }); @@ -574,7 +579,7 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); - it('should create the correct Text2VecGPT4AllConfig type with custom values', () => { + it('should create the correct Text2VecGPT4AllConfig type with all values', () => { const config = configure.vectorizer.text2VecGPT4All('test', { vectorizeCollectionName: true, }); @@ -608,7 +613,7 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); - it('should create the correct Text2VecHuggingFaceConfig type with custom values', () => { + it('should create the correct Text2VecHuggingFaceConfig type with all values', () => { const config = configure.vectorizer.text2VecHuggingFace('test', { endpointURL: 'endpoint-url', model: 'model', @@ -656,7 +661,7 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); - it('should create the correct Text2VecJinaConfig type with custom values', () => { + it('should create the correct Text2VecJinaConfig type with all values', () => { const config = configure.vectorizer.text2VecJina('test', { model: 'model', vectorizeCollectionName: true, @@ -692,7 +697,7 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); - it('should create the correct Text2VecOpenAIConfig type with custom values', () => { + it('should create the correct Text2VecOpenAIConfig type with all values', () => { const config = configure.vectorizer.text2VecOpenAI('test', { baseURL: 'base-url', dimensions: 256, @@ -736,7 +741,7 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); - it('should create the correct Text2VecPalmConfig type with custom values', () => { + it('should create the correct Text2VecPalmConfig type with all values', () => { const config = configure.vectorizer.text2VecPalm('test', { apiEndpoint: 'api-endpoint', modelId: 'model-id', @@ -776,7 +781,7 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); - it('should create the correct Text2VecTransformersConfig type with custom values', () => { + it('should create the correct Text2VecTransformersConfig type with all values', () => { const config = configure.vectorizer.text2VecTransformers('test', { poolingStrategy: 'pooling-strategy', vectorizeCollectionName: true, @@ -812,7 +817,7 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); - it('should create the correct Text2VecVoyageConfig type with custom values', () => { + it('should create the correct Text2VecVoyageConfig type with all values', () => { const config = configure.vectorizer.text2VecVoyageAI('test', { baseURL: 'base-url', model: 'model', @@ -837,7 +842,61 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); - it('should create the correct GenerativeAzureOpenAIConfig type with default values', () => { + it('should create the correct GenerativeAnyscaleConfig type with required & default values', () => { + const config = configure.generative.anyscale(); + expect(config).toEqual>({ + name: 'generative-anyscale', + config: undefined, + }); + }); + + it('should create the correct GenerativeAnyscaleConfig type with all values', () => { + const config = configure.generative.anyscale({ + model: 'model', + temperature: 0.5, + }); + expect(config).toEqual>({ + name: 'generative-anyscale', + config: { + model: 'model', + temperature: 0.5, + }, + }); + }); + + it('should create the correct GenerativeAWSConfig type with required & default values', () => { + const config = configure.generative.aws({ + region: 'region', + service: 'service', + }); + expect(config).toEqual>({ + name: 'generative-aws', + config: { + region: 'region', + service: 'service', + }, + }); + }); + + it('should create the correct GenerativeAWSConfig type with all values', () => { + const config = configure.generative.aws({ + endpoint: 'endpoint', + model: 'model', + region: 'region', + service: 'service', + }); + expect(config).toEqual>({ + name: 'generative-aws', + config: { + endpoint: 'endpoint', + model: 'model', + region: 'region', + service: 'service', + }, + }); + }); + + it('should create the correct GenerativeAzureOpenAIConfig type with required & default values', () => { const config = configure.generative.azureOpenAI({ resourceName: 'resource-name', deploymentId: 'deployment-id', @@ -851,7 +910,7 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); - it('should create the correct GenerativeAzureOpenAIConfig type with custom values', () => { + it('should create the correct GenerativeAzureOpenAIConfig type with all values', () => { const config = configure.generative.azureOpenAI({ resourceName: 'resource-name', deploymentId: 'deployment-id', @@ -877,7 +936,7 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); - it('should create the correct GenerativeCohereConfig type with default values', () => { + it('should create the correct GenerativeCohereConfig type with required & default values', () => { const config = configure.generative.cohere(); expect(config).toEqual>({ name: 'generative-cohere', @@ -885,7 +944,7 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); - it('should create the correct GenerativeCohereConfig type with custom values', () => { + it('should create the correct GenerativeCohereConfig type with all values', () => { const config = configure.generative.cohere({ k: 5, maxTokens: 100, @@ -907,7 +966,79 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); - it('should create the correct GenerativeOpenAIConfig type with default values', () => { + it('should create the correct GenerativeMistralConfig type with required & default values', () => { + const config = configure.generative.mistral(); + expect(config).toEqual>({ + name: 'generative-mistral', + config: undefined, + }); + }); + + it('should create the correct GenerativeMistralConfig type with all values', () => { + const config = configure.generative.mistral({ + maxTokens: 100, + model: 'model', + temperature: 0.5, + }); + expect(config).toEqual>({ + name: 'generative-mistral', + config: { + maxTokens: 100, + model: 'model', + temperature: 0.5, + }, + }); + }); + + it('should create the correct GenerativeOctoAIConfig type with required & default values', () => { + const config = configure.generative.octoai(); + expect(config).toEqual>({ + name: 'generative-octoai', + config: undefined, + }); + }); + + it('should create the correct GenerativeOctoAIConfig type with all values', () => { + const config = configure.generative.octoai({ + baseURL: 'base-url', + maxTokens: 100, + model: 'model', + temperature: 0.5, + }); + expect(config).toEqual>({ + name: 'generative-octoai', + config: { + baseURL: 'base-url', + maxTokens: 100, + model: 'model', + temperature: 0.5, + }, + }); + }); + + it('should create the correct GenerativeOllamaConfig type with required & default values', () => { + const config = configure.generative.ollama(); + expect(config).toEqual>({ + name: 'generative-ollama', + config: undefined, + }); + }); + + it('should create the correct GenerativeOllamaConfig type with all values', () => { + const config = configure.generative.ollama({ + apiEndpoint: 'api-endpoint', + model: 'model', + }); + expect(config).toEqual>({ + name: 'generative-ollama', + config: { + apiEndpoint: 'api-endpoint', + model: 'model', + }, + }); + }); + + it('should create the correct GenerativeOpenAIConfig type with required & default values', () => { const config = configure.generative.openAI(); expect(config).toEqual>({ name: 'generative-openai', @@ -915,7 +1046,7 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); - it('should create the correct GenerativeOpenAIConfig type with custom values', () => { + it('should create the correct GenerativeOpenAIConfig type with all values', () => { const config = configure.generative.openAI({ baseURL: 'base-url', frequencyPenalty: 0.5, @@ -939,7 +1070,7 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); - it('should create the correct GenerativePaLMConfig type with default values', () => { + it('should create the correct GenerativePaLMConfig type with required & default values', () => { const config = configure.generative.palm({ projectId: 'project-id', }); @@ -951,7 +1082,7 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); - it('should create the correct GenerativePaLMConfig type with custom values', () => { + it('should create the correct GenerativePaLMConfig type with all values', () => { const config = configure.generative.palm({ apiEndpoint: 'api-endpoint', maxOutputTokens: 100,