diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 69c5bdd5..335f89dd 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -14,6 +14,7 @@ env: WEAVIATE_128: 1.28.11 WEAVIATE_129: 1.29.1 WEAVIATE_130: 1.30.1 + WEAVIATE_131: 1.31.0 concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} diff --git a/src/collections/query/index.ts b/src/collections/query/index.ts index 3be179d3..9fb5b58e 100644 --- a/src/collections/query/index.ts +++ b/src/collections/query/index.ts @@ -247,6 +247,7 @@ export { BaseHybridOptions, BaseNearOptions, BaseNearTextOptions, + Bm25OperatorOptions, Bm25Options, FetchObjectByIdOptions, FetchObjectsOptions, @@ -266,3 +267,5 @@ export { QueryReturn, SearchOptions, } from './types.js'; + +export { Bm25Operator } from './utils.js'; diff --git a/src/collections/query/integration.test.ts b/src/collections/query/integration.test.ts index 09698c05..965afa32 100644 --- a/src/collections/query/integration.test.ts +++ b/src/collections/query/integration.test.ts @@ -1,10 +1,12 @@ /* eslint-disable @typescript-eslint/no-non-null-assertion */ /* eslint-disable @typescript-eslint/no-non-null-asserted-optional-chain */ +import { requireAtLeast } from '../../../test/version.js'; import { WeaviateUnsupportedFeatureError } from '../../errors.js'; import weaviate, { WeaviateClient } from '../../index.js'; import { Collection } from '../collection/index.js'; import { CrossReference, Reference } from '../references/index.js'; import { GroupByOptions } from '../types/index.js'; +import { Bm25Operator } from './utils.js'; describe('Testing of the collection.query methods with a simple collection', () => { let client: WeaviateClient; @@ -132,6 +134,32 @@ describe('Testing of the collection.query methods with a simple collection', () expect(ret.objects[0].uuid).toEqual(id); }); + requireAtLeast( + 1, + 31, + 0 + )('bm25 search operator (minimum_should_match)', () => { + it('should query with bm25 + operator', async () => { + const ret = await collection.query.bm25('carrot', { + limit: 1, + operator: Bm25Operator.or({ minimumMatch: 1 }), + }); + expect(ret.objects.length).toEqual(1); + expect(ret.objects[0].properties.testProp).toEqual('carrot'); + expect(ret.objects[0].uuid).toEqual(id); + }); + + it('should query with hybrid + bm25Operator', async () => { + const ret = await collection.query.hybrid('carrot', { + limit: 1, + bm25Operator: Bm25Operator.and(), + }); + expect(ret.objects.length).toEqual(1); + expect(ret.objects[0].properties.testProp).toEqual('carrot'); + expect(ret.objects[0].uuid).toEqual(id); + }); + }); + it('should query with hybrid and vector', async () => { const ret = await collection.query.hybrid('carrot', { limit: 1, diff --git a/src/collections/query/types.ts b/src/collections/query/types.ts index 48fed7d6..5eaf0f53 100644 --- a/src/collections/query/types.ts +++ b/src/collections/query/types.ts @@ -84,9 +84,15 @@ export type Bm25QueryProperty = { weight: number; }; +export type Bm25OperatorOr = { operator: 'Or'; minimumMatch: number }; +export type Bm25OperatorAnd = { operator: 'And' }; + +export type Bm25OperatorOptions = Bm25OperatorOr | Bm25OperatorAnd; + export type Bm25SearchOptions = { /** Which properties of the collection to perform the keyword search on. */ queryProperties?: (PrimitiveKeys | Bm25QueryProperty)[]; + operator?: Bm25OperatorOptions; }; /** Base options available in the `query.bm25` method */ @@ -115,6 +121,7 @@ export type HybridSearchOptions = { targetVector?: TargetVectorInputType; /** The specific vector to search for or a specific vector subsearch. If not specified, the query is vectorized and used in the similarity search. */ vector?: NearVectorInputType | HybridNearTextSubSearch | HybridNearVectorSubSearch; + bm25Operator?: Bm25OperatorOptions; }; /** Base options available in the `query.hybrid` method */ diff --git a/src/collections/query/utils.ts b/src/collections/query/utils.ts index 4bbe9f76..0450679d 100644 --- a/src/collections/query/utils.ts +++ b/src/collections/query/utils.ts @@ -1,5 +1,5 @@ import { MultiTargetVectorJoin } from '../index.js'; -import { NearVectorInputType, TargetVectorInputType } from './types.js'; +import { Bm25OperatorOptions, Bm25OperatorOr, NearVectorInputType, TargetVectorInputType } from './types.js'; export class NearVectorInputGuards { public static is1DArray(input: NearVectorInputType): input is number[] { @@ -34,3 +34,13 @@ export class TargetVectorInputGuards { return i.combination !== undefined && i.targetVectors !== undefined; } } + +export class Bm25Operator { + static and(): Bm25OperatorOptions { + return { operator: 'And' }; + } + + static or(opts: Omit): Bm25OperatorOptions { + return { ...opts, operator: 'Or' }; + } +} diff --git a/src/collections/serialize/index.ts b/src/collections/serialize/index.ts index e9a74973..ddebb8e4 100644 --- a/src/collections/serialize/index.ts +++ b/src/collections/serialize/index.ts @@ -15,6 +15,8 @@ import { NearThermalSearch, NearVector, NearVideoSearch, + SearchOperatorOptions, + SearchOperatorOptions_Operator, Targets, VectorForTarget, WeightsForTarget, @@ -115,6 +117,7 @@ import { import { BaseHybridOptions, BaseNearOptions, + Bm25OperatorOptions, Bm25Options, Bm25QueryProperty, Bm25SearchOptions, @@ -960,10 +963,26 @@ export class Serialize { }); }; + private static bm25SearchOperator = ( + searchOperator?: Bm25OperatorOptions + ): SearchOperatorOptions | undefined => { + if (searchOperator) { + return SearchOperatorOptions.fromPartial( + searchOperator.operator === ('And' as const) + ? { operator: SearchOperatorOptions_Operator.OPERATOR_AND } + : { + operator: SearchOperatorOptions_Operator.OPERATOR_OR, + minimumOrTokensMatch: searchOperator.minimumMatch, + } + ); + } + }; + public static bm25Search = (args: { query: string } & Bm25SearchOptions): BM25 => { return BM25.fromPartial({ query: args.query, properties: this.bm25QueryProperties(args.queryProperties), + searchOperator: this.bm25SearchOperator(args.operator), }); }; @@ -1074,6 +1093,7 @@ export class Serialize { vectorBytes: vectorBytes, vectorDistance: args.maxVectorDistance, fusionType: fusionType(args.fusionType), + bm25SearchOperator: this.bm25SearchOperator(args.bm25Operator), targetVectors, targets, nearText, diff --git a/src/proto/v1/base_search.ts b/src/proto/v1/base_search.ts index d2e9606a..ef5795e2 100644 --- a/src/proto/v1/base_search.ts +++ b/src/proto/v1/base_search.ts @@ -100,6 +100,50 @@ export interface VectorForTarget { vectors: Vectors[]; } +export interface SearchOperatorOptions { + operator: SearchOperatorOptions_Operator; + minimumOrTokensMatch?: number | undefined; +} + +export enum SearchOperatorOptions_Operator { + OPERATOR_UNSPECIFIED = 0, + OPERATOR_OR = 1, + OPERATOR_AND = 2, + UNRECOGNIZED = -1, +} + +export function searchOperatorOptions_OperatorFromJSON(object: any): SearchOperatorOptions_Operator { + switch (object) { + case 0: + case "OPERATOR_UNSPECIFIED": + return SearchOperatorOptions_Operator.OPERATOR_UNSPECIFIED; + case 1: + case "OPERATOR_OR": + return SearchOperatorOptions_Operator.OPERATOR_OR; + case 2: + case "OPERATOR_AND": + return SearchOperatorOptions_Operator.OPERATOR_AND; + case -1: + case "UNRECOGNIZED": + default: + return SearchOperatorOptions_Operator.UNRECOGNIZED; + } +} + +export function searchOperatorOptions_OperatorToJSON(object: SearchOperatorOptions_Operator): string { + switch (object) { + case SearchOperatorOptions_Operator.OPERATOR_UNSPECIFIED: + return "OPERATOR_UNSPECIFIED"; + case SearchOperatorOptions_Operator.OPERATOR_OR: + return "OPERATOR_OR"; + case SearchOperatorOptions_Operator.OPERATOR_AND: + return "OPERATOR_AND"; + case SearchOperatorOptions_Operator.UNRECOGNIZED: + default: + return "UNRECOGNIZED"; + } +} + export interface Hybrid { query: string; properties: string[]; @@ -130,6 +174,7 @@ export interface Hybrid { /** same as above. Use the target vector in the hybrid message */ nearVector: NearVector | undefined; targets: Targets | undefined; + bm25SearchOperator?: SearchOperatorOptions | undefined; vectorDistance?: number | undefined; vectors: Vectors[]; } @@ -346,6 +391,7 @@ export interface NearIMUSearch { export interface BM25 { query: string; properties: string[]; + searchOperator?: SearchOperatorOptions | undefined; } function createBaseWeightsForTarget(): WeightsForTarget { @@ -712,6 +758,82 @@ export const VectorForTarget = { }, }; +function createBaseSearchOperatorOptions(): SearchOperatorOptions { + return { operator: 0, minimumOrTokensMatch: undefined }; +} + +export const SearchOperatorOptions = { + encode(message: SearchOperatorOptions, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.operator !== 0) { + writer.uint32(8).int32(message.operator); + } + if (message.minimumOrTokensMatch !== undefined) { + writer.uint32(16).int32(message.minimumOrTokensMatch); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): SearchOperatorOptions { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseSearchOperatorOptions(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 8) { + break; + } + + message.operator = reader.int32() as any; + continue; + case 2: + if (tag !== 16) { + break; + } + + message.minimumOrTokensMatch = reader.int32(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): SearchOperatorOptions { + return { + operator: isSet(object.operator) ? searchOperatorOptions_OperatorFromJSON(object.operator) : 0, + minimumOrTokensMatch: isSet(object.minimumOrTokensMatch) + ? globalThis.Number(object.minimumOrTokensMatch) + : undefined, + }; + }, + + toJSON(message: SearchOperatorOptions): unknown { + const obj: any = {}; + if (message.operator !== 0) { + obj.operator = searchOperatorOptions_OperatorToJSON(message.operator); + } + if (message.minimumOrTokensMatch !== undefined) { + obj.minimumOrTokensMatch = Math.round(message.minimumOrTokensMatch); + } + return obj; + }, + + create(base?: DeepPartial): SearchOperatorOptions { + return SearchOperatorOptions.fromPartial(base ?? {}); + }, + fromPartial(object: DeepPartial): SearchOperatorOptions { + const message = createBaseSearchOperatorOptions(); + message.operator = object.operator ?? 0; + message.minimumOrTokensMatch = object.minimumOrTokensMatch ?? undefined; + return message; + }, +}; + function createBaseHybrid(): Hybrid { return { query: "", @@ -724,6 +846,7 @@ function createBaseHybrid(): Hybrid { nearText: undefined, nearVector: undefined, targets: undefined, + bm25SearchOperator: undefined, vectorDistance: undefined, vectors: [], }; @@ -763,6 +886,9 @@ export const Hybrid = { if (message.targets !== undefined) { Targets.encode(message.targets, writer.uint32(82).fork()).ldelim(); } + if (message.bm25SearchOperator !== undefined) { + SearchOperatorOptions.encode(message.bm25SearchOperator, writer.uint32(90).fork()).ldelim(); + } if (message.vectorDistance !== undefined) { writer.uint32(165).float(message.vectorDistance); } @@ -859,6 +985,13 @@ export const Hybrid = { message.targets = Targets.decode(reader, reader.uint32()); continue; + case 11: + if (tag !== 90) { + break; + } + + message.bm25SearchOperator = SearchOperatorOptions.decode(reader, reader.uint32()); + continue; case 20: if (tag !== 165) { break; @@ -898,6 +1031,9 @@ export const Hybrid = { nearText: isSet(object.nearText) ? NearTextSearch.fromJSON(object.nearText) : undefined, nearVector: isSet(object.nearVector) ? NearVector.fromJSON(object.nearVector) : undefined, targets: isSet(object.targets) ? Targets.fromJSON(object.targets) : undefined, + bm25SearchOperator: isSet(object.bm25SearchOperator) + ? SearchOperatorOptions.fromJSON(object.bm25SearchOperator) + : undefined, vectorDistance: isSet(object.vectorDistance) ? globalThis.Number(object.vectorDistance) : undefined, vectors: globalThis.Array.isArray(object?.vectors) ? object.vectors.map((e: any) => Vectors.fromJSON(e)) : [], }; @@ -935,6 +1071,9 @@ export const Hybrid = { if (message.targets !== undefined) { obj.targets = Targets.toJSON(message.targets); } + if (message.bm25SearchOperator !== undefined) { + obj.bm25SearchOperator = SearchOperatorOptions.toJSON(message.bm25SearchOperator); + } if (message.vectorDistance !== undefined) { obj.vectorDistance = message.vectorDistance; } @@ -965,6 +1104,9 @@ export const Hybrid = { message.targets = (object.targets !== undefined && object.targets !== null) ? Targets.fromPartial(object.targets) : undefined; + message.bm25SearchOperator = (object.bm25SearchOperator !== undefined && object.bm25SearchOperator !== null) + ? SearchOperatorOptions.fromPartial(object.bm25SearchOperator) + : undefined; message.vectorDistance = object.vectorDistance ?? undefined; message.vectors = object.vectors?.map((e) => Vectors.fromPartial(e)) || []; return message; @@ -2390,7 +2532,7 @@ export const NearIMUSearch = { }; function createBaseBM25(): BM25 { - return { query: "", properties: [] }; + return { query: "", properties: [], searchOperator: undefined }; } export const BM25 = { @@ -2401,6 +2543,9 @@ export const BM25 = { for (const v of message.properties) { writer.uint32(18).string(v!); } + if (message.searchOperator !== undefined) { + SearchOperatorOptions.encode(message.searchOperator, writer.uint32(26).fork()).ldelim(); + } return writer; }, @@ -2425,6 +2570,13 @@ export const BM25 = { message.properties.push(reader.string()); continue; + case 3: + if (tag !== 26) { + break; + } + + message.searchOperator = SearchOperatorOptions.decode(reader, reader.uint32()); + continue; } if ((tag & 7) === 4 || tag === 0) { break; @@ -2440,6 +2592,7 @@ export const BM25 = { properties: globalThis.Array.isArray(object?.properties) ? object.properties.map((e: any) => globalThis.String(e)) : [], + searchOperator: isSet(object.searchOperator) ? SearchOperatorOptions.fromJSON(object.searchOperator) : undefined, }; }, @@ -2451,6 +2604,9 @@ export const BM25 = { if (message.properties?.length) { obj.properties = message.properties; } + if (message.searchOperator !== undefined) { + obj.searchOperator = SearchOperatorOptions.toJSON(message.searchOperator); + } return obj; }, @@ -2461,6 +2617,9 @@ export const BM25 = { const message = createBaseBM25(); message.query = object.query ?? ""; message.properties = object.properties?.map((e) => e) || []; + message.searchOperator = (object.searchOperator !== undefined && object.searchOperator !== null) + ? SearchOperatorOptions.fromPartial(object.searchOperator) + : undefined; return message; }, };