Skip to content

Commit 0c2e8cb

Browse files
authored
Add Weighted RRF to Azure Cosmos DB Python SDK (#40899)
* Add Weighted Reciprocal Rank Fusion to Python Cosmos DB Adds weighted reciprocal rank fusion to the python sdk. * Update hybrid_search_aggregator.py * Update CHANGELOG.md * fix pylint * Update README.md * added query optimizations as well * add valid vector search test for wrrf * update quick fixes * pylint fix * Update CHANGELOG.md * update changes * review requested changes * Update CHANGELOG.md * updates should fix deprecration of list in full text score * Update test_query_hybrid_search.py * Update CHANGELOG.md * Update tests and Readme * Update _version.py
1 parent 2b293f3 commit 0c2e8cb

File tree

10 files changed

+379
-53
lines changed

10 files changed

+379
-53
lines changed

sdk/cosmos/azure-cosmos/CHANGELOG.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
## Release History
22

3-
### 4.11.0b2 (Unreleased)
3+
### 4.12.0b1 (unreleased)
44

55
#### Features Added
6-
7-
#### Breaking Changes
6+
* Added ability to use weighted RRF (Reciprocal Rank Fusion) for Hybrid full text search queries. See [PR 40899](https://github.com/Azure/azure-sdk-for-python/pull/40899/files).
87

98
#### Bugs Fixed
109
* Fixed Diagnostics Error Log Formatting to handle error messages from non-CosmosHttpResponseExceptions. See [PR 40889](https://github.com/Azure/azure-sdk-for-python/pull/40889/files)
1110
* Fixed bug where `multiple_write_locations` option in client was not being honored. See [PR 40999](https://github.com/Azure/azure-sdk-for-python/pull/40999).
1211

12+
#### Breaking Changes
13+
1314
#### Other Changes
1415

1516
### 4.11.0b1 (2025-04-30)

sdk/cosmos/azure-cosmos/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,13 @@ All of these mentioned queries would look something like this:
867867

868868
- `SELECT TOP 10 c.id, c.text FROM c ORDER BY RANK RRF(FullTextScore(c.text, ['quantum', 'theory']), FullTextScore(c.text, ['model']), VectorDistance(c.embedding, {item_embedding}))"`
869869

870+
You can also use Weighted Reciprocal Rank Fusion to assign different weights to the different scores being used in the RRF function.
871+
This is done by passing in a list of weights to the RRF function in the query. **NOTE: If more weights are given than there are components of the RRF function, or if weights are missing a BAD REQUEST exception will occur.**
872+
- `SELECT TOP 10 c.id, c.text FROM c ORDER BY RANK RRF(FullTextScore(c.text, ['quantum', 'theory']), FullTextScore(c.text, ['model']), VectorDistance(c.embedding, {item_embedding}), [0.5, 0.3, 0.2])`
873+
874+
875+
- `SELECT TOP 10 c.id, c.text FROM c ORDER BY RANK RRF(FullTextScore(c.text, ['quantum', 'theory']), FullTextScore(c.text, ['model']), VectorDistance(c.embedding, {item_embedding}), [-0.5, 0.3, 0.2])`
876+
870877
These queries must always use a TOP or LIMIT clause within the query since hybrid search queries have to look through a lot of data otherwise and may become too expensive or long-running.
871878
Since these queries are relatively expensive, the SDK sets a default limit of 1000 max items per query - if you'd like to raise that further, you
872879
can use the `AZURE_COSMOS_HYBRID_SEARCH_MAX_ITEMS` environment variable to do so. However, be advised that queries with too many vector results

sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3196,7 +3196,8 @@ def _GetQueryPlanThroughGateway(self, query: str, resource_link: str, **kwargs:
31963196
documents._QueryFeature.Top + "," +
31973197
documents._QueryFeature.NonStreamingOrderBy + "," +
31983198
documents._QueryFeature.HybridSearch + "," +
3199-
documents._QueryFeature.CountIf)
3199+
documents._QueryFeature.CountIf + "," +
3200+
documents._QueryFeature.WeightedRankFusion)
32003201
if os.environ.get(Constants.NON_STREAMING_ORDER_BY_DISABLED_CONFIG,
32013202
Constants.NON_STREAMING_ORDER_BY_DISABLED_CONFIG_DEFAULT) == "True":
32023203
supported_query_features = (documents._QueryFeature.Aggregate + "," +

sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/hybrid_search_aggregator.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(self, client, resource_link, options, partitioned_query_execution_i
5959
self._document_producer_comparator = None
6060
self._response_hook = response_hook
6161

62-
async def _run_hybrid_search(self):
62+
async def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-many-statements
6363
# Check if we need to run global statistics queries, and if so do for every partition in the container
6464
if self._hybrid_search_query_info['requiresGlobalStatistics']:
6565
target_partition_key_ranges = await self._get_target_partition_key_range(target_all_ranges=True)
@@ -147,21 +147,32 @@ async def _run_hybrid_search(self):
147147
self._format_final_results(drained_results)
148148
return
149149

150+
# Get the Components weights if any
151+
if self._hybrid_search_query_info.get('componentWeights'):
152+
component_weights = self._hybrid_search_query_info['componentWeights']
153+
else:
154+
# If no weights are provided, we default to 1.0 for all components
155+
component_weights = [1.0] * len(self._hybrid_search_query_info['componentQueryInfos'])
156+
150157
# Sort drained results by _rid
151158
drained_results.sort(key=lambda x: x['_rid'])
152159

153160
# Compose component scores matrix, where each tuple is (score, index)
154161
component_scores = _retrieve_component_scores(drained_results)
155162

156-
# Sort by scores in descending order
157-
for score_tuples in component_scores:
158-
score_tuples.sort(key=lambda x: x[0], reverse=True)
163+
# Sort by scores using component weights
164+
for index, score_tuples in enumerate(component_scores):
165+
# Negative Weights will change sorting from Descending to Ascending
166+
ordering = self._hybrid_search_query_info['componentQueryInfos'][index]['orderBy'][0]
167+
comparison_factor = not ordering.lower() == 'ascending'
168+
# pylint: disable=cell-var-from-loop
169+
score_tuples.sort(key=lambda x: x[0], reverse=comparison_factor)
159170

160171
# Compute the ranks
161172
ranks = _compute_ranks(component_scores)
162173

163174
# Compute the RRF scores and add them to output
164-
_compute_rrf_scores(ranks, drained_results)
175+
_compute_rrf_scores(ranks, component_weights, drained_results)
165176

166177
# Finally, sort on the RRF scores to build the final result to return
167178
drained_results.sort(key=lambda x: x['Score'], reverse=True)

sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/hybrid_search_aggregator.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
"""Internal class for multi execution context aggregator implementation in the Azure Cosmos database service.
55
"""
6-
6+
from typing import List, Union
77
from azure.cosmos._execution_context.base_execution_context import _QueryExecutionContextBase
88
from azure.cosmos._execution_context import document_producer
99
from azure.cosmos._routing import routing_range
@@ -36,12 +36,13 @@ def _retrieve_component_scores(drained_results):
3636
return component_scores_list
3737

3838

39-
def _compute_rrf_scores(ranks, query_results):
39+
def _compute_rrf_scores(ranks: List[List[int]], component_weights: List[Union[int, float]], query_results: List[dict]):
4040
component_count = len(ranks)
4141
for index, result in enumerate(query_results):
4242
rrf_score = 0.0
4343
for component_index in range(component_count):
44-
rrf_score += 1.0 / (RRF_CONSTANT + ranks[component_index][index])
44+
rrf_score += component_weights[component_index] / (RRF_CONSTANT + ranks[component_index][index])
45+
4546
# Add the score to the item to be returned
4647
result['Score'] = rrf_score
4748

@@ -54,7 +55,7 @@ def _compute_ranks(component_scores):
5455
rank = 1 # ranks are 1-based
5556
for index, score_tuple in enumerate(scores):
5657
# Identical scores should have the same rank
57-
if index > 0 and score_tuple[0] < scores[index - 1][0]:
58+
if index > 0 and score_tuple[0] != scores[index - 1][0]:
5859
rank += 1
5960
ranks[component_index][score_tuple[1]] = rank
6061

@@ -164,7 +165,7 @@ def __init__(self, client, resource_link, options,
164165
self._document_producer_comparator = None
165166
self._response_hook = response_hook
166167

167-
def _run_hybrid_search(self):
168+
def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-many-statements
168169
# Check if we need to run global statistics queries, and if so do for every partition in the container
169170
if self._hybrid_search_query_info['requiresGlobalStatistics']:
170171
target_partition_key_ranges = self._get_target_partition_key_range(target_all_ranges=True)
@@ -251,21 +252,33 @@ def _run_hybrid_search(self):
251252
self._format_final_results(drained_results)
252253
return
253254

255+
# Get the Components weight if any
256+
if self._hybrid_search_query_info.get('componentWeights'):
257+
component_weights = self._hybrid_search_query_info['componentWeights']
258+
else:
259+
# If no weights are provided, we assume all components have equal weight
260+
component_weights = [1.0] * len(self._hybrid_search_query_info['componentQueryInfos'])
261+
254262
# Sort drained results by _rid
255263
drained_results.sort(key=lambda x: x['_rid'])
256264

257265
# Compose component scores matrix, where each tuple is (score, index)
258266
component_scores = _retrieve_component_scores(drained_results)
259267

260-
# Sort by scores in descending order
261-
for score_tuples in component_scores:
262-
score_tuples.sort(key=lambda x: x[0], reverse=True)
268+
# Sort by scores using component weights
269+
for index, score_tuples in enumerate(component_scores):
270+
# Ordering of the component query is based on if the weight is negative or positive
271+
# A positive weight ordering means descending order, a negative weight ordering means ascending order
272+
ordering = self._hybrid_search_query_info['componentQueryInfos'][index]['orderBy'][0]
273+
comparison_factor = not ordering.lower() == 'ascending'
274+
# pylint: disable=cell-var-from-loop
275+
score_tuples.sort(key=lambda x: x[0], reverse=comparison_factor)
263276

264277
# Compute the ranks
265278
ranks = _compute_ranks(component_scores)
266279

267280
# Compute the RRF scores and add them to output
268-
_compute_rrf_scores(ranks, drained_results)
281+
_compute_rrf_scores(ranks, component_weights, drained_results)
269282

270283
# Finally, sort on the RRF scores to build the final result to return
271284
drained_results.sort(key=lambda x: x['Score'], reverse=True)

sdk/cosmos/azure-cosmos/azure/cosmos/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@
1919
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2020
# SOFTWARE.
2121

22-
VERSION = "4.11.0b2"
22+
VERSION = "4.12.0b1"

sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3207,7 +3207,8 @@ async def _GetQueryPlanThroughGateway(self, query: str, resource_link: str, **kw
32073207
documents._QueryFeature.Top + "," +
32083208
documents._QueryFeature.NonStreamingOrderBy + "," +
32093209
documents._QueryFeature.HybridSearch + "," +
3210-
documents._QueryFeature.CountIf)
3210+
documents._QueryFeature.CountIf + "," +
3211+
documents._QueryFeature.WeightedRankFusion)
32113212
if os.environ.get(Constants.NON_STREAMING_ORDER_BY_DISABLED_CONFIG,
32123213
Constants.NON_STREAMING_ORDER_BY_DISABLED_CONFIG_DEFAULT) == "True":
32133214
supported_query_features = (documents._QueryFeature.Aggregate + "," +

sdk/cosmos/azure-cosmos/azure/cosmos/documents.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ class _QueryFeature:
438438
NonStreamingOrderBy: Literal["NonStreamingOrderBy"] = "NonStreamingOrderBy"
439439
HybridSearch: Literal["HybridSearch"] = "HybridSearch"
440440
CountIf: Literal["CountIf"] = "CountIf"
441-
441+
WeightedRankFusion: Literal["WeightedRankFusion"] = "WeightedRankFusion"
442442

443443
class _DistinctType:
444444
NoneType: Literal["None"] = "None"

0 commit comments

Comments
 (0)