Skip to content

Commit 2686329

Browse files
authored
Return a user-defined message if context is empty in GraphRAG (#343)
* Return a user-defined message if context is empty in GraphRAG * Rename parameter
1 parent 8b3af81 commit 2686329

File tree

6 files changed

+100
-26
lines changed

6 files changed

+100
-26
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
### Added
66

77
- Added support for automatic schema extraction from text using LLMs. In the `SimpleKGPipeline`, when the user provides no schema, the automatic schema extraction is enabled by default.
8+
- Added ability to return a user-defined message if context is empty in GraphRAG (which skips the LLM call).
89

910
### Fixed
1011

docs/source/user_guide_rag.rst

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
.. _user-guide-rag:
22

33
User Guide: RAG
4-
#################
4+
###############
55

66
This guide provides a starting point for using the Neo4j GraphRAG package
77
and configuring it according to specific requirements.
88

99

10-
************
10+
**********
1111
Quickstart
12-
************
12+
**********
1313

1414
To perform a GraphRAG query using the `neo4j-graphrag` package, a few components are needed:
1515

@@ -63,9 +63,9 @@ In practice, it's done with only a few lines of code:
6363

6464
The following sections provide more details about how to customize this code.
6565

66-
******************************
66+
**********************
6767
GraphRAG Configuration
68-
******************************
68+
**********************
6969

7070
Each component can be configured individually: the LLM and the prompt.
7171

@@ -775,7 +775,7 @@ See :ref:`hybridcypherretriever`.
775775
.. _text2cypher-retriever-user-guide:
776776

777777
Text2Cypher Retriever
778-
------------------------------------
778+
---------------------
779779

780780
This retriever first asks an LLM to generate a Cypher query to fetch the exact
781781
information required to answer the question from the database. Then this query is
@@ -853,7 +853,7 @@ See :ref:`text2cypherretriever`.
853853
.. _custom-retriever:
854854

855855
Custom Retriever
856-
===================
856+
================
857857

858858
If the application requires very specific retrieval strategy, it is possible to implement
859859
a custom retriever using the `Retriever` interface:
@@ -883,14 +883,60 @@ a custom retriever using the `Retriever` interface:
883883
See :ref:`rawsearchresult` for a description of the returned type.
884884

885885

886-
******************************
886+
***********************
887+
GraphRAG search options
888+
***********************
889+
890+
Return context
891+
==============
892+
893+
By default, the search method only returns the final answer. It is possible to see
894+
what was retrieved as part of the context by setting `return_context=True`:
895+
896+
.. code:: python
897+
898+
rag.search("my question", return_context=True)
899+
900+
901+
Return a custom message if context is empty
902+
===========================================
903+
904+
If the retriever is not able to find any context, the LLM will return an answer anyway.
905+
It is possible to skip the LLM call if the context is empty and return a user-defined message
906+
instead:
907+
908+
.. code:: python
909+
910+
rag.search(
911+
"my question",
912+
response_fallback="I can not answer this question because I have no relevant context."
913+
)
914+
915+
916+
Pass configuration to the retriever search method
917+
=================================================
918+
919+
The retrievers search method have a bunch of configuration options (see above),
920+
which can also be configured through the GraphRAG search method using the `retriever_config`
921+
argument. For instance, the following code snippet illustrates how to define the `top_k`
922+
parameter for the retriever:
923+
924+
.. code:: python
925+
926+
rag.search(
927+
"my question",
928+
retriever_config={"top_k": 2}
929+
)
930+
931+
932+
**************
887933
DB Operations
888-
******************************
934+
**************
889935

890936
See :ref:`database-interaction-section`.
891937

892938
Create a Vector Index
893-
========================
939+
=====================
894940

895941
.. code:: python
896942
@@ -918,7 +964,7 @@ Create a Vector Index
918964
919965
920966
Populate a Vector Index
921-
==========================
967+
=======================
922968

923969
.. code:: python
924970
@@ -950,7 +996,7 @@ This property will also be added to the vector index.
950996

951997

952998
Drop a Vector Index
953-
========================
999+
===================
9541000

9551001
.. warning::
9561002

examples/question_answering/graphrag.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def formatter(record: neo4j.Record) -> RetrieverResultItem:
5656
result = rag.search(
5757
"Tell me more about Avatar movies",
5858
return_context=True,
59+
# optional
60+
response_fallback="I can't answer this question without context",
5961
)
6062
print(result.answer)
6163
# print(result.retriever_result)

src/neo4j_graphrag/generation/graphrag.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def search(
8989
examples: str = "",
9090
retriever_config: Optional[dict[str, Any]] = None,
9191
return_context: bool | None = None,
92+
response_fallback: str | None = None,
9293
) -> RagResultModel:
9394
"""
9495
.. warning::
@@ -109,6 +110,7 @@ def search(
109110
retriever_config (Optional[dict]): Parameters passed to the retriever.
110111
search method; e.g.: top_k
111112
return_context (bool): Whether to append the retriever result to the final result (default: False).
113+
response_fallback (Optional[str]): If not null, will return this message instead of calling the LLM if context comes back empty.
112114
113115
Returns:
114116
RagResultModel: The LLM-generated answer.
@@ -126,6 +128,7 @@ def search(
126128
examples=examples,
127129
retriever_config=retriever_config or {},
128130
return_context=return_context,
131+
response_fallback=response_fallback,
129132
)
130133
except ValidationError as e:
131134
raise SearchValidationError(e.errors())
@@ -135,18 +138,22 @@ def search(
135138
retriever_result: RetrieverResult = self.retriever.search(
136139
query_text=query, **validated_data.retriever_config
137140
)
138-
context = "\n".join(item.content for item in retriever_result.items)
139-
prompt = self.prompt_template.format(
140-
query_text=query_text, context=context, examples=validated_data.examples
141-
)
142-
logger.debug(f"RAG: retriever_result={prettify(retriever_result)}")
143-
logger.debug(f"RAG: prompt={prompt}")
144-
answer = self.llm.invoke(
145-
prompt,
146-
message_history,
147-
system_instruction=self.prompt_template.system_instructions,
148-
)
149-
result: dict[str, Any] = {"answer": answer.content}
141+
if len(retriever_result.items) == 0 and response_fallback is not None:
142+
answer = response_fallback
143+
else:
144+
context = "\n".join(item.content for item in retriever_result.items)
145+
prompt = self.prompt_template.format(
146+
query_text=query_text, context=context, examples=validated_data.examples
147+
)
148+
logger.debug(f"RAG: retriever_result={prettify(retriever_result)}")
149+
logger.debug(f"RAG: prompt={prompt}")
150+
llm_response = self.llm.invoke(
151+
prompt,
152+
message_history,
153+
system_instruction=self.prompt_template.system_instructions,
154+
)
155+
answer = llm_response.content
156+
result: dict[str, Any] = {"answer": answer}
150157
if return_context:
151158
result["retriever_result"] = retriever_result
152159
return RagResultModel(**result)

src/neo4j_graphrag/generation/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class RagSearchModel(BaseModel):
4343
examples: str = ""
4444
retriever_config: dict[str, Any] = {}
4545
return_context: bool = False
46+
response_fallback: str | None = None
4647

4748

4849
class RagResultModel(BaseModel):

tests/unit/test_graphrag.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ def test_graphrag_happy_path(retriever_mock: MagicMock, llm: MagicMock) -> None:
5959
)
6060
llm.invoke.return_value = LLMResponse(content="llm generated text")
6161

62-
res = rag.search("question")
62+
res = rag.search("question", retriever_config={"top_k": 111})
6363

64-
retriever_mock.search.assert_called_once_with(query_text="question")
64+
retriever_mock.search.assert_called_once_with(query_text="question", top_k=111)
6565
llm.invoke.assert_called_once_with(
6666
"""Context:
6767
item content 1
@@ -263,6 +263,23 @@ def test_graphrag_happy_path_custom_system_instruction(
263263
assert res.answer == "llm generated text"
264264

265265

266+
def test_graphrag_happy_path_response_fallback(
267+
retriever_mock: MagicMock, llm: MagicMock
268+
) -> None:
269+
rag = GraphRAG(
270+
retriever=retriever_mock,
271+
llm=llm,
272+
)
273+
retriever_mock.search.return_value = RetrieverResult(items=[])
274+
res = rag.search(
275+
"question",
276+
response_fallback="I can't answer this question without context",
277+
)
278+
279+
assert llm.invoke.call_count == 0
280+
assert res.answer == "I can't answer this question without context"
281+
282+
266283
def test_graphrag_initialization_error(llm: MagicMock) -> None:
267284
with pytest.raises(RagInitializationError) as excinfo:
268285
GraphRAG(

0 commit comments

Comments
 (0)