diff --git a/CHANGELOG.md b/CHANGELOG.md index 378ec3c84..a8bad9486 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ ### Added - Added support for automatic schema extraction from text using LLMs. In the `SimpleKGPipeline`, when the user provides no schema, the automatic schema extraction is enabled by default. +- Added ability to return a user-defined message if context is empty in GraphRAG (which skips the LLM call). ### Fixed diff --git a/docs/source/user_guide_rag.rst b/docs/source/user_guide_rag.rst index 42859c41c..1ad76ef91 100644 --- a/docs/source/user_guide_rag.rst +++ b/docs/source/user_guide_rag.rst @@ -1,15 +1,15 @@ .. _user-guide-rag: User Guide: RAG -################# +############### This guide provides a starting point for using the Neo4j GraphRAG package and configuring it according to specific requirements. -************ +********** Quickstart -************ +********** To perform a GraphRAG query using the `neo4j-graphrag` package, a few components are needed: @@ -63,9 +63,9 @@ In practice, it's done with only a few lines of code: The following sections provide more details about how to customize this code. -****************************** +********************** GraphRAG Configuration -****************************** +********************** Each component can be configured individually: the LLM and the prompt. @@ -775,7 +775,7 @@ See :ref:`hybridcypherretriever`. .. _text2cypher-retriever-user-guide: Text2Cypher Retriever ------------------------------------- +--------------------- This retriever first asks an LLM to generate a Cypher query to fetch the exact information required to answer the question from the database. Then this query is @@ -853,7 +853,7 @@ See :ref:`text2cypherretriever`. .. _custom-retriever: Custom Retriever -=================== +================ If the application requires very specific retrieval strategy, it is possible to implement a custom retriever using the `Retriever` interface: @@ -883,14 +883,60 @@ a custom retriever using the `Retriever` interface: See :ref:`rawsearchresult` for a description of the returned type. -****************************** +*********************** +GraphRAG search options +*********************** + +Return context +============== + +By default, the search method only returns the final answer. It is possible to see +what was retrieved as part of the context by setting `return_context=True`: + +.. code:: python + + rag.search("my question", return_context=True) + + +Return a custom message if context is empty +=========================================== + +If the retriever is not able to find any context, the LLM will return an answer anyway. +It is possible to skip the LLM call if the context is empty and return a user-defined message +instead: + +.. code:: python + + rag.search( + "my question", + response_fallback="I can not answer this question because I have no relevant context." + ) + + +Pass configuration to the retriever search method +================================================= + +The retrievers search method have a bunch of configuration options (see above), +which can also be configured through the GraphRAG search method using the `retriever_config` +argument. For instance, the following code snippet illustrates how to define the `top_k` +parameter for the retriever: + +.. code:: python + + rag.search( + "my question", + retriever_config={"top_k": 2} + ) + + +************** DB Operations -****************************** +************** See :ref:`database-interaction-section`. Create a Vector Index -======================== +===================== .. code:: python @@ -918,7 +964,7 @@ Create a Vector Index Populate a Vector Index -========================== +======================= .. code:: python @@ -950,7 +996,7 @@ This property will also be added to the vector index. Drop a Vector Index -======================== +=================== .. warning:: diff --git a/examples/question_answering/graphrag.py b/examples/question_answering/graphrag.py index 25186e949..f1cb935de 100644 --- a/examples/question_answering/graphrag.py +++ b/examples/question_answering/graphrag.py @@ -56,6 +56,8 @@ def formatter(record: neo4j.Record) -> RetrieverResultItem: result = rag.search( "Tell me more about Avatar movies", return_context=True, + # optional + response_fallback="I can't answer this question without context", ) print(result.answer) # print(result.retriever_result) diff --git a/src/neo4j_graphrag/generation/graphrag.py b/src/neo4j_graphrag/generation/graphrag.py index 4f0fcdbbc..3e649cc13 100644 --- a/src/neo4j_graphrag/generation/graphrag.py +++ b/src/neo4j_graphrag/generation/graphrag.py @@ -89,6 +89,7 @@ def search( examples: str = "", retriever_config: Optional[dict[str, Any]] = None, return_context: bool | None = None, + response_fallback: str | None = None, ) -> RagResultModel: """ .. warning:: @@ -109,6 +110,7 @@ def search( retriever_config (Optional[dict]): Parameters passed to the retriever. search method; e.g.: top_k return_context (bool): Whether to append the retriever result to the final result (default: False). + response_fallback (Optional[str]): If not null, will return this message instead of calling the LLM if context comes back empty. Returns: RagResultModel: The LLM-generated answer. @@ -126,6 +128,7 @@ def search( examples=examples, retriever_config=retriever_config or {}, return_context=return_context, + response_fallback=response_fallback, ) except ValidationError as e: raise SearchValidationError(e.errors()) @@ -135,18 +138,22 @@ def search( retriever_result: RetrieverResult = self.retriever.search( query_text=query, **validated_data.retriever_config ) - context = "\n".join(item.content for item in retriever_result.items) - prompt = self.prompt_template.format( - query_text=query_text, context=context, examples=validated_data.examples - ) - logger.debug(f"RAG: retriever_result={prettify(retriever_result)}") - logger.debug(f"RAG: prompt={prompt}") - answer = self.llm.invoke( - prompt, - message_history, - system_instruction=self.prompt_template.system_instructions, - ) - result: dict[str, Any] = {"answer": answer.content} + if len(retriever_result.items) == 0 and response_fallback is not None: + answer = response_fallback + else: + context = "\n".join(item.content for item in retriever_result.items) + prompt = self.prompt_template.format( + query_text=query_text, context=context, examples=validated_data.examples + ) + logger.debug(f"RAG: retriever_result={prettify(retriever_result)}") + logger.debug(f"RAG: prompt={prompt}") + llm_response = self.llm.invoke( + prompt, + message_history, + system_instruction=self.prompt_template.system_instructions, + ) + answer = llm_response.content + result: dict[str, Any] = {"answer": answer} if return_context: result["retriever_result"] = retriever_result return RagResultModel(**result) diff --git a/src/neo4j_graphrag/generation/types.py b/src/neo4j_graphrag/generation/types.py index 3d9852d29..a03983c27 100644 --- a/src/neo4j_graphrag/generation/types.py +++ b/src/neo4j_graphrag/generation/types.py @@ -43,6 +43,7 @@ class RagSearchModel(BaseModel): examples: str = "" retriever_config: dict[str, Any] = {} return_context: bool = False + response_fallback: str | None = None class RagResultModel(BaseModel): diff --git a/tests/unit/test_graphrag.py b/tests/unit/test_graphrag.py index d90e7904d..925b48b78 100644 --- a/tests/unit/test_graphrag.py +++ b/tests/unit/test_graphrag.py @@ -59,9 +59,9 @@ def test_graphrag_happy_path(retriever_mock: MagicMock, llm: MagicMock) -> None: ) llm.invoke.return_value = LLMResponse(content="llm generated text") - res = rag.search("question") + res = rag.search("question", retriever_config={"top_k": 111}) - retriever_mock.search.assert_called_once_with(query_text="question") + retriever_mock.search.assert_called_once_with(query_text="question", top_k=111) llm.invoke.assert_called_once_with( """Context: item content 1 @@ -263,6 +263,23 @@ def test_graphrag_happy_path_custom_system_instruction( assert res.answer == "llm generated text" +def test_graphrag_happy_path_response_fallback( + retriever_mock: MagicMock, llm: MagicMock +) -> None: + rag = GraphRAG( + retriever=retriever_mock, + llm=llm, + ) + retriever_mock.search.return_value = RetrieverResult(items=[]) + res = rag.search( + "question", + response_fallback="I can't answer this question without context", + ) + + assert llm.invoke.call_count == 0 + assert res.answer == "I can't answer this question without context" + + def test_graphrag_initialization_error(llm: MagicMock) -> None: with pytest.raises(RagInitializationError) as excinfo: GraphRAG(