diff --git a/ci/compose.sh b/ci/compose.sh index 354d8e70..407b51dc 100644 --- a/ci/compose.sh +++ b/ci/compose.sh @@ -21,5 +21,5 @@ function compose_down_all { } function all_weaviate_ports { - echo "8079 8080 8081 8082 8083 8085 8086 8087 8088" + echo "8079 8080 8081 8082 8083 8085 8086 8087 8088 8089 8090" } diff --git a/ci/docker-compose-backup.yml b/ci/docker-compose-backup.yml new file mode 100644 index 00000000..9d1e0d64 --- /dev/null +++ b/ci/docker-compose-backup.yml @@ -0,0 +1,19 @@ +--- +version: '3.4' +services: + weaviate-backup: + image: semitechnologies/weaviate:${WEAVIATE_VERSION} + restart: on-failure:0 + ports: + - 8090:8080 + - 50061:50051 + environment: + QUERY_DEFAULTS_LIMIT: 20 + AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true' + PERSISTENCE_DATA_PATH: "./weaviate-data" + BACKUP_FILESYSTEM_PATH: "/tmp/backups" + ENABLE_MODULES: backup-filesystem + CLUSTER_GOSSIP_BIND_PORT: "7100" + CLUSTER_DATA_BIND_PORT: "7101" + DISABLE_TELEMETRY: 'true' +... diff --git a/src/backup/index.ts b/src/backup/index.ts index bf12582e..cd2f5483 100644 --- a/src/backup/index.ts +++ b/src/backup/index.ts @@ -6,6 +6,7 @@ import Connection from '../connection/index.js'; export type Backend = 'filesystem' | 's3' | 'gcs' | 'azure'; export type BackupStatus = 'STARTED' | 'TRANSFERRING' | 'TRANSFERRED' | 'SUCCESS' | 'FAILED'; +export type BackupCompressionLevel = 'DefaultCompression' | 'BestSpeed' | 'BestCompression'; export interface Backup { creator: () => BackupCreator; diff --git a/src/collections/backup/client.ts b/src/collections/backup/client.ts index ca02008c..8fc4011c 100644 --- a/src/collections/backup/client.ts +++ b/src/collections/backup/client.ts @@ -1,5 +1,6 @@ import { Backend, + BackupCompressionLevel, BackupCreateStatusGetter, BackupCreator, BackupRestoreStatusGetter, @@ -11,11 +12,28 @@ import { WeaviateBackupFailed, WeaviateDeserializationError } from '../../errors import { BackupCreateResponse, BackupCreateStatusResponse, + BackupRestoreResponse, BackupRestoreStatusResponse, } from '../../openapi/types.js'; +/** Configuration options available when creating a backup */ +export type BackupConfigCreate = { + /** The size of the chunks to use for the backup. */ + chunkSize?: number; + /** The standard of compression to use for the backup. */ + compressionLevel?: BackupCompressionLevel; + /** The percentage of CPU to use for the backup creation job. */ + cpuPercentage?: number; +}; + +/** Configuration options available when restoring a backup */ +export type BackupConfigRestore = { + /** The percentage of CPU to use for the backuop restoration job. */ + cpuPercentage?: number; +}; + /** The arguments required to create and restore backups. */ -export interface BackupArgs { +export type BackupArgs = { /** The ID of the backup. */ backupId: string; /** The backend to use for the backup. */ @@ -26,21 +44,16 @@ export interface BackupArgs { excludeCollections?: string[]; /** Whether to wait for the backup to complete. */ waitForCompletion?: boolean; -} + /** The configuration options for the backup. */ + config?: C; +}; /** The arguments required to get the status of a backup. */ -export interface BackupStatusArgs { +export type BackupStatusArgs = { /** The ID of the backup. */ backupId: string; /** The backend to use for the backup. */ backend: Backend; -} - -/** The response from a backup creation request. */ -export type BackupReturn = { - collections: string[]; - status: BackupStatus; - path: string; }; export const backup = (connection: Connection) => { @@ -57,7 +70,7 @@ export const backup = (connection: Connection) => { .do(); }; return { - create: async (args: BackupArgs): Promise => { + create: async (args: BackupArgs): Promise => { let builder = new BackupCreator(connection, new BackupCreateStatusGetter(connection)) .withBackupId(args.backupId) .withBackend(args.backend); @@ -67,13 +80,30 @@ export const backup = (connection: Connection) => { if (args.excludeCollections) { builder = builder.withExcludeClassNames(...args.excludeCollections); } - const res = builder.do(); + if (args.config) { + builder = builder.withConfig({ + ChunkSize: args.config.chunkSize, + CompressionLevel: args.config.compressionLevel, + CPUPercentage: args.config.cpuPercentage, + }); + } + let res: BackupCreateResponse; + try { + res = await builder.do(); + } catch (err) { + throw new Error(`Backup creation failed: ${err}`); + } + if (res.status === 'FAILED') { + throw new Error(`Backup creation failed: ${res.error}`); + } + let status: BackupCreateStatusResponse | undefined; if (args.waitForCompletion) { let wait = true; while (wait) { const res = await getCreateStatus(args); // eslint-disable-line no-await-in-loop if (res.status === 'SUCCESS') { wait = false; + status = res; } if (res.status === 'FAILED') { throw new WeaviateBackupFailed(res.error ? res.error : '', 'creation'); @@ -81,13 +111,11 @@ export const backup = (connection: Connection) => { await new Promise((resolve) => setTimeout(resolve, 1000)); // eslint-disable-line no-await-in-loop } } - return res.then(() => - new BackupCreateStatusGetter(connection).withBackupId(args.backupId).withBackend(args.backend).do() - ); + return status ? { ...status, classes: res.classes } : res; }, getCreateStatus: getCreateStatus, getRestoreStatus: getRestoreStatus, - restore: async (args: BackupArgs): Promise => { + restore: async (args: BackupArgs): Promise => { let builder = new BackupRestorer(connection, new BackupRestoreStatusGetter(connection)) .withBackupId(args.backupId) .withBackend(args.backend); @@ -97,13 +125,28 @@ export const backup = (connection: Connection) => { if (args.excludeCollections) { builder = builder.withExcludeClassNames(...args.excludeCollections); } - const res = builder.do(); + if (args.config) { + builder = builder.withConfig({ + CPUPercentage: args.config.cpuPercentage, + }); + } + let res: BackupRestoreResponse; + try { + res = await builder.do(); + } catch (err) { + throw new Error(`Backup restoration failed: ${err}`); + } + if (res.status === 'FAILED') { + throw new Error(`Backup restoration failed: ${res.error}`); + } + let status: BackupRestoreStatusResponse | undefined; if (args.waitForCompletion) { let wait = true; while (wait) { const res = await getRestoreStatus(args); // eslint-disable-line no-await-in-loop if (res.status === 'SUCCESS') { wait = false; + status = res; } if (res.status === 'FAILED') { throw new WeaviateBackupFailed(res.error ? res.error : '', 'restoration'); @@ -111,9 +154,12 @@ export const backup = (connection: Connection) => { await new Promise((resolve) => setTimeout(resolve, 1000)); // eslint-disable-line no-await-in-loop } } - return res.then(() => - new BackupRestoreStatusGetter(connection).withBackupId(args.backupId).withBackend(args.backend).do() - ); + return status + ? { + ...status, + classes: res.classes, + } + : res; }, }; }; @@ -125,26 +171,26 @@ export interface Backup { * @param {BackupArgs} args The arguments for the request. * @returns {Promise} The response from Weaviate. */ - create(args: BackupArgs): Promise; + create(args: BackupArgs): Promise; /** * Get the status of a backup creation. * * @param {BackupStatusArgs} args The arguments for the request. - * @returns {Promise} The status of the backup creation. + * @returns {Promise} The status of the backup creation. */ getCreateStatus(args: BackupStatusArgs): Promise; /** * Get the status of a backup restore. * * @param {BackupStatusArgs} args The arguments for the request. - * @returns {Promise} The status of the backup restore. + * @returns {Promise} The status of the backup restore. */ getRestoreStatus(args: BackupStatusArgs): Promise; /** * Restore a backup of the database. * * @param {BackupArgs} args The arguments for the request. - * @returns {Promise} The response from Weaviate. + * @returns {Promise} The response from Weaviate. */ - restore(args: BackupArgs): Promise; + restore(args: BackupArgs): Promise; } diff --git a/src/collections/backup/collection.ts b/src/collections/backup/collection.ts index faf30b1b..c4861a8f 100644 --- a/src/collections/backup/collection.ts +++ b/src/collections/backup/collection.ts @@ -3,19 +3,20 @@ import Connection from '../../connection/index.js'; import { BackupCreateResponse, BackupCreateStatusResponse, + BackupRestoreResponse, BackupRestoreStatusResponse, } from '../../openapi/types.js'; import { BackupStatusArgs, backup } from './client.js'; /** The arguments required to create and restore backups. */ -export interface BackupCollectionArgs { +export type BackupCollectionArgs = { /** The ID of the backup. */ backupId: string; /** The backend to use for the backup. */ backend: Backend; /** The collections to include in the backup. */ waitForCompletion?: boolean; -} +}; export const backupCollection = (connection: Connection, name: string) => { const handler = backup(connection); @@ -47,21 +48,21 @@ export interface BackupCollection { * Get the status of a backup. * * @param {BackupStatusArgs} args The arguments for the request. - * @returns {Promise} The status of the backup. + * @returns {Promise} The status of the backup. */ getCreateStatus(args: BackupStatusArgs): Promise; /** * Get the status of a restore. * * @param {BackupStatusArgs} args The arguments for the request. - * @returns {Promise} The status of the restore. + * @returns {Promise} The status of the restore. */ getRestoreStatus(args: BackupStatusArgs): Promise; /** * Restore a backup of this collection. * * @param {BackupArgs} args The arguments for the request. - * @returns {Promise} The response from Weaviate. + * @returns {Promise} The response from Weaviate. */ - restore(args: BackupCollectionArgs): Promise; + restore(args: BackupCollectionArgs): Promise; } diff --git a/src/collections/backup/integration.test.ts b/src/collections/backup/integration.test.ts new file mode 100644 index 00000000..cf8888a6 --- /dev/null +++ b/src/collections/backup/integration.test.ts @@ -0,0 +1,105 @@ +/* eslint-disable @typescript-eslint/no-non-null-assertion */ +/* eslint-disable @typescript-eslint/no-non-null-asserted-optional-chain */ +/* eslint-disable no-await-in-loop */ +import { Backend } from '../../backup/index.js'; +import weaviate, { Collection, WeaviateClient } from '../../index.js'; + +// These must run sequentially because Weaviate is not capable of running multiple backups at the same time +describe('Integration testing of backups', () => { + const clientPromise = weaviate.connectToLocal({ + port: 8090, + grpcPort: 50061, + }); + + const getCollection = (client: WeaviateClient) => client.collections.get('TestBackupCollection'); + + beforeAll(() => + clientPromise.then((client) => + Promise.all([ + client.collections.create({ name: 'TestBackupClient' }).then((col) => col.data.insert()), + client.collections.create({ name: 'TestBackupCollection' }).then((col) => col.data.insert()), + ]) + ) + ); + + afterAll(() => clientPromise.then((client) => client.collections.deleteAll())); + + const testClientWaitForCompletion = async (client: WeaviateClient) => { + const res = await client.backup.create({ + backupId: `test-backup-${randomBackupId()}`, + backend: 'filesystem', + waitForCompletion: true, + }); + expect(res.status).toBe('SUCCESS'); + return client; + }; + + const testClientNoWaitForCompletion = async (client: WeaviateClient) => { + const res = await client.backup.create({ + backupId: `test-backup-${randomBackupId()}`, + backend: 'filesystem', + }); + expect(res.status).toBe('STARTED'); + const status = await client.backup.getCreateStatus({ + backupId: res.id as string, + backend: res.backend as 'filesystem', + }); + expect(status).not.toBe('SUCCESS'); // can be 'STARTED' or 'TRANSFERRING' depending on the speed of the test machine + + // wait to complete so that other tests can run without colliding with Weaviate's lack of simultaneous backups + let wait = true; + while (wait) { + const { status, error } = await client.backup.getCreateStatus({ + backupId: res.id as string, + backend: res.backend as Backend, + }); + if (status === 'SUCCESS') { + wait = false; + } + if (status === 'FAILED') { + throw new Error(`Backup creation failed: ${error}`); + } + await new Promise((resolve) => setTimeout(resolve, 1000)); + } + + return client; + }; + + const testCollectionWaitForCompletion = async (collection: Collection) => { + const res = await collection.backup.create({ + backupId: `test-backup-${randomBackupId()}`, + backend: 'filesystem', + waitForCompletion: true, + }); + expect(res.status).toBe('SUCCESS'); + expect(res.classes).toEqual(['TestBackupCollection']); + return collection; + }; + + const testCollectionNoWaitForCompletion = async (collection: Collection) => { + const res = await collection.backup.create({ + backupId: `test-backup-${randomBackupId()}`, + backend: 'filesystem', + }); + expect(res.status).toBe('STARTED'); + expect(res.classes).toEqual(['TestBackupCollection']); + const status = await collection.backup.getCreateStatus({ + backupId: res.id as string, + backend: res.backend as 'filesystem', + }); + expect(status).not.toBe('SUCCESS'); // can be 'STARTED' or 'TRANSFERRING' depending on the speed of the test machine + return collection; + }; + + it('run', () => + clientPromise + .then(testClientWaitForCompletion) + .then(testClientNoWaitForCompletion) + .then(getCollection) + .then(testCollectionWaitForCompletion) + .then(testCollectionNoWaitForCompletion)); +}); + +function randomBackupId() { + return 'backup-id-' + Math.floor(Math.random() * Number.MAX_SAFE_INTEGER); +} diff --git a/src/collections/collection/index.ts b/src/collections/collection/index.ts index 1ce628cf..fe69a1c6 100644 --- a/src/collections/collection/index.ts +++ b/src/collections/collection/index.ts @@ -102,7 +102,7 @@ const collection = ( dbVersionSupport: DbVersionSupport, consistencyLevel?: ConsistencyLevel, tenant?: Tenant -) => { +): Collection => { if (!isString(name)) { throw new WeaviateInvalidInputError(`The collection name must be a string, got: ${typeof name}`); } @@ -125,7 +125,7 @@ const collection = ( name: name, query: queryCollection, sort: sort(), - tenants: tenants(connection, capitalizedName), + tenants: tenants(connection, capitalizedName, dbVersionSupport), exists: () => new ClassExists(connection).withClassName(capitalizedName).do(), iterator: (opts?: IteratorOptions) => new Iterator((limit: number, after?: string) => diff --git a/src/collections/deserialize/index.ts b/src/collections/deserialize/index.ts index 93d4fd51..e2460532 100644 --- a/src/collections/deserialize/index.ts +++ b/src/collections/deserialize/index.ts @@ -16,18 +16,30 @@ import { DeleteManyReturn, } from '../types/index.js'; import { BatchObject as BatchObjectGRPC, BatchObjectsReply } from '../../proto/v1/batch.js'; -import { Properties as PropertiesGrpc, Value } from '../../proto/v1/properties.js'; +import { ListValue, Properties as PropertiesGrpc, Value } from '../../proto/v1/properties.js'; import { BatchDeleteReply } from '../../proto/v1/batch_delete.js'; import { WeaviateDeserializationError } from '../../errors.js'; +import { DbVersionSupport } from '../../utils/dbVersion.js'; export class Deserialize { - public static query(reply: SearchReply): WeaviateReturn { + private supports125ListValue: boolean; + + private constructor(supports125ListValue: boolean) { + this.supports125ListValue = supports125ListValue; + } + + public static async use(support: DbVersionSupport): Promise { + const supports125ListValue = await support.supports125ListValue().then((res) => res.supports); + return new Deserialize(supports125ListValue); + } + + public query(reply: SearchReply): WeaviateReturn { return { objects: reply.results.map((result) => { return { metadata: Deserialize.metadata(result.metadata), - properties: Deserialize.properties(result.properties), - references: Deserialize.references(result.properties), + properties: this.properties(result.properties), + references: this.references(result.properties), uuid: Deserialize.uuid(result.metadata), vectors: Deserialize.vectors(result.metadata), } as any; @@ -35,14 +47,14 @@ export class Deserialize { }; } - public static generate(reply: SearchReply): GenerativeReturn { + public generate(reply: SearchReply): GenerativeReturn { return { objects: reply.results.map((result) => { return { generated: result.metadata?.generativePresent ? result.metadata?.generative : undefined, metadata: Deserialize.metadata(result.metadata), - properties: Deserialize.properties(result.properties), - references: Deserialize.references(result.properties), + properties: this.properties(result.properties), + references: this.references(result.properties), uuid: Deserialize.uuid(result.metadata), vectors: Deserialize.vectors(result.metadata), } as any; @@ -51,7 +63,7 @@ export class Deserialize { }; } - public static groupBy(reply: SearchReply): GroupByReturn { + public groupBy(reply: SearchReply): GroupByReturn { const objects: GroupByObject[] = []; const groups: Record> = {}; reply.groupByResults.forEach((result) => { @@ -59,8 +71,8 @@ export class Deserialize { return { belongsToGroup: result.name, metadata: Deserialize.metadata(object.metadata), - properties: Deserialize.properties(object.properties), - references: Deserialize.references(object.properties), + properties: this.properties(object.properties), + references: this.references(object.properties), uuid: Deserialize.uuid(object.metadata), vectors: Deserialize.vectors(object.metadata), } as any; @@ -80,7 +92,7 @@ export class Deserialize { }; } - public static generateGroupBy(reply: SearchReply): GenerativeGroupByReturn { + public generateGroupBy(reply: SearchReply): GenerativeGroupByReturn { const objects: GroupByObject[] = []; const groups: Record> = {}; reply.groupByResults.forEach((result) => { @@ -88,8 +100,8 @@ export class Deserialize { return { belongsToGroup: result.name, metadata: Deserialize.metadata(object.metadata), - properties: Deserialize.properties(object.properties), - references: Deserialize.references(object.properties), + properties: this.properties(object.properties), + references: this.references(object.properties), uuid: Deserialize.uuid(object.metadata), vectors: Deserialize.vectors(object.metadata), } as any; @@ -111,12 +123,12 @@ export class Deserialize { }; } - private static properties(properties?: PropertiesResult) { + private properties(properties?: PropertiesResult) { if (!properties) return {}; - return Deserialize.objectProperties(properties.nonRefProps); + return this.objectProperties(properties.nonRefProps); } - private static references(properties?: PropertiesResult) { + private references(properties?: PropertiesResult) { if (!properties) return undefined; if (properties.refProps.length === 0) return properties.refPropsRequested ? {} : undefined; const out: any = {}; @@ -128,8 +140,8 @@ export class Deserialize { uuids.push(uuid); return { metadata: Deserialize.metadata(property.metadata), - properties: Deserialize.properties(property), - references: Deserialize.references(property), + properties: this.properties(property), + references: this.references(property), uuid: uuid, vectors: Deserialize.vectors(property.metadata), }; @@ -141,15 +153,18 @@ export class Deserialize { return out; } - private static parsePropertyValue(value: Value): any { + private parsePropertyValue(value: Value): any { if (value.boolValue !== undefined) return value.boolValue; if (value.dateValue !== undefined) return new Date(value.dateValue); if (value.intValue !== undefined) return value.intValue; if (value.listValue !== undefined) - return value.listValue.values.map((v) => Deserialize.parsePropertyValue(v)); + return this.supports125ListValue + ? this.parseListValue(value.listValue) + : value.listValue.values.map((v) => this.parsePropertyValue(v)); if (value.numberValue !== undefined) return value.numberValue; - if (value.objectValue !== undefined) return Deserialize.objectProperties(value.objectValue); + if (value.objectValue !== undefined) return this.objectProperties(value.objectValue); if (value.stringValue !== undefined) return value.stringValue; + if (value.textValue !== undefined) return value.textValue; if (value.uuidValue !== undefined) return value.uuidValue; if (value.blobValue !== undefined) return value.blobValue; if (value.geoValue !== undefined) return value.geoValue; @@ -158,11 +173,23 @@ export class Deserialize { throw new WeaviateDeserializationError(`Unknown value type: ${JSON.stringify(value, null, 2)}`); } - private static objectProperties(properties?: PropertiesGrpc): Properties { + private parseListValue(value: ListValue): string[] | number[] | boolean[] | Date[] | Properties[] { + if (value.boolValues !== undefined) return value.boolValues.values; + if (value.dateValues !== undefined) return value.dateValues.values.map((date) => new Date(date)); + if (value.intValues !== undefined) return Deserialize.intsFromBytes(value.intValues.values); + if (value.numberValues !== undefined) return Deserialize.numbersFromBytes(value.numberValues.values); + if (value.objectValues !== undefined) + return value.objectValues.values.map((v) => this.objectProperties(v)); + if (value.textValues !== undefined) return value.textValues.values; + if (value.uuidValues !== undefined) return value.uuidValues.values; + throw new Error(`Unknown list value type: ${JSON.stringify(value, null, 2)}`); + } + + private objectProperties(properties?: PropertiesGrpc): Properties { const out: Properties = {}; if (properties) { Object.entries(properties.fields).forEach(([key, value]) => { - out[key] = Deserialize.parsePropertyValue(value); + out[key] = this.parsePropertyValue(value); }); } return out; @@ -194,6 +221,18 @@ export class Deserialize { return Array.from(view); } + private static intsFromBytes(bytes: Uint8Array) { + const buffer = Buffer.from(bytes); + const view = new BigInt64Array(buffer.buffer, buffer.byteOffset, buffer.byteLength / 8); // ints are float64 in weaviate + return Array.from(view).map(Number); + } + + private static numbersFromBytes(bytes: Uint8Array) { + const buffer = Buffer.from(bytes); + const view = new Float64Array(buffer.buffer, buffer.byteOffset, buffer.byteLength / 8); // numbers are float64 in weaviate + return Array.from(view); + } + private static vectors(metadata?: MetadataResult): Record { if (!metadata) return {}; if (metadata.vectorBytes.length === 0 && metadata.vector.length === 0 && metadata.vectors.length === 0) diff --git a/src/collections/generate/index.ts b/src/collections/generate/index.ts index 7d42ccd7..6bd1c1bd 100644 --- a/src/collections/generate/index.ts +++ b/src/collections/generate/index.ts @@ -1,3 +1,4 @@ +export type { Generate } from './types.js'; import Connection from '../../connection/grpc.js'; import { DbVersionSupport } from '../../utils/dbVersion.js'; @@ -13,26 +14,25 @@ import { BaseNearOptions, NearMediaType, NearOptions, - NearTextOptions, GroupByNearOptions, GroupByNearTextOptions, + BaseHybridOptions, + GroupByHybridOptions, + BaseBm25Options, + GroupByBm25Options, + SearchOptions, } from '../query/types.js'; -import { GenerativeReturn, GenerativeGroupByReturn } from '../types/index.js'; +import { GenerativeReturn, GenerativeGroupByReturn, GroupByOptions } from '../types/index.js'; import { SearchReply } from '../../proto/v1/search_get.js'; -import { WeaviateInvalidInputError } from '../../errors.js'; - -export type GenerateOptions = { - singlePrompt?: string; - groupedTask?: string; - groupedProperties?: T extends undefined ? string[] : (keyof T)[]; -}; +import { WeaviateInvalidInputError, WeaviateUnsupportedFeatureError } from '../../errors.js'; +import { GenerateOptions, Generate, GenerateReturn } from './types.js'; class GenerateManager implements Generate { - connection: Connection; - name: string; - dbVersionSupport: DbVersionSupport; - consistencyLevel?: ConsistencyLevel; - tenant?: string; + private connection: Connection; + private name: string; + private dbVersionSupport: DbVersionSupport; + private consistencyLevel?: ConsistencyLevel; + private tenant?: string; private constructor( connection: Connection, @@ -58,55 +58,106 @@ class GenerateManager implements Generate { return new GenerateManager(connection, name, dbVersionSupport, consistencyLevel, tenant); } + private checkSupportForNamedVectors = async (opts?: BaseNearOptions) => { + if (!Serialize.isNamedVectors(opts)) return; + const check = await this.dbVersionSupport.supportsNamedVectors(); + if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message); + }; + + private checkSupportForBm25AndHybridGroupByQueries = async (query: 'Bm25' | 'Hybrid', opts?: any) => { + if (!Serialize.isGroupBy(opts)) return; + const check = await this.dbVersionSupport.supportsBm25AndHybridGroupByQueries(); + if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message(query)); + }; + + private async parseReply(reply: SearchReply) { + const deserialize = await Deserialize.use(this.dbVersionSupport); + return deserialize.generate(reply); + } + + private async parseGroupByReply( + opts: SearchOptions | GroupByOptions | undefined, + reply: SearchReply + ) { + const deserialize = await Deserialize.use(this.dbVersionSupport); + return Serialize.isGroupBy(opts) ? deserialize.generateGroupBy(reply) : deserialize.generate(reply); + } + public fetchObjects( generate: GenerateOptions, opts?: FetchObjectsOptions ): Promise> { - return this.connection.search(this.name).then((search) => - search - .withFetch({ + return this.checkSupportForNamedVectors(opts) + .then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant)) + .then((search) => + search.withFetch({ ...Serialize.fetchObjects(opts), generative: Serialize.generative(generate), }) - .then((reply) => Deserialize.generate(reply)) - ); + ) + .then((reply) => this.parseReply(reply)); } public bm25( query: string, generate: GenerateOptions, - opts?: Bm25Options - ): Promise> { - return this.connection.search(this.name).then((search) => - search - .withBm25({ + opts?: BaseBm25Options + ): Promise>; + public bm25( + query: string, + generate: GenerateOptions, + opts: GroupByBm25Options + ): Promise>; + public bm25(query: string, generate: GenerateOptions, opts?: Bm25Options): GenerateReturn { + return Promise.all([ + this.checkSupportForNamedVectors(opts), + this.checkSupportForBm25AndHybridGroupByQueries('Bm25', opts), + ]) + .then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant)) + .then((search) => + search.withBm25({ ...Serialize.bm25({ query, ...opts }), generative: Serialize.generative(generate), + groupBy: Serialize.isGroupBy>(opts) + ? Serialize.groupBy(opts.groupBy) + : undefined, }) - .then((reply) => Deserialize.generate(reply)) - ); + ) + .then((reply) => this.parseGroupByReply(opts, reply)); } public hybrid( query: string, generate: GenerateOptions, - opts?: HybridOptions - ): Promise> { - return this.connection.search(this.name).then((search) => - search - .withHybrid({ + opts?: BaseHybridOptions + ): Promise>; + public hybrid( + query: string, + generate: GenerateOptions, + opts: GroupByHybridOptions + ): Promise>; + public hybrid(query: string, generate: GenerateOptions, opts?: HybridOptions): GenerateReturn { + return Promise.all([ + this.checkSupportForNamedVectors(opts), + this.checkSupportForBm25AndHybridGroupByQueries('Bm25', opts), + ]) + .then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant)) + .then((search) => + search.withHybrid({ ...Serialize.hybrid({ query, ...opts }), generative: Serialize.generative(generate), + groupBy: Serialize.isGroupBy>(opts) + ? Serialize.groupBy(opts.groupBy) + : undefined, }) - .then((reply) => Deserialize.generate(reply)) - ); + ) + .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearImage(image: string, generate: GenerateOptions): Promise>; public nearImage( image: string, generate: GenerateOptions, - opts: BaseNearOptions + opts?: BaseNearOptions ): Promise>; public nearImage( image: string, @@ -114,28 +165,24 @@ class GenerateManager implements Generate { opts: GroupByNearOptions ): Promise>; public nearImage(image: string, generate: GenerateOptions, opts?: NearOptions): GenerateReturn { - return this.connection.search(this.name).then((search) => - search - .withNearImage({ + return this.checkSupportForNamedVectors(opts) + .then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant)) + .then((search) => + search.withNearImage({ ...Serialize.nearImage({ image, ...(opts ? opts : {}) }), generative: Serialize.generative(generate), groupBy: Serialize.isGroupBy>(opts) ? Serialize.groupBy(opts.groupBy) : undefined, }) - .then((reply) => - Serialize.isGroupBy>(opts) - ? Deserialize.generateGroupBy(reply) - : Deserialize.generate(reply) - ) - ); + ) + .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearObject(id: string, generate: GenerateOptions): Promise>; public nearObject( id: string, generate: GenerateOptions, - opts: BaseNearOptions + opts?: BaseNearOptions ): Promise>; public nearObject( id: string, @@ -143,28 +190,24 @@ class GenerateManager implements Generate { opts: GroupByNearOptions ): Promise>; public nearObject(id: string, generate: GenerateOptions, opts?: NearOptions): GenerateReturn { - return this.connection.search(this.name).then((search) => - search - .withNearObject({ + return this.checkSupportForNamedVectors(opts) + .then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant)) + .then((search) => + search.withNearObject({ ...Serialize.nearObject({ id, ...(opts ? opts : {}) }), generative: Serialize.generative(generate), groupBy: Serialize.isGroupBy>(opts) ? Serialize.groupBy(opts.groupBy) : undefined, }) - .then((reply) => - Serialize.isGroupBy>(opts) - ? Deserialize.generateGroupBy(reply) - : Deserialize.generate(reply) - ) - ); + ) + .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearText(query: string | string[], generate: GenerateOptions): Promise>; public nearText( query: string | string[], generate: GenerateOptions, - opts: BaseNearTextOptions + opts?: BaseNearTextOptions ): Promise>; public nearText( query: string | string[], @@ -176,28 +219,24 @@ class GenerateManager implements Generate { generate: GenerateOptions, opts?: NearOptions ): GenerateReturn { - return this.connection.search(this.name).then((search) => - search - .withNearText({ + return this.checkSupportForNamedVectors(opts) + .then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant)) + .then((search) => + search.withNearText({ ...Serialize.nearText({ query, ...(opts ? opts : {}) }), generative: Serialize.generative(generate), groupBy: Serialize.isGroupBy>(opts) ? Serialize.groupBy(opts.groupBy) : undefined, }) - .then((reply) => - Serialize.isGroupBy>(opts) - ? Deserialize.generateGroupBy(reply) - : Deserialize.generate(reply) - ) - ); + ) + .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearVector(vector: number[], generate: GenerateOptions): Promise>; public nearVector( vector: number[], generate: GenerateOptions, - opts: BaseNearOptions + opts?: BaseNearOptions ): Promise>; public nearVector( vector: number[], @@ -209,33 +248,25 @@ class GenerateManager implements Generate { generate: GenerateOptions, opts?: NearOptions ): GenerateReturn { - return this.connection.search(this.name).then((search) => - search - .withNearVector({ + return this.checkSupportForNamedVectors(opts) + .then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant)) + .then((search) => + search.withNearVector({ ...Serialize.nearVector({ vector, ...(opts ? opts : {}) }), generative: Serialize.generative(generate), groupBy: Serialize.isGroupBy>(opts) ? Serialize.groupBy(opts.groupBy) : undefined, }) - .then((reply) => - Serialize.isGroupBy>(opts) - ? Deserialize.generateGroupBy(reply) - : Deserialize.generate(reply) - ) - ); + ) + .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearMedia( - media: string, - type: NearMediaType, - generate: GenerateOptions - ): Promise>; public nearMedia( media: string, type: NearMediaType, generate: GenerateOptions, - opts: BaseNearOptions + opts?: BaseNearOptions ): Promise>; public nearMedia( media: string, @@ -249,151 +280,64 @@ class GenerateManager implements Generate { generate: GenerateOptions, opts?: NearOptions ): GenerateReturn { - return this.connection.search(this.name).then((search) => { - let reply: Promise; - const generative = Serialize.generative(generate); - const groupBy = Serialize.isGroupBy>(opts) - ? Serialize.groupBy(opts.groupBy) - : undefined; - switch (type) { - case 'audio': - reply = search.withNearAudio({ - ...Serialize.nearAudio({ audio: media, ...(opts ? opts : {}) }), - generative, - groupBy, - }); - break; - case 'depth': - reply = search.withNearDepth({ - ...Serialize.nearDepth({ depth: media, ...(opts ? opts : {}) }), - generative, - groupBy, - }); - break; - case 'image': - reply = search.withNearImage({ - ...Serialize.nearImage({ image: media, ...(opts ? opts : {}) }), - generative, - groupBy, - }); - break; - case 'imu': - reply = search.withNearIMU({ - ...Serialize.nearIMU({ imu: media, ...(opts ? opts : {}) }), - generative, - groupBy, - }); - break; - case 'thermal': - reply = search.withNearThermal({ - ...Serialize.nearThermal({ thermal: media, ...(opts ? opts : {}) }), - generative, - groupBy, - }); - break; - case 'video': - reply = search.withNearVideo({ - ...Serialize.nearVideo({ video: media, ...(opts ? opts : {}) }), - generative, - groupBy, - }); - break; - default: - throw new WeaviateInvalidInputError(`Invalid media type: ${type}`); - } - return reply.then((reply) => - groupBy ? Deserialize.generateGroupBy(reply) : Deserialize.generate(reply) - ); - }); + return this.checkSupportForNamedVectors(opts) + .then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant)) + .then((search) => { + let reply: Promise; + const generative = Serialize.generative(generate); + const groupBy = Serialize.isGroupBy>(opts) + ? Serialize.groupBy(opts.groupBy) + : undefined; + switch (type) { + case 'audio': + reply = search.withNearAudio({ + ...Serialize.nearAudio({ audio: media, ...(opts ? opts : {}) }), + generative, + groupBy, + }); + break; + case 'depth': + reply = search.withNearDepth({ + ...Serialize.nearDepth({ depth: media, ...(opts ? opts : {}) }), + generative, + groupBy, + }); + break; + case 'image': + reply = search.withNearImage({ + ...Serialize.nearImage({ image: media, ...(opts ? opts : {}) }), + generative, + groupBy, + }); + break; + case 'imu': + reply = search.withNearIMU({ + ...Serialize.nearIMU({ imu: media, ...(opts ? opts : {}) }), + generative, + groupBy, + }); + break; + case 'thermal': + reply = search.withNearThermal({ + ...Serialize.nearThermal({ thermal: media, ...(opts ? opts : {}) }), + generative, + groupBy, + }); + break; + case 'video': + reply = search.withNearVideo({ + ...Serialize.nearVideo({ video: media, ...(opts ? opts : {}) }), + generative, + groupBy, + }); + break; + default: + throw new WeaviateInvalidInputError(`Invalid media type: ${type}`); + } + return reply; + }) + .then((reply) => this.parseGroupByReply(opts, reply)); } } -export interface Generate { - fetchObjects: (generate: GenerateOptions, opts?: FetchObjectsOptions) => Promise>; - bm25: (query: string, generate: GenerateOptions, opts?: Bm25Options) => Promise>; - hybrid: ( - query: string, - generate: GenerateOptions, - opts?: HybridOptions - ) => Promise>; - - nearImage(image: string, generate: GenerateOptions): Promise>; - nearImage( - image: string, - generate: GenerateOptions, - opts: BaseNearOptions - ): Promise>; - nearImage( - image: string, - generate: GenerateOptions, - opts: GroupByNearOptions - ): Promise>; - nearImage(image: string, generate: GenerateOptions, opts?: NearOptions): GenerateReturn; - - nearMedia(media: string, type: NearMediaType, generate: GenerateOptions): Promise>; - nearMedia( - media: string, - type: NearMediaType, - generate: GenerateOptions, - opts: BaseNearOptions - ): Promise>; - nearMedia( - media: string, - type: NearMediaType, - generate: GenerateOptions, - opts: GroupByNearOptions - ): Promise>; - nearMedia( - media: string, - type: NearMediaType, - generate: GenerateOptions, - opts?: NearOptions - ): GenerateReturn; - - nearObject(id: string, generate: GenerateOptions): Promise>; - nearObject( - id: string, - generate: GenerateOptions, - opts: BaseNearOptions - ): Promise>; - nearObject( - id: string, - generate: GenerateOptions, - opts: GroupByNearOptions - ): Promise>; - nearObject(id: string, generate: GenerateOptions, opts?: NearOptions): GenerateReturn; - - nearText(query: string | string[], generate: GenerateOptions): Promise>; - nearText( - query: string | string[], - generate: GenerateOptions, - opts: BaseNearTextOptions - ): Promise>; - nearText( - query: string | string[], - generate: GenerateOptions, - opts: GroupByNearTextOptions - ): Promise>; - nearText( - query: string | string[], - generate: GenerateOptions, - opts?: NearTextOptions - ): GenerateReturn; - - nearVector(vector: number[], generate: GenerateOptions): Promise>; - nearVector( - vector: number[], - generate: GenerateOptions, - opts: BaseNearOptions - ): Promise>; - nearVector( - vector: number[], - generate: GenerateOptions, - opts: GroupByNearOptions - ): Promise>; - nearVector(vector: number[], generate: GenerateOptions, opts?: NearOptions): GenerateReturn; -} - -export type GenerateReturn = Promise> | Promise>; - export default GenerateManager.use; diff --git a/src/collections/generate/integration.test.ts b/src/collections/generate/integration.test.ts index 2b2a1a80..be8eeac6 100644 --- a/src/collections/generate/integration.test.ts +++ b/src/collections/generate/integration.test.ts @@ -1,9 +1,10 @@ /* eslint-disable @typescript-eslint/no-non-null-assertion */ /* eslint-disable @typescript-eslint/no-non-null-asserted-optional-chain */ import weaviate, { WeaviateClient } from '../../index.js'; -import { GenerateOptions } from './index.js'; +import { GenerateOptions } from './types.js'; import { GroupByOptions } from '../types/index.js'; import { Collection } from '../collection/index.js'; +import { WeaviateUnsupportedFeatureError } from '../../errors.js'; const maybe = process.env.OPENAI_APIKEY ? describe : describe.skip; @@ -70,7 +71,7 @@ maybe('Testing of the collection.generate methods with a simple collection', () }); }); const res = await collection.query.fetchObjectById(id, { includeVector: true }); - vector = res?.vectors.default!; + vector = res?.vectors.vector!; }); describe('using a non-generic collection', () => { @@ -252,32 +253,41 @@ maybe('Testing of the groupBy collection.generate methods with a simple collecti // expect(ret.objects[0].belongsToGroup).toEqual('test'); // }); - // it('should groupBy with bm25', async () => { - // const ret = await collection.groupBy.bm25({ - // query: 'test', - // ...groupByArgs, - // }); - // expect(ret.objects.length).toEqual(1); - // expect(ret.groups).toBeDefined(); - // expect(Object.keys(ret.groups)).toEqual(['test']); - // expect(ret.objects[0].properties.testProp).toEqual('test'); - // expect(ret.objects[0].metadata.uuid).toEqual(id); - // expect(ret.objects[0].belongsToGroup).toEqual('test'); - // }); - - // it('should groupBy with hybrid', async () => { - // const ret = await collection.groupBy.hybrid({ - // query: 'test', - // ...groupByArgs, + it('should groupBy with bm25', async () => { + const query = () => + collection.generate.bm25('test', generateOpts, { + groupBy: groupByArgs, + }); + if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 25, 0))) { + await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError); + return; + } + const ret = await query(); + expect(ret.objects.length).toEqual(1); + expect(ret.groups).toBeDefined(); + expect(Object.keys(ret.groups)).toEqual(['test']); + expect(ret.objects[0].properties.testProp).toEqual('test'); + expect(ret.objects[0].uuid).toEqual(id); + expect(ret.objects[0].belongsToGroup).toEqual('test'); + }); - // }); - // expect(ret.objects.length).toEqual(1); - // expect(ret.groups).toBeDefined(); - // expect(Object.keys(ret.groups)).toEqual(['test']); - // expect(ret.objects[0].properties.testProp).toEqual('test'); - // expect(ret.objects[0].metadata.uuid).toEqual(id); - // expect(ret.objects[0].belongsToGroup).toEqual('test'); - // }); + it('should groupBy with hybrid', async () => { + const query = () => + collection.generate.hybrid('test', generateOpts, { + groupBy: groupByArgs, + }); + if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 25, 0))) { + await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError); + return; + } + const ret = await query(); + expect(ret.objects.length).toEqual(1); + expect(ret.groups).toBeDefined(); + expect(Object.keys(ret.groups)).toEqual(['test']); + expect(ret.objects[0].properties.testProp).toEqual('test'); + expect(ret.objects[0].uuid).toEqual(id); + expect(ret.objects[0].belongsToGroup).toEqual('test'); + }); it('should groupBy with nearObject', async () => { const ret = await collection.generate.nearObject(id, generateOpts, { diff --git a/src/collections/generate/types.ts b/src/collections/generate/types.ts new file mode 100644 index 00000000..a7db554e --- /dev/null +++ b/src/collections/generate/types.ts @@ -0,0 +1,375 @@ +import { + FetchObjectsOptions, + Bm25Options, + BaseHybridOptions, + GroupByHybridOptions, + HybridOptions, + BaseNearTextOptions, + BaseNearOptions, + NearMediaType, + NearOptions, + NearTextOptions, + GroupByNearOptions, + GroupByNearTextOptions, + BaseBm25Options, + GroupByBm25Options, +} from '../query/types.js'; +import { GenerativeReturn, GenerativeGroupByReturn } from '../types/index.js'; + +export type GenerateOptions = { + singlePrompt?: string; + groupedTask?: string; + groupedProperties?: T extends undefined ? string[] : (keyof T)[]; +}; + +interface Bm25 { + /** + * Perform retrieval-augmented generation (RaG) on the results of a keyword-based BM25 search of objects in this collection. + * + * See the [docs](https://weaviate.io/developers/weaviate/search/bm25) for a more detailed explanation. + * + * This overload is for performing a search without the `groupBy` param. + * + * @overload + * @param {string} query - The query to search for. + * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {BaseBm25Options} [opts] - The available options for performing the BM25 search. + * @return {Promise>} - The results of the search including the generated data. + */ + bm25(query: string, generate: GenerateOptions, opts?: BaseBm25Options): Promise>; + /** + * Perform retrieval-augmented generation (RaG) on the results of a keyword-based BM25 search of objects in this collection. + * + * See the [docs](https://weaviate.io/developers/weaviate/search/bm25) for a more detailed explanation. + * + * This overload is for performing a search with the `groupBy` param. + * + * @overload + * @param {string} query - The query to search for. + * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GroupByBm25Options} opts - The available options for performing the BM25 search. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + */ + bm25( + query: string, + generate: GenerateOptions, + opts: GroupByBm25Options + ): Promise>; + /** + * Perform retrieval-augmented generation (RaG) on the results of a keyword-based BM25 search of objects in this collection. + * + * See the [docs](https://weaviate.io/developers/weaviate/search/bm25) for a more detailed explanation. + * + * This overload is for performing a search with a programmatically defined `opts` param. + * + * @overload + * @param {string} query - The query to search for. + * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {Bm25Options} [opts] - The available options for performing the BM25 search. + * @return {GenerateReturn} - The results of the search including the generated data. + */ + bm25(query: string, generate: GenerateOptions, opts?: Bm25Options): GenerateReturn; +} + +interface Hybrid { + /** + * Perform retrieval-augmented generation (RaG) on the results of an object search in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. + * + * See the [docs](https://weaviate.io/developers/weaviate/search/hybrid) for a more detailed explanation. + * + * This overload is for performing a search without the `groupBy` param. + * + * @overload + * @param {string} query - The query to search for. + * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {BaseHybridOptions} [opts] - The available options for performing the hybrid search. + * @return {Promise>} - The results of the search including the generated data. + */ + hybrid( + query: string, + generate: GenerateOptions, + opts?: BaseHybridOptions + ): Promise>; + /** + * Perform retrieval-augmented generation (RaG) on the results of an object search in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. + * + * See the [docs](https://weaviate.io/developers/weaviate/search/hybrid) for a more detailed explanation. + * + * This overload is for performing a search with the `groupBy` param. + * + * @overload + * @param {string} query - The query to search for. + * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GroupByHybridOptions} opts - The available options for performing the hybrid search. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + */ + hybrid( + query: string, + generate: GenerateOptions, + opts: GroupByHybridOptions + ): Promise>; + /** + * Perform retrieval-augmented generation (RaG) on the results of an object search in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. + * + * See the [docs](https://weaviate.io/developers/weaviate/search/hybrid) for a more detailed explanation. + * + * This overload is for performing a search with a programmatically defined `opts` param. + * + * @overload + * @param {string} query - The query to search for. + * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {HybridOptions} [opts] - The available options for performing the hybrid search. + * @return {GenerateReturn} - The results of the search including the generated data. + */ + hybrid(query: string, generate: GenerateOptions, opts?: HybridOptions): GenerateReturn; +} + +interface NearMedia { + /** + * Perform retrieval-augmented generation (RaG) on the results of a by-audio object search in this collection using an audio-capable vectorization module and vector-based similarity search. + * + * See the [docs](https://weaviate.io/developers/weaviate/modules/retriever-vectorizer-modules/multi2vec-bind) for a more detailed explanation. + * + * NOTE: You must have a multi-media-capable vectorization module installed in order to use this method, e.g. `multi2vec-bind`. + * + * This overload is for performing a search without the `groupBy` param. + * + * @overload + * @param {string} media - The media file to search for. + * @param {NearMediaType} type - The type of media to search for. + * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {BaseNearOptions} [opts] - The available options for performing the near-media search. + * @return {Promise>} - The results of the search including the generated data. + */ + nearMedia( + media: string, + type: NearMediaType, + generate: GenerateOptions, + opts?: BaseNearOptions + ): Promise>; + /** + * Perform retrieval-augmented generation (RaG) on the results of a by-audio object search in this collection using an audio-capable vectorization module and vector-based similarity search. + * + * See the [docs](https://weaviate.io/developers/weaviate/modules/retriever-vectorizer-modules/multi2vec-bind) for a more detailed explanation. + * + * NOTE: You must have a multi-media-capable vectorization module installed in order to use this method, e.g. `multi2vec-bind`. + * + * This overload is for performing a search with the `groupBy` param. + * + * @overload + * @param {string} media - The media file to search for. + * @param {NearMediaType} type - The type of media to search for. + * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GroupByNearOptions} opts - The available options for performing the near-media search. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + */ + nearMedia( + media: string, + type: NearMediaType, + generate: GenerateOptions, + opts: GroupByNearOptions + ): Promise>; + /** + * Perform retrieval-augmented generation (RaG) on the results of a by-audio object search in this collection using an audio-capable vectorization module and vector-based similarity search. + * + * See the [docs](https://weaviate.io/developers/weaviate/modules/retriever-vectorizer-modules/multi2vec-bind) for a more detailed explanation. + * + * NOTE: You must have a multi-media-capable vectorization module installed in order to use this method, e.g. `multi2vec-bind`. + * + * This overload is for performing a search with a programmatically defined `opts` param. + * + * @overload + * @param {string} media - The media file to search for. + * @param {NearMediaType} type - The type of media to search for. + * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {NearOptions} [opts] - The available options for performing the near-media search. + * @return {GenerateReturn} - The results of the search including the generated data. + */ + nearMedia( + media: string, + type: NearMediaType, + generate: GenerateOptions, + opts?: NearOptions + ): GenerateReturn; +} + +interface NearObject { + /** + * Perform retrieval-augmented generation (RaG) on the results of a by-object object search in this collection using a vector-based similarity search. + * + * See the [docs](https://weaviate.io/developers/weaviate/api/graphql/search-operators#nearobject) for a more detailed explanation. + * + * This overload is for performing a search without the `groupBy` param. + * + * @overload + * @param {string} id - The ID of the object to search for. + * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {BaseNearOptions} [opts] - The available options for performing the near-object search. + * @return {Promise>} - The results of the search including the generated data. + */ + nearObject( + id: string, + generate: GenerateOptions, + opts?: BaseNearOptions + ): Promise>; + /** + * Perform retrieval-augmented generation (RaG) on the results of a by-object object search in this collection using a vector-based similarity search. + * + * See the [docs](https://weaviate.io/developers/weaviate/api/graphql/search-operators#nearobject) for a more detailed explanation. + * + * This overload is for performing a search with the `groupBy` param. + * + * @overload + * @param {string} id - The ID of the object to search for. + * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GroupByNearOptions} opts - The available options for performing the near-object search. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + */ + nearObject( + id: string, + generate: GenerateOptions, + opts: GroupByNearOptions + ): Promise>; + /** + * Perform retrieval-augmented generation (RaG) on the results of a by-object object search in this collection using a vector-based similarity search. + * + * See the [docs](https://weaviate.io/developers/weaviate/api/graphql/search-operators#nearobject) for a more detailed explanation. + * + * This overload is for performing a search with a programmatically defined `opts` param. + * + * @overload + * @param {string} id - The ID of the object to search for. + * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {NearOptions} [opts] - The available options for performing the near-object search. + * @return {GenerateReturn} - The results of the search including the generated data. + */ + nearObject(id: string, generate: GenerateOptions, opts?: NearOptions): GenerateReturn; +} + +interface NearText { + /** + * Perform retrieval-augmented generation (RaG) on the results of a by-image object search in this collection using the image-capable vectorization module and vector-based similarity search. + * + * See the [docs](https://weaviate.io/developers/weaviate/api/graphql/search-operators#neartext) for a more detailed explanation. + * + * NOTE: You must have a text-capable vectorization module installed in order to use this method, e.g. any of the `text2vec-` and `multi2vec-` modules. + * + * This overload is for performing a search without the `groupBy` param. + * + * @overload + * @param {string | string[]} query - The query to search for. + * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {BaseNearTextOptions} [opts] - The available options for performing the near-text search. + * @return {Promise>} - The results of the search including the generated data. + */ + nearText( + query: string | string[], + generate: GenerateOptions, + opts?: BaseNearTextOptions + ): Promise>; + /** + * Perform retrieval-augmented generation (RaG) on the results of a by-image object search in this collection using the image-capable vectorization module and vector-based similarity search. + * + * See the [docs](https://weaviate.io/developers/weaviate/api/graphql/search-operators#neartext) for a more detailed explanation. + * + * NOTE: You must have a text-capable vectorization module installed in order to use this method, e.g. any of the `text2vec-` and `multi2vec-` modules. + * + * This overload is for performing a search with the `groupBy` param. + * + * @overload + * @param {string | string[]} query - The query to search for. + * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GroupByNearTextOptions} opts - The available options for performing the near-text search. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + */ + nearText( + query: string | string[], + generate: GenerateOptions, + opts: GroupByNearTextOptions + ): Promise>; + /** + * Perform retrieval-augmented generation (RaG) on the results of a by-image object search in this collection using the image-capable vectorization module and vector-based similarity search. + * + * See the [docs](https://weaviate.io/developers/weaviate/api/graphql/search-operators#neartext) for a more detailed explanation. + * + * NOTE: You must have a text-capable vectorization module installed in order to use this method, e.g. any of the `text2vec-` and `multi2vec-` modules. + * + * This overload is for performing a search with a programmatically defined `opts` param. + * + * @overload + * @param {string | string[]} query - The query to search for. + * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {NearTextOptions} [opts] - The available options for performing the near-text search. + * @return {GenerateReturn} - The results of the search including the generated data. + */ + nearText( + query: string | string[], + generate: GenerateOptions, + opts?: NearTextOptions + ): GenerateReturn; +} + +interface NearVector { + /** + * Perform retrieval-augmented generation (RaG) on the results of a by-vector object search in this collection using vector-based similarity search. + * + * See the [docs](https://weaviate.io/developers/weaviate/search/similarity) for a more detailed explanation. + * + * This overload is for performing a search without the `groupBy` param. + * + * @overload + * @param {number[]} vector - The vector to search for. + * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {BaseNearOptions} [opts] - The available options for performing the near-vector search. + * @return {Promise>} - The results of the search including the generated data. + */ + nearVector( + vector: number[], + generate: GenerateOptions, + opts?: BaseNearOptions + ): Promise>; + /** + * Perform retrieval-augmented generation (RaG) on the results of a by-vector object search in this collection using vector-based similarity search. + * + * See the [docs](https://weaviate.io/developers/weaviate/search/similarity) for a more detailed explanation. + * + * This overload is for performing a search with the `groupBy` param. + * + * @overload + * @param {number[]} vector - The vector to search for. + * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GroupByNearOptions} opts - The available options for performing the near-vector search. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + */ + nearVector( + vector: number[], + generate: GenerateOptions, + opts: GroupByNearOptions + ): Promise>; + /** + * Perform retrieval-augmented generation (RaG) on the results of a by-vector object search in this collection using vector-based similarity search. + * + * See the [docs](https://weaviate.io/developers/weaviate/search/similarity) for a more detailed explanation. + * + * This overload is for performing a search with a programmatically defined `opts` param. + * + * @overload + * @param {number[]} vector - The vector to search for. + * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {NearOptions} [opts] - The available options for performing the near-vector search. + * @return {GenerateReturn} - The results of the search including the generated data. + */ + nearVector(vector: number[], generate: GenerateOptions, opts?: NearOptions): GenerateReturn; +} + +export interface Generate + extends Bm25, + Hybrid, + NearMedia, + NearObject, + NearText, + NearVector { + fetchObjects: (generate: GenerateOptions, opts?: FetchObjectsOptions) => Promise>; +} + +export type GenerateReturn = Promise> | Promise>; diff --git a/src/collections/query/index.ts b/src/collections/query/index.ts index 73c403d0..32ca9709 100644 --- a/src/collections/query/index.ts +++ b/src/collections/query/index.ts @@ -4,21 +4,24 @@ import Connection from '../../connection/grpc.js'; import { toBase64FromBlob } from '../../utils/base64.js'; -import { ObjectsPath } from '../../data/path.js'; import { DbVersionSupport } from '../../utils/dbVersion.js'; import { ConsistencyLevel } from '../../data/index.js'; import { Deserialize } from '../deserialize/index.js'; import { Serialize } from '../serialize/index.js'; -import { WeaviateObject, WeaviateReturn, GroupByReturn } from '../types/index.js'; +import { WeaviateObject, WeaviateReturn, GroupByReturn, GroupByOptions } from '../types/index.js'; import { SearchReply } from '../../proto/v1/search_get.js'; import { + BaseBm25Options, + BaseHybridOptions, BaseNearOptions, BaseNearTextOptions, Bm25Options, FetchObjectByIdOptions, FetchObjectsOptions, + GroupByBm25Options, + GroupByHybridOptions, GroupByNearOptions, GroupByNearTextOptions, HybridOptions, @@ -27,15 +30,16 @@ import { NearTextOptions, Query, QueryReturn, + SearchOptions, } from './types.js'; -import { WeaviateInvalidInputError } from '../../errors.js'; +import { WeaviateInvalidInputError, WeaviateUnsupportedFeatureError } from '../../errors.js'; class QueryManager implements Query { - connection: Connection; - name: string; - dbVersionSupport: DbVersionSupport; - consistencyLevel?: ConsistencyLevel; - tenant?: string; + private connection: Connection; + private name: string; + private dbVersionSupport: DbVersionSupport; + private consistencyLevel?: ConsistencyLevel; + private tenant?: string; private constructor( connection: Connection, @@ -61,42 +65,99 @@ class QueryManager implements Query { return new QueryManager(connection, name, dbVersionSupport, consistencyLevel, tenant); } + private checkSupportForNamedVectors = async (opts?: BaseNearOptions) => { + if (!Serialize.isNamedVectors(opts)) return; + const check = await this.dbVersionSupport.supportsNamedVectors(); + if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message); + }; + + private checkSupportForBm25AndHybridGroupByQueries = async ( + query: 'Bm25' | 'Hybrid', + opts?: SearchOptions | GroupByOptions + ) => { + if (!Serialize.isGroupBy(opts)) return; + const check = await this.dbVersionSupport.supportsBm25AndHybridGroupByQueries(); + if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message(query)); + }; + + private checkSupportForHybridNearTextAndNearVectorSubSearches = async (opts?: HybridOptions) => { + if (opts?.vector === undefined || Array.isArray(opts.vector)) return; + const check = await this.dbVersionSupport.supportsHybridNearTextAndNearVectorSubsearchQueries(); + if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message); + }; + + private async parseReply(reply: SearchReply) { + const deserialize = await Deserialize.use(this.dbVersionSupport); + return deserialize.query(reply); + } + + private async parseGroupByReply( + opts: SearchOptions | GroupByOptions | undefined, + reply: SearchReply + ) { + const deserialize = await Deserialize.use(this.dbVersionSupport); + return Serialize.isGroupBy(opts) ? deserialize.groupBy(reply) : deserialize.query(reply); + } + public fetchObjectById(id: string, opts?: FetchObjectByIdOptions): Promise | null> { - const path = new ObjectsPath(this.dbVersionSupport); - return this.connection.search(this.name, this.consistencyLevel, this.tenant).then((search) => - search - .withFetch(Serialize.fetchObjectById({ id, ...opts })) - .then((reply) => Deserialize.generate(reply)) - .then((ret) => (ret.objects.length === 1 ? ret.objects[0] : null)) - ); + return this.checkSupportForNamedVectors(opts) + .then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant)) + .then((search) => search.withFetch(Serialize.fetchObjectById({ id, ...opts }))) + .then((reply) => this.parseReply(reply)) + .then((ret) => (ret.objects.length === 1 ? ret.objects[0] : null)); } public fetchObjects(opts?: FetchObjectsOptions): Promise> { - return this.connection - .search(this.name, this.consistencyLevel, this.tenant) + return this.checkSupportForNamedVectors(opts) + .then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant)) .then((search) => search.withFetch(Serialize.fetchObjects(opts))) - .then((reply) => Deserialize.generate(reply)); + .then((reply) => this.parseReply(reply)); } - public bm25(query: string, opts?: Bm25Options): Promise> { - return this.connection - .search(this.name, this.consistencyLevel, this.tenant) - .then((search) => search.withBm25(Serialize.bm25({ query, ...opts }))) - .then((reply) => Deserialize.generate(reply)); + public bm25(query: string, opts?: BaseBm25Options): Promise>; + public bm25(query: string, opts: GroupByBm25Options): Promise>; + public bm25(query: string, opts?: Bm25Options): QueryReturn { + return Promise.all([ + this.checkSupportForNamedVectors(opts), + this.checkSupportForBm25AndHybridGroupByQueries('Bm25', opts), + ]) + .then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant)) + .then((search) => + search.withBm25({ + ...Serialize.bm25({ query, ...opts }), + groupBy: Serialize.isGroupBy>(opts) + ? Serialize.groupBy(opts.groupBy) + : undefined, + }) + ) + .then((reply) => this.parseGroupByReply(opts, reply)); } - public hybrid(query: string, opts?: HybridOptions): Promise> { - return this.connection - .search(this.name, this.consistencyLevel, this.tenant) - .then((search) => search.withHybrid(Serialize.hybrid({ query, ...opts }))) - .then((reply) => Deserialize.generate(reply)); + public hybrid(query: string, opts?: BaseHybridOptions): Promise>; + public hybrid(query: string, opts: GroupByHybridOptions): Promise>; + public hybrid(query: string, opts?: HybridOptions): QueryReturn { + return Promise.all([ + this.checkSupportForNamedVectors(opts), + this.checkSupportForBm25AndHybridGroupByQueries('Hybrid', opts), + this.checkSupportForHybridNearTextAndNearVectorSubSearches(opts), + ]) + .then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant)) + .then((search) => + search.withHybrid({ + ...Serialize.hybrid({ query, ...opts }), + groupBy: Serialize.isGroupBy>(opts) + ? Serialize.groupBy(opts.groupBy) + : undefined, + }) + ) + .then((reply) => this.parseGroupByReply(opts, reply)); } public nearImage(image: string | Blob, opts?: BaseNearOptions): Promise>; public nearImage(image: string | Blob, opts: GroupByNearOptions): Promise>; public nearImage(image: string | Blob, opts?: NearOptions): QueryReturn { - return this.connection - .search(this.name, this.consistencyLevel, this.tenant) + return this.checkSupportForNamedVectors(opts) + .then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant)) .then((search) => { const imagePromise = typeof image === 'string' ? Promise.resolve(image) : toBase64FromBlob(image); return imagePromise.then((image) => @@ -108,9 +169,7 @@ class QueryManager implements Query { }) ); }) - .then((reply) => - Serialize.isGroupBy(opts) ? Deserialize.groupBy(reply) : Deserialize.query(reply) - ); + .then((reply) => this.parseGroupByReply(opts, reply)); } public nearMedia( @@ -125,53 +184,54 @@ class QueryManager implements Query { ): Promise>; public nearMedia(media: string | Blob, type: NearMediaType, opts?: NearOptions): QueryReturn { const mediaPromise = typeof media === 'string' ? Promise.resolve(media) : toBase64FromBlob(media); - return this.connection.search(this.name, this.consistencyLevel, this.tenant).then((search) => { - let reply: Promise; - switch (type) { - case 'audio': - reply = mediaPromise.then((media) => - search.withNearAudio(Serialize.nearAudio({ audio: media, ...(opts ? opts : {}) })) - ); - break; - case 'depth': - reply = mediaPromise.then((media) => - search.withNearDepth(Serialize.nearDepth({ depth: media, ...(opts ? opts : {}) })) - ); - break; - case 'image': - reply = mediaPromise.then((media) => - search.withNearImage(Serialize.nearImage({ image: media, ...(opts ? opts : {}) })) - ); - break; - case 'imu': - reply = mediaPromise.then((media) => - search.withNearIMU(Serialize.nearIMU({ imu: media, ...(opts ? opts : {}) })) - ); - break; - case 'thermal': - reply = mediaPromise.then((media) => - search.withNearThermal(Serialize.nearThermal({ thermal: media, ...(opts ? opts : {}) })) - ); - break; - case 'video': - reply = mediaPromise.then((media) => - search.withNearVideo(Serialize.nearVideo({ video: media, ...(opts ? opts : {}) })) - ); - break; - default: - throw new WeaviateInvalidInputError(`Invalid media type: ${type}`); - } - return reply.then((reply) => - Serialize.isGroupBy(opts) ? Deserialize.groupBy(reply) : Deserialize.query(reply) - ); - }); + return this.checkSupportForNamedVectors(opts) + .then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant)) + .then((search) => { + let reply: Promise; + switch (type) { + case 'audio': + reply = mediaPromise.then((media) => + search.withNearAudio(Serialize.nearAudio({ audio: media, ...(opts ? opts : {}) })) + ); + break; + case 'depth': + reply = mediaPromise.then((media) => + search.withNearDepth(Serialize.nearDepth({ depth: media, ...(opts ? opts : {}) })) + ); + break; + case 'image': + reply = mediaPromise.then((media) => + search.withNearImage(Serialize.nearImage({ image: media, ...(opts ? opts : {}) })) + ); + break; + case 'imu': + reply = mediaPromise.then((media) => + search.withNearIMU(Serialize.nearIMU({ imu: media, ...(opts ? opts : {}) })) + ); + break; + case 'thermal': + reply = mediaPromise.then((media) => + search.withNearThermal(Serialize.nearThermal({ thermal: media, ...(opts ? opts : {}) })) + ); + break; + case 'video': + reply = mediaPromise.then((media) => + search.withNearVideo(Serialize.nearVideo({ video: media, ...(opts ? opts : {}) })) + ); + break; + default: + throw new WeaviateInvalidInputError(`Invalid media type: ${type}`); + } + return reply; + }) + .then((reply) => this.parseGroupByReply(opts, reply)); } public nearObject(id: string, opts?: BaseNearOptions): Promise>; public nearObject(id: string, opts: GroupByNearOptions): Promise>; public nearObject(id: string, opts?: NearOptions): QueryReturn { - return this.connection - .search(this.name, this.consistencyLevel, this.tenant) + return this.checkSupportForNamedVectors(opts) + .then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant)) .then((search) => search.withNearObject({ ...Serialize.nearObject({ id, ...(opts ? opts : {}) }), @@ -180,16 +240,14 @@ class QueryManager implements Query { : undefined, }) ) - .then((reply) => - Serialize.isGroupBy(opts) ? Deserialize.groupBy(reply) : Deserialize.query(reply) - ); + .then((reply) => this.parseGroupByReply(opts, reply)); } public nearText(query: string | string[], opts?: BaseNearTextOptions): Promise>; public nearText(query: string | string[], opts: GroupByNearTextOptions): Promise>; public nearText(query: string | string[], opts?: NearTextOptions): QueryReturn { - return this.connection - .search(this.name, this.consistencyLevel, this.tenant) + return this.checkSupportForNamedVectors(opts) + .then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant)) .then((search) => search.withNearText({ ...Serialize.nearText({ query, ...(opts ? opts : {}) }), @@ -198,16 +256,14 @@ class QueryManager implements Query { : undefined, }) ) - .then((reply) => - Serialize.isGroupBy(opts) ? Deserialize.groupBy(reply) : Deserialize.query(reply) - ); + .then((reply) => this.parseGroupByReply(opts, reply)); } public nearVector(vector: number[], opts?: BaseNearOptions): Promise>; public nearVector(vector: number[], opts: GroupByNearOptions): Promise>; public nearVector(vector: number[], opts?: NearOptions): QueryReturn { - return this.connection - .search(this.name, this.consistencyLevel, this.tenant) + return this.checkSupportForNamedVectors(opts) + .then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant)) .then((search) => search.withNearVector({ ...Serialize.nearVector({ vector, ...(opts ? opts : {}) }), @@ -216,9 +272,7 @@ class QueryManager implements Query { : undefined, }) ) - .then((reply) => - Serialize.isGroupBy(opts) ? Deserialize.groupBy(reply) : Deserialize.query(reply) - ); + .then((reply) => this.parseGroupByReply(opts, reply)); } } diff --git a/src/collections/query/integration.test.ts b/src/collections/query/integration.test.ts index 60a2f500..b96ddab5 100644 --- a/src/collections/query/integration.test.ts +++ b/src/collections/query/integration.test.ts @@ -1,5 +1,6 @@ /* eslint-disable @typescript-eslint/no-non-null-assertion */ /* eslint-disable @typescript-eslint/no-non-null-asserted-optional-chain */ +import { WeaviateUnsupportedFeatureError } from '../../errors.js'; import weaviate, { WeaviateClient } from '../../index.js'; import { Collection } from '../collection/index.js'; import { CrossReference, Reference } from '../references/index.js'; @@ -112,6 +113,63 @@ describe('Testing of the collection.query methods with a simple collection', () expect(ret.objects[0].uuid).toEqual(id); }); + it('should query with hybrid and vector', async () => { + const ret = await collection.query.hybrid('test', { + limit: 1, + vector: vector, + }); + expect(ret.objects.length).toEqual(1); + expect(ret.objects[0].properties.testProp).toEqual('test'); + expect(ret.objects[0].properties.testProp2).toEqual('test2'); + expect(ret.objects[0].uuid).toEqual(id); + }); + + it('should query with hybrid and near text subsearch', async () => { + const query = () => + collection.query.hybrid('test', { + limit: 1, + vector: { + text: 'apple', + distance: 0.9, + moveTo: { + concepts: ['banana'], + force: 0.9, + }, + moveAway: { + concepts: ['test'], + force: 0.1, + }, + }, + }); + if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 25, 0))) { + await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError); + return; + } + const ret = await query(); + expect(ret.objects.length).toEqual(1); + expect(ret.objects[0].properties.testProp).toEqual('apple'); + expect(ret.objects[0].properties.testProp2).toEqual('banana'); + }); + + it('should query with hybrid and near vector subsearch', async () => { + const query = () => + collection.query.hybrid('test', { + limit: 1, + vector: { + vector: vector, + distance: 0.9, + }, + }); + if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 25, 0))) { + await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError); + return; + } + const ret = await query(); + expect(ret.objects.length).toEqual(1); + expect(ret.objects[0].properties.testProp).toEqual('test'); + expect(ret.objects[0].properties.testProp2).toEqual('test2'); + }); + it('should query with nearObject', async () => { const ret = await collection.query.nearObject(id, { limit: 1, targetVector: 'vector' }); expect(ret.objects.length).toEqual(1); @@ -566,10 +624,16 @@ describe('Testing of the collection.query methods with a collection with a refer }); it('should query without searching returning named vector', async () => { - const ret = await collection.query.fetchObjects({ - returnProperties: ['title'], - includeVector: ['title'], - }); + const query = () => + collection.query.fetchObjects({ + returnProperties: ['title'], + includeVector: ['title'], + }); + if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 24, 0))) { + await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError); + return; + } + const ret = await query(); ret.objects.sort((a, b) => a.properties.title.localeCompare(b.properties.title)); expect(ret.objects.length).toEqual(2); expect(ret.objects[0].properties.title).toEqual('other'); @@ -579,10 +643,16 @@ describe('Testing of the collection.query methods with a collection with a refer }); it('should query with a vector search over the named vector space', async () => { - const ret = await collection.query.nearObject(id1, { - returnProperties: ['title'], - targetVector: 'title', - }); + const query = () => + collection.query.nearObject(id1, { + returnProperties: ['title'], + targetVector: 'title', + }); + if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 24, 0))) { + await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError); + return; + } + const ret = await query(); expect(ret.objects.length).toEqual(2); expect(ret.objects[0].properties.title).toEqual('test'); expect(ret.objects[1].properties.title).toEqual('other'); @@ -676,32 +746,41 @@ describe('Testing of the groupBy collection.query methods with a simple collecti // expect(ret.objects[0].belongsToGroup).toEqual('test'); // }); - // it('should groupBy with bm25', async () => { - // const ret = await collection.groupBy.bm25({ - // query: 'test', - // ...groupByArgs, - // }); - // expect(ret.objects.length).toEqual(1); - // expect(ret.groups).toBeDefined(); - // expect(Object.keys(ret.groups)).toEqual(['test']); - // expect(ret.objects[0].properties.testProp).toEqual('test'); - // expect(ret.objects[0].metadata.uuid).toEqual(id); - // expect(ret.objects[0].belongsToGroup).toEqual('test'); - // }); - - // it('should groupBy with hybrid', async () => { - // const ret = await collection.groupBy.hybrid({ - // query: 'test', - // ...groupByArgs, + it('should groupBy with bm25', async () => { + const query = () => + collection.query.bm25('test', { + groupBy: groupByArgs, + }); + if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 25, 0))) { + await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError); + return; + } + const ret = await query(); + expect(ret.objects.length).toEqual(1); + expect(ret.groups).toBeDefined(); + expect(Object.keys(ret.groups)).toEqual(['test']); + expect(ret.objects[0].properties.testProp).toEqual('test'); + expect(ret.objects[0].uuid).toEqual(id); + expect(ret.objects[0].belongsToGroup).toEqual('test'); + }); - // }); - // expect(ret.objects.length).toEqual(1); - // expect(ret.groups).toBeDefined(); - // expect(Object.keys(ret.groups)).toEqual(['test']); - // expect(ret.objects[0].properties.testProp).toEqual('test'); - // expect(ret.objects[0].metadata.uuid).toEqual(id); - // expect(ret.objects[0].belongsToGroup).toEqual('test'); - // }); + it('should groupBy with hybrid', async () => { + const query = () => + collection.query.hybrid('test', { + groupBy: groupByArgs, + }); + if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 25, 0))) { + await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError); + return; + } + const ret = await query(); + expect(ret.objects.length).toEqual(1); + expect(ret.groups).toBeDefined(); + expect(Object.keys(ret.groups)).toEqual(['test']); + expect(ret.objects[0].properties.testProp).toEqual('test'); + expect(ret.objects[0].uuid).toEqual(id); + expect(ret.objects[0].belongsToGroup).toEqual('test'); + }); it('should groupBy with nearObject', async () => { const ret = await collection.query.nearObject(id, { diff --git a/src/collections/query/types.ts b/src/collections/query/types.ts index 88c8285f..5ed4c8c5 100644 --- a/src/collections/query/types.ts +++ b/src/collections/query/types.ts @@ -75,18 +75,27 @@ export type SearchOptions = { returnReferences?: QueryReference[]; }; -/** Options available in the `query.bm25` method */ -export type Bm25Options = SearchOptions & { +/** Base options available in the `query.bm25` method */ +export type BaseBm25Options = SearchOptions & { /** Which properties of the collection to perform the keyword search on. */ queryProperties?: PrimitiveKeys[]; }; -/** Options available in the `query.hybrid` method */ -export type HybridOptions = SearchOptions & { +/** Options available in the `query.bm25` method when specifying the `groupBy` parameter. */ +export type GroupByBm25Options = BaseBm25Options & { + /** The group by options to apply to the search. */ + groupBy: GroupByOptions; +}; + +/** Options available in the `query.bm25` method */ +export type Bm25Options = BaseBm25Options | GroupByBm25Options | undefined; + +/** Base options available in the `query.hybrid` method */ +export type BaseHybridOptions = SearchOptions & { /** The weight of the BM25 score. If not specified, the default weight specified by the server is used. */ alpha?: number; - /** The specific vector to search for. If not specified, the query is vectorized and used in the similarity search. */ - vector?: number[]; + /** The specific vector to search for or a specific vector subsearch. If not specified, the query is vectorized and used in the similarity search. */ + vector?: number[] | HybridNearTextSubSearch | HybridNearVectorSubSearch; /** The properties to search in. If not specified, all properties are searched. */ queryProperties?: PrimitiveKeys[]; /** The type of fusion to apply. If not specified, the default fusion type specified by the server is used. */ @@ -95,6 +104,30 @@ export type HybridOptions = SearchOptions & { targetVector?: string; }; +export type HybridSubSearchBase = { + certainty?: number; + distance?: number; +}; + +export type HybridNearTextSubSearch = HybridSubSearchBase & { + text: string | string[]; + moveTo?: MoveOptions; + moveAway?: MoveOptions; +}; + +export type HybridNearVectorSubSearch = HybridSubSearchBase & { + vector: number[]; +}; + +/** Options available in the `query.hybrid` method when specifying the `groupBy` parameter. */ +export type GroupByHybridOptions = BaseHybridOptions & { + /** The group by options to apply to the search. */ + groupBy: GroupByOptions; +}; + +/** Options available in the `query.hybrid` method */ +export type HybridOptions = BaseHybridOptions | GroupByHybridOptions | undefined; + /** Base options for the near search queries. */ export type BaseNearOptions = SearchOptions & { /** The minimum similarity score to return. Incompatible with the `distance` param. */ @@ -132,6 +165,90 @@ export type GroupByNearTextOptions = BaseNearTextOptions & { /** The type of the media to search for in the `query.nearMedia` method */ export type NearMediaType = 'audio' | 'depth' | 'image' | 'imu' | 'thermal' | 'video'; +interface Bm25 { + /** + * Search for objects in this collection using the keyword-based BM25 algorithm. + * + * See the [docs](https://weaviate.io/developers/weaviate/search/bm25) for a more detailed explanation. + * + * This overload is for performing a search without the `groupBy` param. + * + * @overload + * @param {string} query - The query to search for. + * @param {BaseBm25Options} [opts] - The available options for the search excluding the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. + */ + bm25(query: string, opts?: BaseBm25Options): Promise>; + /** + * Search for objects in this collection using the keyword-based BM25 algorithm. + * + * See the [docs](https://weaviate.io/developers/weaviate/search/bm25) for a more detailed explanation. + * + * This overload is for performing a search with the `groupBy` param. + * + * @overload + * @param {string} query - The query to search for. + * @param {GroupByBm25Options} opts - The available options for the search including the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. + */ + bm25(query: string, opts: GroupByBm25Options): Promise>; + /** + * Search for objects in this collection using the keyword-based BM25 algorithm. + * + * See the [docs](https://weaviate.io/developers/weaviate/search/bm25) for a more detailed explanation. + * + * This overload is for performing a search with a programmatically defined `opts` param. + * + * @overload + * @param {string} query - The query to search for. + * @param {Bm25Options} [opts] - The available options for the search including the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. + */ + bm25(query: string, opts?: Bm25Options): QueryReturn; +} + +interface Hybrid { + /** + * Search for objects in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. + * + * See the [docs](https://weaviate.io/developers/weaviate/search/hybrid) for a more detailed explanation. + * + * This overload is for performing a search without the `groupBy` param. + * + * @overload + * @param {string} query - The query to search for. + * @param {BaseHybridOptions} [opts] - The available options for the search excluding the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. + */ + hybrid(query: string, opts?: BaseHybridOptions): Promise>; + /** + * Search for objects in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. + * + * See the [docs](https://weaviate.io/developers/weaviate/search/hybrid) for a more detailed explanation. + * + * This overload is for performing a search with the `groupBy` param. + * + * @overload + * @param {string} query - The query to search for. + * @param {GroupByHybridOptions} opts - The available options for the search including the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. + */ + hybrid(query: string, opts: GroupByHybridOptions): Promise>; + /** + * Search for objects in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. + * + * See the [docs](https://weaviate.io/developers/weaviate/search/hybrid) for a more detailed explanation. + * + * This overload is for performing a search with a programmatically defined `opts` param. + * + * @overload + * @param {string} query - The query to search for. + * @param {HybridOptions} [opts] - The available options for the search including the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. + */ + hybrid(query: string, opts?: HybridOptions): QueryReturn; +} + interface NearImage { /** * Search for objects by image in this collection using an image-capable vectorization module and vector-based similarity search. @@ -363,7 +480,14 @@ interface NearVector { } /** All the available methods on the `.query` namespace. */ -export interface Query extends NearImage, NearMedia, NearObject, NearText, NearVector { +export interface Query + extends Bm25, + Hybrid, + NearImage, + NearMedia, + NearObject, + NearText, + NearVector { /** * Retrieve an object from the server by its UUID. * @@ -380,28 +504,6 @@ export interface Query extends NearImage, NearMedia, NearObject, Nea * @returns {Promise>} - The objects within the fetched collection. */ fetchObjects: (opts?: FetchObjectsOptions) => Promise>; - - /** - * Search for objects in this collection using the keyword-based BM25 algorithm. - * - * See the [docs](https://weaviate.io/developers/weaviate/search/bm25) for a more detailed explanation. - * - * @param {string} query - The keyword query to search for. - * @param {Bm25Options} [opts] - The available options for searching for the objects. - * @returns {Promise>} - The objects matching the search within the fetched collection. - */ - bm25: (query: string, opts?: Bm25Options) => Promise>; - - /** - * Search for objects in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. - * - * See the [docs](https://weaviate.io/developers/weaviate/search/hybrid) for a more detailed explanation. - * - * @param {string} query - The keyword query to search for. - * @param {HybridOptions} [opts] - The available options for searching for the objects. - * @returns {Promise>} - The objects matching the search within the fetched collection. - */ - hybrid: (query: string, opts?: HybridOptions) => Promise>; } /** Options available in the `query.nearImage`, `query.nearMedia`, `query.nearObject`, and `query.nearVector` methods */ export type NearOptions = BaseNearOptions | GroupByNearOptions | undefined; diff --git a/src/collections/serialize/index.ts b/src/collections/serialize/index.ts index ebc4e77d..75abda07 100644 --- a/src/collections/serialize/index.ts +++ b/src/collections/serialize/index.ts @@ -79,8 +79,12 @@ import { NearOptions, SearchOptions, NearTextOptions, + BaseNearOptions, + BaseHybridOptions, + HybridNearTextSubSearch, + HybridNearVectorSubSearch, } from '../query/types.js'; -import { GenerateOptions } from '../generate/index.js'; +import { GenerateOptions } from '../generate/types.js'; import { BooleanArrayProperties, IntArrayProperties, @@ -301,6 +305,10 @@ export class MetadataGuards { } export class Serialize { + public static isNamedVectors = (opts?: BaseNearOptions): boolean => { + return Array.isArray(opts?.includeVector) || opts?.targetVector !== undefined; + }; + private static common = (args?: SearchOptions): BaseSearchArgs => { const out: BaseSearchArgs = { limit: args?.limit, @@ -349,6 +357,43 @@ export class Serialize { }; }; + private static isHybridVectorSearch = (vector: BaseHybridOptions['vector']): vector is number[] => { + return Array.isArray(vector); + }; + + private static isHybridNearTextSearch = ( + vector: BaseHybridOptions['vector'] + ): vector is HybridNearTextSubSearch => { + return (vector as HybridNearTextSubSearch)?.text !== undefined; + }; + + private static isHybridNearVectorSearch = ( + vector: BaseHybridOptions['vector'] + ): vector is HybridNearVectorSubSearch => { + return (vector as HybridNearVectorSubSearch)?.vector !== undefined; + }; + + private static hybridVector = (vector: BaseHybridOptions['vector']): Uint8Array | undefined => { + return Serialize.isHybridVectorSearch(vector) ? Serialize.vectorToBytes(vector) : undefined; + }; + + private static hybridNearText = (vector: BaseHybridOptions['vector']): NearTextSearch | undefined => { + return Serialize.isHybridNearTextSearch(vector) + ? Serialize.nearTextSearch({ + ...vector, + query: vector.text, + }) + : undefined; + }; + + private static hybridNearVector = (vector: BaseHybridOptions['vector']): NearVector | undefined => { + return Serialize.isHybridNearVectorSearch(vector) + ? NearVector.fromPartial({ + vectorBytes: Serialize.vectorToBytes(vector.vector), + }) + : undefined; + }; + public static hybrid = (args: { query: string } & HybridOptions): SearchHybridArgs => { const fusionType = (fusionType?: string): Hybrid_FusionType => { switch (fusionType) { @@ -360,15 +405,18 @@ export class Serialize { return Hybrid_FusionType.FUSION_TYPE_UNSPECIFIED; } }; + return { ...Serialize.common(args), hybridSearch: Hybrid.fromPartial({ query: args.query, alpha: args.alpha ? args.alpha : 0.5, properties: args.queryProperties, - vectorBytes: args.vector ? Serialize.vectorToBytes(args.vector) : undefined, + vectorBytes: Serialize.hybridVector(args.vector), fusionType: fusionType(args.fusionType), targetVectors: args.targetVector ? [args.targetVector] : undefined, + nearText: Serialize.hybridNearText(args.vector), + nearVector: Serialize.hybridNearVector(args.vector), }), autocut: args.autoLimit, }; @@ -439,31 +487,42 @@ export class Serialize { }; }; + private static nearTextSearch = (args: { + query: string | string[]; + certainty?: number; + distance?: number; + targetVector?: string; + moveAway?: { concepts?: string[]; force?: number; objects?: string[] }; + moveTo?: { concepts?: string[]; force?: number; objects?: string[] }; + }) => { + return NearTextSearch.fromPartial({ + query: typeof args.query === 'string' ? [args.query] : args.query, + certainty: args.certainty, + distance: args.distance, + targetVectors: args.targetVector ? [args.targetVector] : undefined, + moveAway: args.moveAway + ? NearTextSearch_Move.fromPartial({ + concepts: args.moveAway.concepts, + force: args.moveAway.force, + uuids: args.moveAway.objects, + }) + : undefined, + moveTo: args.moveTo + ? NearTextSearch_Move.fromPartial({ + concepts: args.moveTo.concepts, + force: args.moveTo.force, + uuids: args.moveTo.objects, + }) + : undefined, + }); + }; + public static nearText = ( args: { query: string | string[] } & NearTextOptions ): SearchNearTextArgs => { return { ...Serialize.common(args), - nearText: NearTextSearch.fromPartial({ - query: typeof args.query === 'string' ? [args.query] : args.query, - certainty: args.certainty, - distance: args.distance, - targetVectors: args.targetVector ? [args.targetVector] : undefined, - moveAway: args.moveAway - ? NearTextSearch_Move.fromPartial({ - concepts: args.moveAway.concepts, - force: args.moveAway.force, - uuids: args.moveAway.objects, - }) - : undefined, - moveTo: args.moveTo - ? NearTextSearch_Move.fromPartial({ - concepts: args.moveTo.concepts, - force: args.moveTo.force, - uuids: args.moveTo.objects, - }) - : undefined, - }), + nearText: Serialize.nearTextSearch(args), autocut: args.autoLimit, }; }; @@ -485,15 +544,24 @@ export class Serialize { return new Uint8Array(new Float32Array(vector).buffer); }; + private static nearVectorSearch = (args: { + vector: number[]; + certainty?: number; + distance?: number; + targetVector?: string; + }) => { + return NearVector.fromPartial({ + vectorBytes: Serialize.vectorToBytes(args.vector), + certainty: args.certainty, + distance: args.distance, + targetVectors: args.targetVector ? [args.targetVector] : undefined, + }); + }; + public static nearVector = (args: { vector: number[] } & NearOptions): SearchNearVectorArgs => { return { ...Serialize.common(args), - nearVector: NearVector.fromPartial({ - vectorBytes: Serialize.vectorToBytes(args.vector), - certainty: args.certainty, - distance: args.distance, - targetVectors: args.targetVector ? [args.targetVector] : undefined, - }), + nearVector: Serialize.nearVectorSearch(args), autocut: args.autoLimit, }; }; diff --git a/src/collections/tenants/index.ts b/src/collections/tenants/index.ts index 82205e93..1e3994da 100644 --- a/src/collections/tenants/index.ts +++ b/src/collections/tenants/index.ts @@ -1,33 +1,91 @@ -import Connection from '../../connection/index.js'; +import { ConnectionGRPC } from '../../connection/index.js'; +import { WeaviateUnsupportedFeatureError } from '../../errors.js'; +import { TenantActivityStatus, TenantsGetReply } from '../../proto/v1/tenants.js'; import { TenantsCreator, TenantsDeleter, TenantsGetter, TenantsUpdater } from '../../schema/index.js'; +import { DbVersionSupport } from '../../utils/dbVersion.js'; export type Tenant = { name: string; activityStatus?: 'COLD' | 'HOT'; }; -const tenants = (connection: Connection, name: string): Tenants => { - const parseTenants = (tenants: Tenant | Tenant[]) => (Array.isArray(tenants) ? tenants : [tenants]); +export type TenantsGetOptions = { + tenants?: string; +}; + +class ActivityStatusMapper { + static from(status: TenantActivityStatus): 'COLD' | 'HOT' { + switch (status) { + case TenantActivityStatus.TENANT_ACTIVITY_STATUS_COLD: + return 'COLD'; + case TenantActivityStatus.TENANT_ACTIVITY_STATUS_HOT: + return 'HOT'; + default: + throw new Error(`Unsupported tenant activity status: ${status}`); + } + } +} + +const mapReply = (reply: TenantsGetReply): Record => { + const tenants: Record = {}; + reply.tenants.forEach((t) => { + tenants[t.name] = { + name: t.name, + activityStatus: ActivityStatusMapper.from(t.activityStatus), + }; + }); + return tenants; +}; + +const checkSupportForGRPCTenantsGetEndpoint = async (dbVersionSupport: DbVersionSupport) => { + const check = await dbVersionSupport.supportsTenantsGetGRPCMethod(); + if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message); +}; + +const parseTenantOrTenantArray = (tenants: Tenant | Tenant[]) => + Array.isArray(tenants) ? tenants : [tenants]; + +const parseStringOrTenant = (tenant: string | Tenant) => (typeof tenant === 'string' ? tenant : tenant.name); + +const tenants = ( + connection: ConnectionGRPC, + collection: string, + dbVersionSupport: DbVersionSupport +): Tenants => { + const getGRPC = (names?: string[]) => + checkSupportForGRPCTenantsGetEndpoint(dbVersionSupport) + .then(() => connection.tenants(collection)) + .then((builder) => builder.withGet({ names })) + .then(mapReply); + const getREST = () => + new TenantsGetter(connection, collection).do().then((tenants) => { + const result: Record = {}; + tenants.forEach((tenant) => { + if (!tenant.name) return; + result[tenant.name] = tenant as Tenant; + }); + return result; + }); return { create: (tenants: Tenant | Tenant[]) => - new TenantsCreator(connection, name, parseTenants(tenants)).do() as Promise, - get: () => - new TenantsGetter(connection, name).do().then((tenants) => { - const result: Record = {}; - tenants.forEach((tenant) => { - if (!tenant.name) return; - result[tenant.name] = tenant as Tenant; - }); - return result; - }), + new TenantsCreator(connection, collection, parseTenantOrTenantArray(tenants)).do() as Promise, + get: async function () { + const check = await dbVersionSupport.supportsTenantsGetGRPCMethod(); + return check.supports ? getGRPC() : getREST(); + }, + getByNames: (tenants: (string | Tenant)[]) => getGRPC(tenants.map(parseStringOrTenant)), + getByName: (tenant: string | Tenant) => { + const tenantName = parseStringOrTenant(tenant); + return getGRPC([tenantName]).then((tenants) => tenants[tenantName] || null); + }, remove: (tenants: Tenant | Tenant[]) => new TenantsDeleter( connection, - name, - parseTenants(tenants).map((t) => t.name) + collection, + parseTenantOrTenantArray(tenants).map((t) => t.name) ).do(), update: (tenants: Tenant | Tenant[]) => - new TenantsUpdater(connection, name, parseTenants(tenants)).do() as Promise, + new TenantsUpdater(connection, collection, parseTenantOrTenantArray(tenants)).do() as Promise, }; }; @@ -58,6 +116,24 @@ export interface Tenants { * @returns {Promise>} A list of tenants as an object of Tenant types, where the key is the tenant name. */ get: () => Promise>; + /** + * Return the specified tenants from a collection in Weaviate. + * + * The collection must have been created with multi-tenancy enabled. + * + * @param {(string | Tenant)[]} names The tenants to retrieve. + * @returns {Promise} The list of tenants. If the tenant does not exist, it will not be included in the list. + */ + getByNames: (names: (string | Tenant)[]) => Promise>; + /** + * Return the specified tenant from a collection in Weaviate. + * + * The collection must have been created with multi-tenancy enabled. + * + * @param {string | Tenant} name The name of the tenant to retrieve. + * @returns {Promise} The tenant as a Tenant type, or null if the tenant does not exist. + */ + getByName: (name: string | Tenant) => Promise; /** * Remove the specified tenants from a collection in Weaviate. * diff --git a/src/collections/tenants/integration.test.ts b/src/collections/tenants/integration.test.ts index ce70c92a..4bd2d34c 100644 --- a/src/collections/tenants/integration.test.ts +++ b/src/collections/tenants/integration.test.ts @@ -1,8 +1,9 @@ /* eslint-disable @typescript-eslint/no-non-null-assertion */ +import { WeaviateUnsupportedFeatureError } from '../../errors.js'; import weaviate, { WeaviateClient } from '../../index.js'; import { Collection } from '../collection/index.js'; -describe('Testing of the collection.data methods', () => { +describe('Testing of the collection.tenants methods', () => { let client: WeaviateClient; let collection: Collection; const collectionName = 'TestCollectionTenants'; @@ -75,4 +76,83 @@ describe('Testing of the collection.data methods', () => { expect(result[0].name).toBe('cold'); expect(result[0].activityStatus).toBe('HOT'); }); + + describe('getByName and getByNames', () => { + it('should be able to get a tenant by name string', async () => { + const query = () => collection.tenants.getByName('hot'); + if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 25, 0))) { + await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError); + return; + } + const result = await query(); + expect(result).toHaveProperty('name', 'hot'); + expect(result).toHaveProperty('activityStatus', 'HOT'); + }); + + it('should be able to get a tenant by tenant object', async () => { + const query = () => collection.tenants.getByName({ name: 'hot' }); + if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 25, 0))) { + await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError); + return; + } + const result = await query(); + expect(result).toHaveProperty('name', 'hot'); + expect(result).toHaveProperty('activityStatus', 'HOT'); + }); + + it('should fail to get a non-existing tenant', async () => { + const query = () => collection.tenants.getByName('non-existing'); + if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 25, 0))) { + await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError); + return; + } + const result = await query(); + expect(result).toBeNull(); + }); + + it('should be able to get tenants by name strings', async () => { + const query = () => collection.tenants.getByNames(['hot', 'cold']); + if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 25, 0))) { + await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError); + return; + } + const result = await query(); + expect(result).toHaveProperty('hot'); + expect(result).toHaveProperty('cold'); + }); + + it('should be able to get tenants by tenant objects', async () => { + const query = () => collection.tenants.getByNames([{ name: 'hot' }, { name: 'cold' }]); + if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 25, 0))) { + await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError); + return; + } + const result = await query(); + expect(result).toHaveProperty('hot'); + expect(result).toHaveProperty('cold'); + }); + + it('should be able to get tenants by mixed name strings and tenant objects', async () => { + const query = () => collection.tenants.getByNames(['hot', { name: 'cold' }]); + if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 25, 0))) { + await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError); + return; + } + const result = await query(); + expect(result).toHaveProperty('hot'); + expect(result).toHaveProperty('cold'); + }); + + it('should be able to get partial tenants', async () => { + const query = () => collection.tenants.getByNames(['hot', 'non-existing']); + if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 25, 0))) { + await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError); + return; + } + const result = await query(); + expect(result).toHaveProperty('hot'); + expect(result).not.toHaveProperty('cold'); + expect(result).not.toHaveProperty('non-existing'); + }); + }); }); diff --git a/src/connection/grpc.ts b/src/connection/grpc.ts index 1b9405a7..73c53cd1 100644 --- a/src/connection/grpc.ts +++ b/src/connection/grpc.ts @@ -11,6 +11,8 @@ import { HealthDefinition, HealthCheckResponse_ServingStatus } from '../proto/go import Batcher, { Batch } from '../grpc/batcher.js'; import Searcher, { Search } from '../grpc/searcher.js'; +import { DbVersionSupport, initDbVersionProvider } from '../utils/dbVersion.js'; +import TenantsManager, { Tenants } from '../grpc/tenantsManager.js'; import { WeaviateGRPCUnavailableError } from '../errors.js'; @@ -37,8 +39,22 @@ export default class ConnectionGRPC extends ConnectionGQL { static use = async (params: GrpcConnectionParams) => { const connection = new ConnectionGRPC(params); - await connection.connect(); - return connection; + const dbVersionProvider = initDbVersionProvider(connection); + const dbVersionSupport = new DbVersionSupport(dbVersionProvider); + const settled = await Promise.allSettled([ + dbVersionSupport.supportsCompatibleGrpcService().then((check) => { + if (!check.supports) { + throw new Error(check.message); + } + }), + connection.connect(), + ]); + settled.forEach((promise) => { + if (promise.status === 'rejected') { + throw new Error(promise.reason); + } + }); + return { connection, dbVersionProvider, dbVersionSupport }; }; private async connect() { @@ -48,20 +64,29 @@ export default class ConnectionGRPC extends ConnectionGQL { } } - search = (name: string, consistencyLevel?: ConsistencyLevel, tenant?: string) => { + search = (collection: string, consistencyLevel?: ConsistencyLevel, tenant?: string) => { if (this.authEnabled) { return this.login().then((token) => - this.grpc.search(name, consistencyLevel, tenant, `Bearer ${token}`) + this.grpc.search(collection, consistencyLevel, tenant, `Bearer ${token}`) ); } - return new Promise((resolve) => resolve(this.grpc.search(name, consistencyLevel, tenant))); + return new Promise((resolve) => resolve(this.grpc.search(collection, consistencyLevel, tenant))); }; - batch = (name: string, consistencyLevel?: ConsistencyLevel, tenant?: string) => { + batch = (collection: string, consistencyLevel?: ConsistencyLevel, tenant?: string) => { if (this.authEnabled) { - return this.login().then((token) => this.grpc.batch(name, consistencyLevel, tenant, `Bearer ${token}`)); + return this.login().then((token) => + this.grpc.batch(collection, consistencyLevel, tenant, `Bearer ${token}`) + ); + } + return new Promise((resolve) => resolve(this.grpc.batch(collection, consistencyLevel, tenant))); + }; + + tenants = (collection: string) => { + if (this.authEnabled) { + return this.login().then((token) => this.grpc.tenants(collection, `Bearer ${token}`)); } - return new Promise((resolve) => resolve(this.grpc.batch(name, consistencyLevel, tenant))); + return new Promise((resolve) => resolve(this.grpc.tenants(collection))); }; close = () => { @@ -72,14 +97,20 @@ export default class ConnectionGRPC extends ConnectionGQL { export interface GrpcClient { close: () => void; - batch: (name: string, consistencyLevel?: ConsistencyLevel, tenant?: string, bearerToken?: string) => Batch; + batch: ( + collection: string, + consistencyLevel?: ConsistencyLevel, + tenant?: string, + bearerToken?: string + ) => Batch; health: () => Promise; search: ( - name: string, + collection: string, consistencyLevel?: ConsistencyLevel, tenant?: string, bearerToken?: string ) => Search; + tenants: (collection: string, bearerToken?: string) => Tenants; } export const grpcClient = (config: GrpcConnectionParams): GrpcClient => { @@ -102,10 +133,10 @@ export const grpcClient = (config: GrpcConnectionParams): GrpcClient => { const health = clientFactory.create(HealthDefinition, channel); return { close: () => channel.close(), - batch: (name: string, consistencyLevel?: ConsistencyLevel, tenant?: string, bearerToken?: string) => + batch: (collection: string, consistencyLevel?: ConsistencyLevel, tenant?: string, bearerToken?: string) => Batcher.use( client, - name, + collection, new Metadata(bearerToken ? { ...config.headers, authorization: bearerToken } : config.headers), consistencyLevel, tenant @@ -114,13 +145,24 @@ export const grpcClient = (config: GrpcConnectionParams): GrpcClient => { health .check({ service: '/grpc.health.v1.Health/Check' }) .then((res) => res.status === HealthCheckResponse_ServingStatus.SERVING), - search: (name: string, consistencyLevel?: ConsistencyLevel, tenant?: string, bearerToken?: string) => + search: ( + collection: string, + consistencyLevel?: ConsistencyLevel, + tenant?: string, + bearerToken?: string + ) => Searcher.use( client, - name, + collection, new Metadata(bearerToken ? { ...config.headers, authorization: bearerToken } : config.headers), consistencyLevel, tenant ), + tenants: (collection: string, bearerToken?: string) => + TenantsManager.use( + client, + collection, + new Metadata(bearerToken ? { ...config.headers, authorization: bearerToken } : config.headers) + ), }; }; diff --git a/src/connection/journey.test.ts b/src/connection/journey.test.ts index 7073ad47..8bac863e 100644 --- a/src/connection/journey.test.ts +++ b/src/connection/journey.test.ts @@ -12,7 +12,7 @@ describe('connection', () => { it('makes a logged-in request when client host param has trailing slashes', () => { if (process.env.WCS_DUMMY_CI_PW == undefined || process.env.WCS_DUMMY_CI_PW == '') { console.warn('Skipping because `WCS_DUMMY_CI_PW` is not set'); - return; + return Promise.resolve(); } const client = weaviate.client({ @@ -39,7 +39,7 @@ describe('connection', () => { it('makes an Azure logged-in request with client credentials', () => { if (process.env.AZURE_CLIENT_SECRET == undefined || process.env.AZURE_CLIENT_SECRET == '') { console.warn('Skipping because `AZURE_CLIENT_SECRET` is not set'); - return; + return Promise.resolve(); } const client = weaviate.client({ @@ -65,7 +65,7 @@ describe('connection', () => { it('makes an Okta logged-in request with client credentials', () => { if (process.env.OKTA_CLIENT_SECRET == undefined || process.env.OKTA_CLIENT_SECRET == '') { console.warn('Skipping because `OKTA_CLIENT_SECRET` is not set'); - return; + return Promise.resolve(); } const client = weaviate.client({ @@ -92,7 +92,7 @@ describe('connection', () => { it('makes an Okta logged-in request with username/password', () => { if (process.env.OKTA_DUMMY_CI_PW == undefined || process.env.OKTA_DUMMY_CI_PW == '') { console.warn('Skipping because `OKTA_DUMMY_CI_PW` is not set'); - return; + return Promise.resolve(); } const client = weaviate.client({ @@ -119,7 +119,7 @@ describe('connection', () => { it('makes a WCS logged-in request with username/password', () => { if (process.env.WCS_DUMMY_CI_PW == undefined || process.env.WCS_DUMMY_CI_PW == '') { console.warn('Skipping because `WCS_DUMMY_CI_PW` is not set'); - return; + return Promise.resolve(); } const client = weaviate.client({ diff --git a/src/errors.ts b/src/errors.ts index 405bc996..f43edb35 100644 --- a/src/errors.ts +++ b/src/errors.ts @@ -102,3 +102,5 @@ export class WeaviateBackupFailed extends WeaviateError { super(`Backup ${kind} failed with message: ${message}`); } } + +export class WeaviateUnsupportedFeatureError extends WeaviateError {} diff --git a/src/grpc/base.ts b/src/grpc/base.ts index 2355e695..f09a1318 100644 --- a/src/grpc/base.ts +++ b/src/grpc/base.ts @@ -6,20 +6,20 @@ import { Metadata } from 'nice-grpc'; export default class Base { protected connection: WeaviateClient; - protected name: string; + protected collection: string; protected consistencyLevel?: ConsistencyLevelGRPC; protected tenant?: string; protected metadata?: Metadata; protected constructor( connection: WeaviateClient, - name: string, + collection: string, metadata: Metadata, consistencyLevel?: ConsistencyLevel, tenant?: string ) { this.connection = connection; - this.name = name; + this.collection = collection; this.consistencyLevel = this.mapConsistencyLevel(consistencyLevel); this.tenant = tenant; this.metadata = metadata; diff --git a/src/grpc/batcher.ts b/src/grpc/batcher.ts index 9ffccddf..e74afda5 100644 --- a/src/grpc/batcher.ts +++ b/src/grpc/batcher.ts @@ -28,12 +28,12 @@ export interface BatchDeleteArgs { export default class Batcher extends Base implements Batch { public static use( connection: WeaviateClient, - name: string, + collection: string, metadata: Metadata, consistencyLevel?: ConsistencyLevel, tenant?: string ): Batch { - return new Batcher(connection, name, metadata, consistencyLevel, tenant); + return new Batcher(connection, collection, metadata, consistencyLevel, tenant); } public withDelete = (args: BatchDeleteArgs) => this.callDelete(BatchDeleteRequest.fromPartial(args)); @@ -44,7 +44,7 @@ export default class Batcher extends Base implements Batch { .batchDelete( { ...message, - collection: this.name, + collection: this.collection, consistencyLevel: this.consistencyLevel, tenant: this.tenant, }, diff --git a/src/grpc/searcher.ts b/src/grpc/searcher.ts index 172026ef..b1b3372e 100644 --- a/src/grpc/searcher.ts +++ b/src/grpc/searcher.ts @@ -114,12 +114,12 @@ export interface Search { export default class Searcher extends Base implements Search { public static use( connection: WeaviateClient, - name: string, + collection: string, metadata: Metadata, consistencyLevel?: ConsistencyLevel, tenant?: string ): Search { - return new Searcher(connection, name, metadata, consistencyLevel, tenant); + return new Searcher(connection, collection, metadata, consistencyLevel, tenant); } public withFetch = (args: SearchFetchArgs) => this.call(SearchRequest.fromPartial(args)); @@ -140,10 +140,11 @@ export default class Searcher extends Base implements Search { .search( { ...message, - collection: this.name, + collection: this.collection, consistencyLevel: this.consistencyLevel, tenant: this.tenant, uses123Api: true, + uses125Api: true, }, { metadata: this.metadata, diff --git a/src/grpc/tenantsManager.ts b/src/grpc/tenantsManager.ts new file mode 100644 index 00000000..8b156fbf --- /dev/null +++ b/src/grpc/tenantsManager.ts @@ -0,0 +1,33 @@ +import Base from './base.js'; +import { Metadata } from 'nice-grpc'; +import { WeaviateClient } from '../proto/v1/weaviate.js'; +import { TenantsGetReply, TenantsGetRequest } from '../proto/v1/tenants.js'; + +export type TenantsGetArgs = { + names?: string[]; +}; + +export interface Tenants { + withGet: (args: TenantsGetArgs) => Promise; +} + +export default class TenantsManager extends Base implements TenantsManager { + public static use(connection: WeaviateClient, collection: string, metadata: Metadata): Tenants { + return new TenantsManager(connection, collection, metadata); + } + + public withGet = (args: TenantsGetArgs) => + this.call(TenantsGetRequest.fromPartial({ names: args.names ? { values: args.names } : undefined })); + + private call(message: TenantsGetRequest) { + return this.connection.tenantsGet( + { + ...message, + collection: this.collection, + }, + { + metadata: this.metadata, + } + ); + } +} diff --git a/src/index.ts b/src/index.ts index 4d61ec5c..7417b334 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,5 +1,5 @@ import { ConnectionGRPC } from './connection/index.js'; -import { DbVersionProvider, DbVersionSupport } from './utils/dbVersion.js'; +import { DbVersion, DbVersionProvider, DbVersionSupport } from './utils/dbVersion.js'; import { backup, Backup } from './collections/backup/client.js'; import cluster, { Cluster } from './collections/cluster/index.js'; import { @@ -84,9 +84,10 @@ export interface WeaviateClient { close: () => Promise; getMeta: () => Promise; + getOpenIDConfig?: () => Promise; + getWeaviateVersion: () => Promise; isLive: () => Promise; isReady: () => Promise; - getOpenIDConfig?: () => Promise; } const app = { @@ -130,7 +131,7 @@ const app = { ? new HttpsAgent({ keepAlive: true }) : new HttpAgent({ keepAlive: true }); - const conn = await ConnectionGRPC.use({ + const { connection, dbVersionProvider, dbVersionSupport } = await ConnectionGRPC.use({ host: params.rest.host.startsWith('http') ? `${params.rest.host}${params.rest.path || ''}` : `${scheme}://${params.rest.host}:${params.rest.port}${params.rest.path || ''}`, @@ -144,20 +145,18 @@ const app = { agent, }); - const dbVersionProvider = initDbVersionProvider(conn); - const dbVersionSupport = new DbVersionSupport(dbVersionProvider); - const ifc: WeaviateClient = { - backup: backup(conn), - cluster: cluster(conn), - collections: collections(conn, dbVersionSupport), - close: () => Promise.resolve(conn.close()), // hedge against future changes to add I/O to .close() - getMeta: () => new MetaGetter(conn).do(), - getOpenIDConfig: () => new OpenidConfigurationGetter(conn.http).do(), - isLive: () => new LiveChecker(conn, dbVersionProvider).do(), - isReady: () => new ReadyChecker(conn, dbVersionProvider).do(), + backup: backup(connection), + cluster: cluster(connection), + collections: collections(connection, dbVersionSupport), + close: () => Promise.resolve(connection.close()), // hedge against future changes to add I/O to .close() + getMeta: () => new MetaGetter(connection).do(), + getOpenIDConfig: () => new OpenidConfigurationGetter(connection.http).do(), + getWeaviateVersion: () => dbVersionSupport.getVersion(), + isLive: () => new LiveChecker(connection, dbVersionProvider).do(), + isReady: () => new ReadyChecker(connection, dbVersionProvider).do(), }; - if (conn.oidcAuth) ifc.oidcAuth = conn.oidcAuth; + if (connection.oidcAuth) ifc.oidcAuth = connection.oidcAuth; return ifc; }, @@ -170,21 +169,6 @@ const app = { reconfigure, }; -function initDbVersionProvider(conn: ConnectionGRPC) { - const metaGetter = new MetaGetter(conn); - const versionGetter = () => { - return metaGetter - .do() - .then((result: any) => result.version) - .catch(() => Promise.resolve('')); - }; - - const dbVersionProvider = new DbVersionProvider(versionGetter); - dbVersionProvider.refresh(); - - return dbVersionProvider; -} - export default app; export * from './collections/index.js'; export * from './connection/index.js'; diff --git a/src/misc/metaGetter.ts b/src/misc/metaGetter.ts index dc44b235..9c020da4 100644 --- a/src/misc/metaGetter.ts +++ b/src/misc/metaGetter.ts @@ -1,4 +1,5 @@ import Connection from '../connection/index.js'; +import { Meta } from '../openapi/types.js'; import { CommandBase } from '../validation/commandBase.js'; export default class MetaGetter extends CommandBase { @@ -10,7 +11,7 @@ export default class MetaGetter extends CommandBase { // nothing to validate } - do = () => { + do = (): Promise => { return this.client.get('/meta', true); }; } diff --git a/src/proto/v1/properties.ts b/src/proto/v1/properties.ts index de3abe72..3f21d26e 100644 --- a/src/proto/v1/properties.ts +++ b/src/proto/v1/properties.ts @@ -15,7 +15,10 @@ export interface Properties_FieldsEntry { } export interface Value { - numberValue?: number | undefined; + numberValue?: + | number + | undefined; + /** @deprecated */ stringValue?: string | undefined; boolValue?: boolean | undefined; objectValue?: Properties | undefined; @@ -27,10 +30,55 @@ export interface Value { blobValue?: string | undefined; phoneValue?: PhoneNumber | undefined; nullValue?: NullValue | undefined; + textValue?: string | undefined; } export interface ListValue { + /** @deprecated */ values: Value[]; + numberValues?: NumberValues | undefined; + boolValues?: BoolValues | undefined; + objectValues?: ObjectValues | undefined; + dateValues?: DateValues | undefined; + uuidValues?: UuidValues | undefined; + intValues?: IntValues | undefined; + textValues?: TextValues | undefined; +} + +export interface NumberValues { + /** + * The values are stored as a byte array, where each 8 bytes represent a single float64 value. + * The byte array is stored in little-endian order using uint64 encoding. + */ + values: Uint8Array; +} + +export interface TextValues { + values: string[]; +} + +export interface BoolValues { + values: boolean[]; +} + +export interface ObjectValues { + values: Properties[]; +} + +export interface DateValues { + values: string[]; +} + +export interface UuidValues { + values: string[]; +} + +export interface IntValues { + /** + * The values are stored as a byte array, where each 8 bytes represent a single int64 value. + * The byte array is stored in little-endian order using uint64 encoding. + */ + values: Uint8Array; } export interface GeoCoordinate { @@ -214,6 +262,7 @@ function createBaseValue(): Value { blobValue: undefined, phoneValue: undefined, nullValue: undefined, + textValue: undefined, }; } @@ -255,6 +304,9 @@ export const Value = { if (message.nullValue !== undefined) { writer.uint32(96).int32(message.nullValue); } + if (message.textValue !== undefined) { + writer.uint32(106).string(message.textValue); + } return writer; }, @@ -349,6 +401,13 @@ export const Value = { message.nullValue = reader.int32() as any; continue; + case 13: + if (tag !== 106) { + break; + } + + message.textValue = reader.string(); + continue; } if ((tag & 7) === 4 || tag === 0) { break; @@ -372,6 +431,7 @@ export const Value = { blobValue: isSet(object.blobValue) ? globalThis.String(object.blobValue) : undefined, phoneValue: isSet(object.phoneValue) ? PhoneNumber.fromJSON(object.phoneValue) : undefined, nullValue: isSet(object.nullValue) ? nullValueFromJSON(object.nullValue) : undefined, + textValue: isSet(object.textValue) ? globalThis.String(object.textValue) : undefined, }; }, @@ -413,6 +473,9 @@ export const Value = { if (message.nullValue !== undefined) { obj.nullValue = nullValueToJSON(message.nullValue); } + if (message.textValue !== undefined) { + obj.textValue = message.textValue; + } return obj; }, @@ -441,12 +504,22 @@ export const Value = { ? PhoneNumber.fromPartial(object.phoneValue) : undefined; message.nullValue = object.nullValue ?? undefined; + message.textValue = object.textValue ?? undefined; return message; }, }; function createBaseListValue(): ListValue { - return { values: [] }; + return { + values: [], + numberValues: undefined, + boolValues: undefined, + objectValues: undefined, + dateValues: undefined, + uuidValues: undefined, + intValues: undefined, + textValues: undefined, + }; } export const ListValue = { @@ -454,6 +527,27 @@ export const ListValue = { for (const v of message.values) { Value.encode(v!, writer.uint32(10).fork()).ldelim(); } + if (message.numberValues !== undefined) { + NumberValues.encode(message.numberValues, writer.uint32(18).fork()).ldelim(); + } + if (message.boolValues !== undefined) { + BoolValues.encode(message.boolValues, writer.uint32(26).fork()).ldelim(); + } + if (message.objectValues !== undefined) { + ObjectValues.encode(message.objectValues, writer.uint32(34).fork()).ldelim(); + } + if (message.dateValues !== undefined) { + DateValues.encode(message.dateValues, writer.uint32(42).fork()).ldelim(); + } + if (message.uuidValues !== undefined) { + UuidValues.encode(message.uuidValues, writer.uint32(50).fork()).ldelim(); + } + if (message.intValues !== undefined) { + IntValues.encode(message.intValues, writer.uint32(58).fork()).ldelim(); + } + if (message.textValues !== undefined) { + TextValues.encode(message.textValues, writer.uint32(66).fork()).ldelim(); + } return writer; }, @@ -471,6 +565,55 @@ export const ListValue = { message.values.push(Value.decode(reader, reader.uint32())); continue; + case 2: + if (tag !== 18) { + break; + } + + message.numberValues = NumberValues.decode(reader, reader.uint32()); + continue; + case 3: + if (tag !== 26) { + break; + } + + message.boolValues = BoolValues.decode(reader, reader.uint32()); + continue; + case 4: + if (tag !== 34) { + break; + } + + message.objectValues = ObjectValues.decode(reader, reader.uint32()); + continue; + case 5: + if (tag !== 42) { + break; + } + + message.dateValues = DateValues.decode(reader, reader.uint32()); + continue; + case 6: + if (tag !== 50) { + break; + } + + message.uuidValues = UuidValues.decode(reader, reader.uint32()); + continue; + case 7: + if (tag !== 58) { + break; + } + + message.intValues = IntValues.decode(reader, reader.uint32()); + continue; + case 8: + if (tag !== 66) { + break; + } + + message.textValues = TextValues.decode(reader, reader.uint32()); + continue; } if ((tag & 7) === 4 || tag === 0) { break; @@ -481,7 +624,16 @@ export const ListValue = { }, fromJSON(object: any): ListValue { - return { values: globalThis.Array.isArray(object?.values) ? object.values.map((e: any) => Value.fromJSON(e)) : [] }; + return { + values: globalThis.Array.isArray(object?.values) ? object.values.map((e: any) => Value.fromJSON(e)) : [], + numberValues: isSet(object.numberValues) ? NumberValues.fromJSON(object.numberValues) : undefined, + boolValues: isSet(object.boolValues) ? BoolValues.fromJSON(object.boolValues) : undefined, + objectValues: isSet(object.objectValues) ? ObjectValues.fromJSON(object.objectValues) : undefined, + dateValues: isSet(object.dateValues) ? DateValues.fromJSON(object.dateValues) : undefined, + uuidValues: isSet(object.uuidValues) ? UuidValues.fromJSON(object.uuidValues) : undefined, + intValues: isSet(object.intValues) ? IntValues.fromJSON(object.intValues) : undefined, + textValues: isSet(object.textValues) ? TextValues.fromJSON(object.textValues) : undefined, + }; }, toJSON(message: ListValue): unknown { @@ -489,6 +641,27 @@ export const ListValue = { if (message.values?.length) { obj.values = message.values.map((e) => Value.toJSON(e)); } + if (message.numberValues !== undefined) { + obj.numberValues = NumberValues.toJSON(message.numberValues); + } + if (message.boolValues !== undefined) { + obj.boolValues = BoolValues.toJSON(message.boolValues); + } + if (message.objectValues !== undefined) { + obj.objectValues = ObjectValues.toJSON(message.objectValues); + } + if (message.dateValues !== undefined) { + obj.dateValues = DateValues.toJSON(message.dateValues); + } + if (message.uuidValues !== undefined) { + obj.uuidValues = UuidValues.toJSON(message.uuidValues); + } + if (message.intValues !== undefined) { + obj.intValues = IntValues.toJSON(message.intValues); + } + if (message.textValues !== undefined) { + obj.textValues = TextValues.toJSON(message.textValues); + } return obj; }, @@ -498,6 +671,448 @@ export const ListValue = { fromPartial(object: DeepPartial): ListValue { const message = createBaseListValue(); message.values = object.values?.map((e) => Value.fromPartial(e)) || []; + message.numberValues = (object.numberValues !== undefined && object.numberValues !== null) + ? NumberValues.fromPartial(object.numberValues) + : undefined; + message.boolValues = (object.boolValues !== undefined && object.boolValues !== null) + ? BoolValues.fromPartial(object.boolValues) + : undefined; + message.objectValues = (object.objectValues !== undefined && object.objectValues !== null) + ? ObjectValues.fromPartial(object.objectValues) + : undefined; + message.dateValues = (object.dateValues !== undefined && object.dateValues !== null) + ? DateValues.fromPartial(object.dateValues) + : undefined; + message.uuidValues = (object.uuidValues !== undefined && object.uuidValues !== null) + ? UuidValues.fromPartial(object.uuidValues) + : undefined; + message.intValues = (object.intValues !== undefined && object.intValues !== null) + ? IntValues.fromPartial(object.intValues) + : undefined; + message.textValues = (object.textValues !== undefined && object.textValues !== null) + ? TextValues.fromPartial(object.textValues) + : undefined; + return message; + }, +}; + +function createBaseNumberValues(): NumberValues { + return { values: new Uint8Array(0) }; +} + +export const NumberValues = { + encode(message: NumberValues, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.values.length !== 0) { + writer.uint32(10).bytes(message.values); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): NumberValues { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseNumberValues(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.values = reader.bytes(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): NumberValues { + return { values: isSet(object.values) ? bytesFromBase64(object.values) : new Uint8Array(0) }; + }, + + toJSON(message: NumberValues): unknown { + const obj: any = {}; + if (message.values.length !== 0) { + obj.values = base64FromBytes(message.values); + } + return obj; + }, + + create(base?: DeepPartial): NumberValues { + return NumberValues.fromPartial(base ?? {}); + }, + fromPartial(object: DeepPartial): NumberValues { + const message = createBaseNumberValues(); + message.values = object.values ?? new Uint8Array(0); + return message; + }, +}; + +function createBaseTextValues(): TextValues { + return { values: [] }; +} + +export const TextValues = { + encode(message: TextValues, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + for (const v of message.values) { + writer.uint32(10).string(v!); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): TextValues { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseTextValues(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.values.push(reader.string()); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): TextValues { + return { + values: globalThis.Array.isArray(object?.values) ? object.values.map((e: any) => globalThis.String(e)) : [], + }; + }, + + toJSON(message: TextValues): unknown { + const obj: any = {}; + if (message.values?.length) { + obj.values = message.values; + } + return obj; + }, + + create(base?: DeepPartial): TextValues { + return TextValues.fromPartial(base ?? {}); + }, + fromPartial(object: DeepPartial): TextValues { + const message = createBaseTextValues(); + message.values = object.values?.map((e) => e) || []; + return message; + }, +}; + +function createBaseBoolValues(): BoolValues { + return { values: [] }; +} + +export const BoolValues = { + encode(message: BoolValues, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + writer.uint32(10).fork(); + for (const v of message.values) { + writer.bool(v); + } + writer.ldelim(); + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): BoolValues { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseBoolValues(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag === 8) { + message.values.push(reader.bool()); + + continue; + } + + if (tag === 10) { + const end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) { + message.values.push(reader.bool()); + } + + continue; + } + + break; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): BoolValues { + return { + values: globalThis.Array.isArray(object?.values) ? object.values.map((e: any) => globalThis.Boolean(e)) : [], + }; + }, + + toJSON(message: BoolValues): unknown { + const obj: any = {}; + if (message.values?.length) { + obj.values = message.values; + } + return obj; + }, + + create(base?: DeepPartial): BoolValues { + return BoolValues.fromPartial(base ?? {}); + }, + fromPartial(object: DeepPartial): BoolValues { + const message = createBaseBoolValues(); + message.values = object.values?.map((e) => e) || []; + return message; + }, +}; + +function createBaseObjectValues(): ObjectValues { + return { values: [] }; +} + +export const ObjectValues = { + encode(message: ObjectValues, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + for (const v of message.values) { + Properties.encode(v!, writer.uint32(10).fork()).ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): ObjectValues { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseObjectValues(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.values.push(Properties.decode(reader, reader.uint32())); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): ObjectValues { + return { + values: globalThis.Array.isArray(object?.values) ? object.values.map((e: any) => Properties.fromJSON(e)) : [], + }; + }, + + toJSON(message: ObjectValues): unknown { + const obj: any = {}; + if (message.values?.length) { + obj.values = message.values.map((e) => Properties.toJSON(e)); + } + return obj; + }, + + create(base?: DeepPartial): ObjectValues { + return ObjectValues.fromPartial(base ?? {}); + }, + fromPartial(object: DeepPartial): ObjectValues { + const message = createBaseObjectValues(); + message.values = object.values?.map((e) => Properties.fromPartial(e)) || []; + return message; + }, +}; + +function createBaseDateValues(): DateValues { + return { values: [] }; +} + +export const DateValues = { + encode(message: DateValues, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + for (const v of message.values) { + writer.uint32(10).string(v!); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): DateValues { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseDateValues(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.values.push(reader.string()); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): DateValues { + return { + values: globalThis.Array.isArray(object?.values) ? object.values.map((e: any) => globalThis.String(e)) : [], + }; + }, + + toJSON(message: DateValues): unknown { + const obj: any = {}; + if (message.values?.length) { + obj.values = message.values; + } + return obj; + }, + + create(base?: DeepPartial): DateValues { + return DateValues.fromPartial(base ?? {}); + }, + fromPartial(object: DeepPartial): DateValues { + const message = createBaseDateValues(); + message.values = object.values?.map((e) => e) || []; + return message; + }, +}; + +function createBaseUuidValues(): UuidValues { + return { values: [] }; +} + +export const UuidValues = { + encode(message: UuidValues, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + for (const v of message.values) { + writer.uint32(10).string(v!); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): UuidValues { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseUuidValues(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.values.push(reader.string()); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): UuidValues { + return { + values: globalThis.Array.isArray(object?.values) ? object.values.map((e: any) => globalThis.String(e)) : [], + }; + }, + + toJSON(message: UuidValues): unknown { + const obj: any = {}; + if (message.values?.length) { + obj.values = message.values; + } + return obj; + }, + + create(base?: DeepPartial): UuidValues { + return UuidValues.fromPartial(base ?? {}); + }, + fromPartial(object: DeepPartial): UuidValues { + const message = createBaseUuidValues(); + message.values = object.values?.map((e) => e) || []; + return message; + }, +}; + +function createBaseIntValues(): IntValues { + return { values: new Uint8Array(0) }; +} + +export const IntValues = { + encode(message: IntValues, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.values.length !== 0) { + writer.uint32(10).bytes(message.values); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): IntValues { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseIntValues(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.values = reader.bytes(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): IntValues { + return { values: isSet(object.values) ? bytesFromBase64(object.values) : new Uint8Array(0) }; + }, + + toJSON(message: IntValues): unknown { + const obj: any = {}; + if (message.values.length !== 0) { + obj.values = base64FromBytes(message.values); + } + return obj; + }, + + create(base?: DeepPartial): IntValues { + return IntValues.fromPartial(base ?? {}); + }, + fromPartial(object: DeepPartial): IntValues { + const message = createBaseIntValues(); + message.values = object.values ?? new Uint8Array(0); return message; }, }; @@ -735,6 +1350,31 @@ export const PhoneNumber = { }, }; +function bytesFromBase64(b64: string): Uint8Array { + if ((globalThis as any).Buffer) { + return Uint8Array.from(globalThis.Buffer.from(b64, "base64")); + } else { + const bin = globalThis.atob(b64); + const arr = new Uint8Array(bin.length); + for (let i = 0; i < bin.length; ++i) { + arr[i] = bin.charCodeAt(i); + } + return arr; + } +} + +function base64FromBytes(arr: Uint8Array): string { + if ((globalThis as any).Buffer) { + return globalThis.Buffer.from(arr).toString("base64"); + } else { + const bin: string[] = []; + arr.forEach((byte) => { + bin.push(globalThis.String.fromCharCode(byte)); + }); + return globalThis.btoa(bin.join("")); + } +} + type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined; export type DeepPartial = T extends Builtin ? T diff --git a/src/proto/v1/search_get.ts b/src/proto/v1/search_get.ts index 02f6a950..24cbbc36 100644 --- a/src/proto/v1/search_get.ts +++ b/src/proto/v1/search_get.ts @@ -59,6 +59,7 @@ export interface SearchRequest { | undefined; /** @deprecated */ uses123Api: boolean; + uses125Api: boolean; } export interface GroupBy { @@ -128,6 +129,12 @@ export interface Hybrid { fusionType: Hybrid_FusionType; vectorBytes: Uint8Array; targetVectors: string[]; + /** target_vector in msg is ignored and should not be set for hybrid */ + nearText: + | NearTextSearch + | undefined; + /** same as above. Use the target vector in the hybrid message */ + nearVector: NearVector | undefined; } export enum Hybrid_FusionType { @@ -382,6 +389,7 @@ function createBaseSearchRequest(): SearchRequest { generative: undefined, rerank: undefined, uses123Api: false, + uses125Api: false, }; } @@ -465,6 +473,9 @@ export const SearchRequest = { if (message.uses123Api === true) { writer.uint32(800).bool(message.uses123Api); } + if (message.uses125Api === true) { + writer.uint32(808).bool(message.uses125Api); + } return writer; }, @@ -657,6 +668,13 @@ export const SearchRequest = { message.uses123Api = reader.bool(); continue; + case 101: + if (tag !== 808) { + break; + } + + message.uses125Api = reader.bool(); + continue; } if ((tag & 7) === 4 || tag === 0) { break; @@ -694,6 +712,7 @@ export const SearchRequest = { generative: isSet(object.generative) ? GenerativeSearch.fromJSON(object.generative) : undefined, rerank: isSet(object.rerank) ? Rerank.fromJSON(object.rerank) : undefined, uses123Api: isSet(object.uses123Api) ? globalThis.Boolean(object.uses123Api) : false, + uses125Api: isSet(object.uses125Api) ? globalThis.Boolean(object.uses125Api) : false, }; }, @@ -777,6 +796,9 @@ export const SearchRequest = { if (message.uses123Api === true) { obj.uses123Api = message.uses123Api; } + if (message.uses125Api === true) { + obj.uses125Api = message.uses125Api; + } return obj; }, @@ -845,6 +867,7 @@ export const SearchRequest = { ? Rerank.fromPartial(object.rerank) : undefined; message.uses123Api = object.uses123Api ?? false; + message.uses125Api = object.uses125Api ?? false; return message; }, }; @@ -1522,6 +1545,8 @@ function createBaseHybrid(): Hybrid { fusionType: 0, vectorBytes: new Uint8Array(0), targetVectors: [], + nearText: undefined, + nearVector: undefined, }; } @@ -1550,6 +1575,12 @@ export const Hybrid = { for (const v of message.targetVectors) { writer.uint32(58).string(v!); } + if (message.nearText !== undefined) { + NearTextSearch.encode(message.nearText, writer.uint32(66).fork()).ldelim(); + } + if (message.nearVector !== undefined) { + NearVector.encode(message.nearVector, writer.uint32(74).fork()).ldelim(); + } return writer; }, @@ -1619,6 +1650,20 @@ export const Hybrid = { message.targetVectors.push(reader.string()); continue; + case 8: + if (tag !== 66) { + break; + } + + message.nearText = NearTextSearch.decode(reader, reader.uint32()); + continue; + case 9: + if (tag !== 74) { + break; + } + + message.nearVector = NearVector.decode(reader, reader.uint32()); + continue; } if ((tag & 7) === 4 || tag === 0) { break; @@ -1641,6 +1686,8 @@ export const Hybrid = { targetVectors: globalThis.Array.isArray(object?.targetVectors) ? object.targetVectors.map((e: any) => globalThis.String(e)) : [], + nearText: isSet(object.nearText) ? NearTextSearch.fromJSON(object.nearText) : undefined, + nearVector: isSet(object.nearVector) ? NearVector.fromJSON(object.nearVector) : undefined, }; }, @@ -1667,6 +1714,12 @@ export const Hybrid = { if (message.targetVectors?.length) { obj.targetVectors = message.targetVectors; } + if (message.nearText !== undefined) { + obj.nearText = NearTextSearch.toJSON(message.nearText); + } + if (message.nearVector !== undefined) { + obj.nearVector = NearVector.toJSON(message.nearVector); + } return obj; }, @@ -1682,6 +1735,12 @@ export const Hybrid = { message.fusionType = object.fusionType ?? 0; message.vectorBytes = object.vectorBytes ?? new Uint8Array(0); message.targetVectors = object.targetVectors?.map((e) => e) || []; + message.nearText = (object.nearText !== undefined && object.nearText !== null) + ? NearTextSearch.fromPartial(object.nearText) + : undefined; + message.nearVector = (object.nearVector !== undefined && object.nearVector !== null) + ? NearVector.fromPartial(object.nearVector) + : undefined; return message; }, }; diff --git a/src/proto/v1/tenants.ts b/src/proto/v1/tenants.ts new file mode 100644 index 00000000..4b292921 --- /dev/null +++ b/src/proto/v1/tenants.ts @@ -0,0 +1,369 @@ +/* eslint-disable */ +import _m0 from "protobufjs/minimal.js"; + +export const protobufPackage = "weaviate.v1"; + +export enum TenantActivityStatus { + TENANT_ACTIVITY_STATUS_UNSPECIFIED = 0, + TENANT_ACTIVITY_STATUS_HOT = 1, + TENANT_ACTIVITY_STATUS_COLD = 2, + TENANT_ACTIVITY_STATUS_WARM = 3, + TENANT_ACTIVITY_STATUS_FROZEN = 4, + UNRECOGNIZED = -1, +} + +export function tenantActivityStatusFromJSON(object: any): TenantActivityStatus { + switch (object) { + case 0: + case "TENANT_ACTIVITY_STATUS_UNSPECIFIED": + return TenantActivityStatus.TENANT_ACTIVITY_STATUS_UNSPECIFIED; + case 1: + case "TENANT_ACTIVITY_STATUS_HOT": + return TenantActivityStatus.TENANT_ACTIVITY_STATUS_HOT; + case 2: + case "TENANT_ACTIVITY_STATUS_COLD": + return TenantActivityStatus.TENANT_ACTIVITY_STATUS_COLD; + case 3: + case "TENANT_ACTIVITY_STATUS_WARM": + return TenantActivityStatus.TENANT_ACTIVITY_STATUS_WARM; + case 4: + case "TENANT_ACTIVITY_STATUS_FROZEN": + return TenantActivityStatus.TENANT_ACTIVITY_STATUS_FROZEN; + case -1: + case "UNRECOGNIZED": + default: + return TenantActivityStatus.UNRECOGNIZED; + } +} + +export function tenantActivityStatusToJSON(object: TenantActivityStatus): string { + switch (object) { + case TenantActivityStatus.TENANT_ACTIVITY_STATUS_UNSPECIFIED: + return "TENANT_ACTIVITY_STATUS_UNSPECIFIED"; + case TenantActivityStatus.TENANT_ACTIVITY_STATUS_HOT: + return "TENANT_ACTIVITY_STATUS_HOT"; + case TenantActivityStatus.TENANT_ACTIVITY_STATUS_COLD: + return "TENANT_ACTIVITY_STATUS_COLD"; + case TenantActivityStatus.TENANT_ACTIVITY_STATUS_WARM: + return "TENANT_ACTIVITY_STATUS_WARM"; + case TenantActivityStatus.TENANT_ACTIVITY_STATUS_FROZEN: + return "TENANT_ACTIVITY_STATUS_FROZEN"; + case TenantActivityStatus.UNRECOGNIZED: + default: + return "UNRECOGNIZED"; + } +} + +export interface TenantsGetRequest { + collection: string; + names?: TenantNames | undefined; +} + +export interface TenantNames { + values: string[]; +} + +export interface TenantsGetReply { + took: number; + tenants: Tenant[]; +} + +export interface Tenant { + name: string; + activityStatus: TenantActivityStatus; +} + +function createBaseTenantsGetRequest(): TenantsGetRequest { + return { collection: "", names: undefined }; +} + +export const TenantsGetRequest = { + encode(message: TenantsGetRequest, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.collection !== "") { + writer.uint32(10).string(message.collection); + } + if (message.names !== undefined) { + TenantNames.encode(message.names, writer.uint32(18).fork()).ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): TenantsGetRequest { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseTenantsGetRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.collection = reader.string(); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.names = TenantNames.decode(reader, reader.uint32()); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): TenantsGetRequest { + return { + collection: isSet(object.collection) ? globalThis.String(object.collection) : "", + names: isSet(object.names) ? TenantNames.fromJSON(object.names) : undefined, + }; + }, + + toJSON(message: TenantsGetRequest): unknown { + const obj: any = {}; + if (message.collection !== "") { + obj.collection = message.collection; + } + if (message.names !== undefined) { + obj.names = TenantNames.toJSON(message.names); + } + return obj; + }, + + create(base?: DeepPartial): TenantsGetRequest { + return TenantsGetRequest.fromPartial(base ?? {}); + }, + fromPartial(object: DeepPartial): TenantsGetRequest { + const message = createBaseTenantsGetRequest(); + message.collection = object.collection ?? ""; + message.names = (object.names !== undefined && object.names !== null) + ? TenantNames.fromPartial(object.names) + : undefined; + return message; + }, +}; + +function createBaseTenantNames(): TenantNames { + return { values: [] }; +} + +export const TenantNames = { + encode(message: TenantNames, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + for (const v of message.values) { + writer.uint32(10).string(v!); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): TenantNames { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseTenantNames(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.values.push(reader.string()); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): TenantNames { + return { + values: globalThis.Array.isArray(object?.values) ? object.values.map((e: any) => globalThis.String(e)) : [], + }; + }, + + toJSON(message: TenantNames): unknown { + const obj: any = {}; + if (message.values?.length) { + obj.values = message.values; + } + return obj; + }, + + create(base?: DeepPartial): TenantNames { + return TenantNames.fromPartial(base ?? {}); + }, + fromPartial(object: DeepPartial): TenantNames { + const message = createBaseTenantNames(); + message.values = object.values?.map((e) => e) || []; + return message; + }, +}; + +function createBaseTenantsGetReply(): TenantsGetReply { + return { took: 0, tenants: [] }; +} + +export const TenantsGetReply = { + encode(message: TenantsGetReply, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.took !== 0) { + writer.uint32(13).float(message.took); + } + for (const v of message.tenants) { + Tenant.encode(v!, writer.uint32(18).fork()).ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): TenantsGetReply { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseTenantsGetReply(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 13) { + break; + } + + message.took = reader.float(); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.tenants.push(Tenant.decode(reader, reader.uint32())); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): TenantsGetReply { + return { + took: isSet(object.took) ? globalThis.Number(object.took) : 0, + tenants: globalThis.Array.isArray(object?.tenants) ? object.tenants.map((e: any) => Tenant.fromJSON(e)) : [], + }; + }, + + toJSON(message: TenantsGetReply): unknown { + const obj: any = {}; + if (message.took !== 0) { + obj.took = message.took; + } + if (message.tenants?.length) { + obj.tenants = message.tenants.map((e) => Tenant.toJSON(e)); + } + return obj; + }, + + create(base?: DeepPartial): TenantsGetReply { + return TenantsGetReply.fromPartial(base ?? {}); + }, + fromPartial(object: DeepPartial): TenantsGetReply { + const message = createBaseTenantsGetReply(); + message.took = object.took ?? 0; + message.tenants = object.tenants?.map((e) => Tenant.fromPartial(e)) || []; + return message; + }, +}; + +function createBaseTenant(): Tenant { + return { name: "", activityStatus: 0 }; +} + +export const Tenant = { + encode(message: Tenant, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.name !== "") { + writer.uint32(10).string(message.name); + } + if (message.activityStatus !== 0) { + writer.uint32(16).int32(message.activityStatus); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): Tenant { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseTenant(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.name = reader.string(); + continue; + case 2: + if (tag !== 16) { + break; + } + + message.activityStatus = reader.int32() as any; + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): Tenant { + return { + name: isSet(object.name) ? globalThis.String(object.name) : "", + activityStatus: isSet(object.activityStatus) ? tenantActivityStatusFromJSON(object.activityStatus) : 0, + }; + }, + + toJSON(message: Tenant): unknown { + const obj: any = {}; + if (message.name !== "") { + obj.name = message.name; + } + if (message.activityStatus !== 0) { + obj.activityStatus = tenantActivityStatusToJSON(message.activityStatus); + } + return obj; + }, + + create(base?: DeepPartial): Tenant { + return Tenant.fromPartial(base ?? {}); + }, + fromPartial(object: DeepPartial): Tenant { + const message = createBaseTenant(); + message.name = object.name ?? ""; + message.activityStatus = object.activityStatus ?? 0; + return message; + }, +}; + +type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined; + +export type DeepPartial = T extends Builtin ? T + : T extends globalThis.Array ? globalThis.Array> + : T extends ReadonlyArray ? ReadonlyArray> + : T extends {} ? { [K in keyof T]?: DeepPartial } + : Partial; + +function isSet(value: any): boolean { + return value !== null && value !== undefined; +} diff --git a/src/proto/v1/weaviate.ts b/src/proto/v1/weaviate.ts index 79c8754e..aeb4a23c 100644 --- a/src/proto/v1/weaviate.ts +++ b/src/proto/v1/weaviate.ts @@ -3,6 +3,7 @@ import type { CallContext, CallOptions } from "nice-grpc-common"; import { BatchObjectsReply, BatchObjectsRequest } from "./batch.js"; import { BatchDeleteReply, BatchDeleteRequest } from "./batch_delete.js"; import { SearchReply, SearchRequest } from "./search_get.js"; +import { TenantsGetReply, TenantsGetRequest } from "./tenants.js"; export const protobufPackage = "weaviate.v1"; @@ -35,6 +36,14 @@ export const WeaviateDefinition = { responseStream: false, options: {}, }, + tenantsGet: { + name: "TenantsGet", + requestType: TenantsGetRequest, + requestStream: false, + responseType: TenantsGetReply, + responseStream: false, + options: {}, + }, }, } as const; @@ -48,6 +57,7 @@ export interface WeaviateServiceImplementation { request: BatchDeleteRequest, context: CallContext & CallContextExt, ): Promise>; + tenantsGet(request: TenantsGetRequest, context: CallContext & CallContextExt): Promise>; } export interface WeaviateClient { @@ -60,6 +70,7 @@ export interface WeaviateClient { request: DeepPartial, options?: CallOptions & CallOptionsExt, ): Promise; + tenantsGet(request: DeepPartial, options?: CallOptions & CallOptionsExt): Promise; } type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined; diff --git a/src/schema/journey.test.ts b/src/schema/journey.test.ts index 8fb77ce3..31432878 100644 --- a/src/schema/journey.test.ts +++ b/src/schema/journey.test.ts @@ -15,7 +15,10 @@ const is125 = (client: WeaviateClient) => .metaGetter() .do() .then((res: Meta) => res.version) - .then((version: string) => { + .then((version: string | undefined) => { + if (!version) { + return false; + } const semver = version.split('.').map((v) => parseInt(v, 10)); return semver[1] >= 25; }); @@ -681,11 +684,6 @@ describe('multi tenancy', () => { scheme: 'http', host: 'localhost:8080', }); - const versionPromise: Promise = client.misc - .metaGetter() - .do() - .then((res: Meta) => res.version); - const classObj: WeaviateClass = { class: 'MultiTenancy', properties: [ diff --git a/src/utils/dbVersion.ts b/src/utils/dbVersion.ts index 1d50b051..972d0503 100644 --- a/src/utils/dbVersion.ts +++ b/src/utils/dbVersion.ts @@ -1,3 +1,7 @@ +import ConnectionGRPC from '../connection/grpc.js'; +import MetaGetter from '../misc/metaGetter.js'; +import { Meta } from '../openapi/types.js'; + export class DbVersionSupport { private dbVersionProvider: VersionProvider; @@ -5,45 +9,50 @@ export class DbVersionSupport { this.dbVersionProvider = dbVersionProvider; } + getVersion = () => this.dbVersionProvider.getVersion(); + supportsClassNameNamespacedEndpointsPromise() { - return this.dbVersionProvider.getVersionPromise().then((version?: string) => ({ - version, - supports: this.supportsClassNameNamespacedEndpoints(version), - warns: { - deprecatedNonClassNameNamespacedEndpointsForObjects: () => - console.warn( - `Usage of objects paths without className is deprecated in Weaviate ${version}. Please provide className parameter` - ), - deprecatedNonClassNameNamespacedEndpointsForReferences: () => - console.warn( - `Usage of references paths without className is deprecated in Weaviate ${version}. Please provide className parameter` - ), - deprecatedNonClassNameNamespacedEndpointsForBeacons: () => - console.warn( - `Usage of beacons paths without className is deprecated in Weaviate ${version}. Please provide className parameter` - ), - deprecatedWeaviateTooOld: () => - console.warn( - `Usage of weaviate ${version} is deprecated. Please consider upgrading to the latest version. See https://www.weaviate.io/developers/weaviate for details.` - ), - notSupportedClassNamespacedEndpointsForObjects: () => - console.warn( - `Usage of objects paths with className is not supported in Weaviate ${version}. className parameter is ignored` - ), - notSupportedClassNamespacedEndpointsForReferences: () => - console.warn( - `Usage of references paths with className is not supported in Weaviate ${version}. className parameter is ignored` - ), - notSupportedClassNamespacedEndpointsForBeacons: () => - console.warn( - `Usage of beacons paths with className is not supported in Weaviate ${version}. className parameter is ignored` - ), - notSupportedClassParameterInEndpointsForObjects: () => - console.warn( - `Usage of objects paths with class query parameter is not supported in Weaviate ${version}. class query parameter is ignored` - ), - }, - })); + return this.dbVersionProvider + .getVersion() + .then((version) => version.show()) + .then((version) => ({ + version: version, + supports: this.supportsClassNameNamespacedEndpoints(version), + warns: { + deprecatedNonClassNameNamespacedEndpointsForObjects: () => + console.warn( + `Usage of objects paths without className is deprecated in Weaviate ${version}. Please provide className parameter` + ), + deprecatedNonClassNameNamespacedEndpointsForReferences: () => + console.warn( + `Usage of references paths without className is deprecated in Weaviate ${version}. Please provide className parameter` + ), + deprecatedNonClassNameNamespacedEndpointsForBeacons: () => + console.warn( + `Usage of beacons paths without className is deprecated in Weaviate ${version}. Please provide className parameter` + ), + deprecatedWeaviateTooOld: () => + console.warn( + `Usage of weaviate ${version} is deprecated. Please consider upgrading to the latest version. See https://www.weaviate.io/developers/weaviate for details.` + ), + notSupportedClassNamespacedEndpointsForObjects: () => + console.warn( + `Usage of objects paths with className is not supported in Weaviate ${version}. className parameter is ignored` + ), + notSupportedClassNamespacedEndpointsForReferences: () => + console.warn( + `Usage of references paths with className is not supported in Weaviate ${version}. className parameter is ignored` + ), + notSupportedClassNamespacedEndpointsForBeacons: () => + console.warn( + `Usage of beacons paths with className is not supported in Weaviate ${version}. className parameter is ignored` + ), + notSupportedClassParameterInEndpointsForObjects: () => + console.warn( + `Usage of objects paths with class query parameter is not supported in Weaviate ${version}. class query parameter is ignored` + ), + }, + })); } // >= 1.14 @@ -58,48 +67,184 @@ export class DbVersionSupport { } return false; } + + private errorMessage = (feature: string, current: string, required: string) => + `${feature} is not supported with Weaviate version v${current}. Please use version v${required} or higher.`; + + supportsCompatibleGrpcService = () => + this.dbVersionProvider.getVersion().then((version) => { + return { + version: version, + supports: version.isAtLeast(1, 23, 7), + message: this.errorMessage('gRPC', version.show(), '1.23.7'), + }; + }); + + supportsBm25AndHybridGroupByQueries = () => + this.dbVersionProvider.getVersion().then((version) => { + return { + version: version, + supports: version.isAtLeast(1, 25, 0), + message: (query: 'Bm25' | 'Hybrid') => + this.errorMessage(`GroupBy with ${query}`, version.show(), '1.25.0'), + }; + }); + + supportsHybridNearTextAndNearVectorSubsearchQueries = () => { + return this.dbVersionProvider.getVersion().then((version) => { + return { + version: version, + supports: version.isAtLeast(1, 25, 0), + message: this.errorMessage('Hybrid nearText/nearVector subsearching', version.show(), '1.25.0'), + }; + }); + }; + + supports125ListValue = () => { + return this.dbVersionProvider.getVersion().then((version) => { + return { + version: version, + supports: version.isAtLeast(1, 25, 0), + message: undefined, + }; + }); + }; + + supportsNamedVectors = () => { + return this.dbVersionProvider.getVersion().then((version) => { + return { + version: version, + supports: version.isAtLeast(1, 24, 0), + message: this.errorMessage('Named vectors', version.show(), '1.24.0'), + }; + }); + }; + + supportsTenantsGetGRPCMethod = () => { + return this.dbVersionProvider.getVersion().then((version) => { + return { + version: version, + supports: version.isAtLeast(1, 25, 0), + message: this.errorMessage('Tenants get method', version.show(), '1.25.0'), + }; + }); + }; } const EMPTY_VERSION = ''; export interface VersionProvider { - getVersionPromise(): Promise; + getVersionString(): Promise; + getVersion(): Promise; } export class DbVersionProvider implements VersionProvider { - private versionPromise?: Promise; - private readonly emptyVersionPromise: Promise; - private versionGetter: () => Promise; + private versionPromise?: Promise; + private versionStringGetter: () => Promise; - constructor(versionGetter: () => Promise) { - this.versionGetter = versionGetter; - - this.emptyVersionPromise = Promise.resolve(EMPTY_VERSION); + constructor(versionStringGetter: () => Promise) { + this.versionStringGetter = versionStringGetter; this.versionPromise = undefined; } - getVersionPromise(): Promise { + getVersionString(): Promise { + return this.getVersion().then((version) => version.show()); + } + + getVersion(): Promise { if (this.versionPromise) { return this.versionPromise; } - return this.versionGetter().then((version) => this.assignPromise(version)); + return this.versionStringGetter().then((version) => this.cache(version)); } refresh(force = false): Promise { if (force || !this.versionPromise) { this.versionPromise = undefined; - return this.versionGetter() - .then((version) => this.assignPromise(version)) + return this.versionStringGetter() + .then((version) => this.cache(version)) .then(() => Promise.resolve(true)); } return Promise.resolve(false); } - assignPromise(version: string): Promise { + cache(version: string): Promise { if (version === EMPTY_VERSION) { - return this.emptyVersionPromise; + return Promise.resolve(new DbVersion(0, 0, 0)); } - this.versionPromise = Promise.resolve(version); + this.versionPromise = Promise.resolve(DbVersion.fromString(version)); return this.versionPromise; } } + +export function initDbVersionProvider(conn: ConnectionGRPC) { + const metaGetter = new MetaGetter(conn); + const versionGetter = () => { + return metaGetter.do().then((result) => (result.version ? result.version : '')); + }; + + const dbVersionProvider = new DbVersionProvider(versionGetter); + dbVersionProvider.refresh(); + + return dbVersionProvider; +} + +export class DbVersion { + private major: number; + private minor: number; + private patch?: number; + + constructor(major: number, minor: number, patch?: number) { + this.major = major; + this.minor = minor; + this.patch = patch; + } + + static fromString = (version: string) => { + let regex = /^v?(\d+)\.(\d+)\.(\d+)(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?$/; + let match = version.match(regex); + if (match) { + const [_, major, minor, patch] = match; + return new DbVersion(parseInt(major, 10), parseInt(minor, 10), parseInt(patch, 10)); + } + + regex = /^v?(\d+)\.(\d+)$/; + match = version.match(regex); + if (match) { + const [_, major, minor] = match; + return new DbVersion(parseInt(major, 10), parseInt(minor, 10)); + } + + throw new Error(`Invalid version string: ${version}`); + }; + + private checkNumber = (num: number) => { + if (!Number.isSafeInteger(num)) { + throw new Error(`Invalid number: ${num}`); + } + }; + + show = () => + this.major === 0 && this.major === this.minor && this.minor === this.patch + ? '' + : `${this.major}.${this.minor}${this.patch !== undefined ? `.${this.patch}` : ''}`; + + isAtLeast = (major: number, minor: number, patch?: number) => { + this.checkNumber(major); + this.checkNumber(minor); + + if (this.major > major) return true; + if (this.major < major) return false; + + if (this.minor > minor) return true; + if (this.minor < minor) return false; + + if (this.patch !== undefined && patch !== undefined && this.patch >= patch) { + this.checkNumber(patch); + return true; + } + return false; + }; + + isLowerThan = (major: number, minor: number, patch: number) => !this.isAtLeast(major, minor, patch); +} diff --git a/src/utils/journey.test.ts b/src/utils/journey.test.ts index 714dc42b..832669c0 100644 --- a/src/utils/journey.test.ts +++ b/src/utils/journey.test.ts @@ -9,24 +9,14 @@ describe('db version provider', () => { const versionGetter = () => Promise.resolve(EMPTY_VERSION); const dbVersionProvider = new DbVersionProvider(versionGetter); - return dbVersionProvider - .getVersionPromise() - .then((version) => expect(version).toBe(EMPTY_VERSION)) - .catch(() => { - throw new Error('version should always resolve successfully'); - }); + return dbVersionProvider.getVersionString().then((version) => expect(version).toBe(EMPTY_VERSION)); }); it('should return proper version', () => { const versionGetter = () => Promise.resolve(VERSION_1); const dbVersionProvider = new DbVersionProvider(versionGetter); - return dbVersionProvider - .getVersionPromise() - .then((version) => expect(version).toBe(VERSION_1)) - .catch(() => { - throw new Error('version should always resolve successfully'); - }); + return dbVersionProvider.getVersionString().then((version) => expect(version).toBe(VERSION_1)); }); it('should return new version after refresh', async () => { @@ -44,19 +34,9 @@ describe('db version provider', () => { }; const dbVersionProvider = new DbVersionProvider(versionGetter); - await dbVersionProvider - .getVersionPromise() - .then((version) => expect(version).toBe(VERSION_1)) - .catch(() => { - throw new Error('version should always resolve successfully'); - }); + await dbVersionProvider.getVersionString().then((version) => expect(version).toBe(VERSION_1)); await dbVersionProvider.refresh(true); - await dbVersionProvider - .getVersionPromise() - .then((version) => expect(version).toBe(VERSION_2)) - .catch(() => { - throw new Error('version should always resolve successfully'); - }); + await dbVersionProvider.getVersionString().then((version) => expect(version).toBe(VERSION_2)); }); it('should fetch version once', async () => { @@ -72,24 +52,9 @@ describe('db version provider', () => { }; const dbVersionProvider = new DbVersionProvider(versionGetter); - await dbVersionProvider - .getVersionPromise() - .then((version) => expect(version).toBe(VERSION_1)) - .catch(() => { - throw new Error('version should always resolve successfully'); - }); - await dbVersionProvider - .getVersionPromise() - .then((version) => expect(version).toBe(VERSION_1)) - .catch(() => { - throw new Error('version should always resolve successfully'); - }); - await dbVersionProvider - .getVersionPromise() - .then((version) => expect(version).toBe(VERSION_1)) - .catch(() => { - throw new Error('version should always resolve successfully'); - }); + await dbVersionProvider.getVersionString().then((version) => expect(version).toBe(VERSION_1)); + await dbVersionProvider.getVersionString().then((version) => expect(version).toBe(VERSION_1)); + await dbVersionProvider.getVersion().then((version) => expect(version.show()).toBe(VERSION_1)); expect(callsCounter).toBe(1); }); @@ -110,36 +75,11 @@ describe('db version provider', () => { }; const dbVersionProvider = new DbVersionProvider(versionGetter); - await dbVersionProvider - .getVersionPromise() - .then((version) => expect(version).toBe(EMPTY_VERSION)) - .catch(() => { - throw new Error('version should always resolve successfully'); - }); - await dbVersionProvider - .getVersionPromise() - .then((version) => expect(version).toBe(EMPTY_VERSION)) - .catch(() => { - throw new Error('version should always resolve successfully'); - }); - await dbVersionProvider - .getVersionPromise() - .then((version) => expect(version).toBe(VERSION_1)) - .catch(() => { - throw new Error('version should always resolve successfully'); - }); - await dbVersionProvider - .getVersionPromise() - .then((version) => expect(version).toBe(VERSION_1)) - .catch(() => { - throw new Error('version should always resolve successfully'); - }); - await dbVersionProvider - .getVersionPromise() - .then((version) => expect(version).toBe(VERSION_1)) - .catch(() => { - throw new Error('version should always resolve successfully'); - }); + await dbVersionProvider.getVersionString().then((version) => expect(version).toBe(EMPTY_VERSION)); + await dbVersionProvider.getVersionString().then((version) => expect(version).toBe(EMPTY_VERSION)); + await dbVersionProvider.getVersionString().then((version) => expect(version).toBe(VERSION_1)); + await dbVersionProvider.getVersionString().then((version) => expect(version).toBe(VERSION_1)); + await dbVersionProvider.getVersionString().then((version) => expect(version).toBe(VERSION_1)); expect(callsCounter).toBe(3); }); @@ -152,33 +92,47 @@ describe('db version support', () => { const dbVersionProvider = new DbVersionProvider(() => Promise.resolve(version)); const dbVersionSupport = new DbVersionSupport(dbVersionProvider); - await dbVersionSupport - .supportsClassNameNamespacedEndpointsPromise() - .then((support) => { - expect(support.supports).toBe(false); - expect(support.version).toBe(version); - }) - .catch(() => { - throw new Error('version should always resolve successfully'); - }); + const support = await dbVersionSupport.supportsClassNameNamespacedEndpointsPromise(); + expect(support.supports).toBe(false); + expect(support.version).toBe(version); }); }); it('should support', () => { - const supportedVersions = ['1.14.0', '1.14.9', '1.100', '2.0', '10.11.12']; - supportedVersions.forEach(async (version) => { - const dbVersionProvider = new DbVersionProvider(() => Promise.resolve(version)); + const supportedVersions = [ + { + in: '1.14.0', + exp: '1.14.0', + }, + { + in: '1.14.9', + exp: '1.14.9', + }, + { + in: '1.100', + exp: '1.100', + }, + { + in: '2.0', + exp: '2.0', + }, + { + in: '10.11.12', + exp: '10.11.12', + }, + { + in: '1.25.0-raft', + exp: '1.25.0', + }, + ]; + return supportedVersions.forEach(async (version) => { + const dbVersionProvider = new DbVersionProvider(() => Promise.resolve(version.in)); const dbVersionSupport = new DbVersionSupport(dbVersionProvider); - await dbVersionSupport - .supportsClassNameNamespacedEndpointsPromise() - .then((support) => { - expect(support.supports).toBe(true); - expect(support.version).toBe(version); - }) - .catch(() => { - throw new Error('version should always resolve successfully'); - }); + const support = await dbVersionSupport.supportsClassNameNamespacedEndpointsPromise(); + + expect(support.supports).toBe(true); + expect(support.version).toBe(version.exp); }); }); }); diff --git a/test/dbVersionProvider.ts b/test/dbVersionProvider.ts index a3267e1f..1cbaec28 100644 --- a/test/dbVersionProvider.ts +++ b/test/dbVersionProvider.ts @@ -1,4 +1,4 @@ -import { VersionProvider } from '../src/utils/dbVersion.js'; +import { VersionProvider, DbVersion } from '../src/utils/dbVersion.js'; export class TestDbVersionProvider implements VersionProvider { private version: string; @@ -7,7 +7,11 @@ export class TestDbVersionProvider implements VersionProvider { this.version = version; } - getVersionPromise(): Promise { + getVersionString(): Promise { return Promise.resolve(this.version); } + + getVersion(): Promise { + return Promise.resolve(DbVersion.fromString(this.version)); + } } diff --git a/tools/refresh_protos.sh b/tools/refresh_protos.sh index 5e741e5d..b8857b1f 100755 --- a/tools/refresh_protos.sh +++ b/tools/refresh_protos.sh @@ -33,5 +33,6 @@ sed -i '' 's/\".\/batch\"/\".\/batch.js\"/g' src/proto/v1/*.ts sed -i '' 's/\".\/batch_delete\"/\".\/batch_delete.js\"/g' src/proto/v1/*.ts sed -i '' 's/\".\/properties\"/\".\/properties.js\"/g' src/proto/v1/*.ts sed -i '' 's/\".\/search_get\"/\".\/search_get.js\"/g' src/proto/v1/*.ts +sed -i '' 's/\".\/tenants\"/\".\/tenants.js\"/g' src/proto/v1/*.ts echo "done"