Skip to content

Commit f9f71a8

Browse files
ThomasVitalemarkpollack
authored andcommitted
RAG - Document joiner should sort by score
The current implementation of ConcatenationDocumentJoiner should sort the final document list by score in descending order, so to keep the most relevant documents at the front of the list. This PR fixes that. Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
1 parent 46491c0 commit f9f71a8

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

spring-ai-rag/src/main/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoiner.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,6 +17,7 @@
1717
package org.springframework.ai.rag.retrieval.join;
1818

1919
import java.util.ArrayList;
20+
import java.util.Comparator;
2021
import java.util.List;
2122
import java.util.Map;
2223
import java.util.function.Function;
@@ -32,7 +33,8 @@
3233
/**
3334
* Combines documents retrieved based on multiple queries and from multiple data sources
3435
* by concatenating them into a single collection of documents. In case of duplicate
35-
* documents, the first occurrence is kept. The score of each document is kept as is.
36+
* documents, the first occurrence is kept. The score of each document is kept as is. The
37+
* result is a list of unique documents sorted by their score in descending order.
3638
*
3739
* @author Thomas Vitale
3840
* @since 1.0.0
@@ -54,7 +56,11 @@ public List<Document> join(Map<Query, List<List<Document>>> documentsForQuery) {
5456
.flatMap(List::stream)
5557
.flatMap(List::stream)
5658
.collect(Collectors.toMap(Document::getId, Function.identity(), (existing, duplicate) -> existing))
57-
.values());
59+
.values()
60+
.stream()
61+
.sorted(Comparator.comparingDouble((Document doc) -> doc.getScore() != null ? doc.getScore() : 0.0)
62+
.reversed())
63+
.toList());
5864
}
5965

6066
}

spring-ai-rag/src/test/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoinerTests.java

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import java.util.Map;
2222

2323
import org.junit.jupiter.api.Test;
24-
2524
import org.springframework.ai.document.Document;
2625
import org.springframework.ai.rag.Query;
2726

@@ -92,4 +91,27 @@ void whenDuplicatedDocumentsThenOnlyFirstOccurrenceIsKept() {
9291
assertThat(result).extracting(Document::getText).containsOnlyOnce("Content 2");
9392
}
9493

94+
@Test
95+
void shouldSortDocumentsByDescendingScore() {
96+
//@formatter:off
97+
DocumentJoiner documentJoiner = new ConcatenationDocumentJoiner();
98+
var documentsForQuery = new HashMap<Query, List<List<Document>>>();
99+
documentsForQuery.put(new Query("query1"), List.of(
100+
List.of(
101+
Document.builder().id("1").text("Content 1").score(0.81).build(),
102+
Document.builder().id("2").text("Content 2").score(0.83).build()),
103+
List.of(
104+
Document.builder().id("3").text("Content 3").score(null).build())));
105+
documentsForQuery.put(new Query("query2"), List.of(
106+
List.of(
107+
Document.builder().id("4").text("Content 4").score(0.85).build(),
108+
Document.builder().id("5").text("Content 5").score(0.77).build())));
109+
110+
List<Document> result = documentJoiner.join(documentsForQuery);
111+
112+
assertThat(result).hasSize(5);
113+
assertThat(result).extracting(Document::getId).containsExactly("4", "2", "1", "5", "3");
114+
//@formatter:on
115+
}
116+
95117
}

0 commit comments

Comments
 (0)