Skip to content

Return a user-defined message if context is empty in GraphRAG #343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Added

- Added support for automatic schema extraction from text using LLMs. In the `SimpleKGPipeline`, when the user provides no schema, the automatic schema extraction is enabled by default.
- Added ability to return a user-defined message if context is empty in GraphRAG (which skips the LLM call).

### Fixed

Expand Down
70 changes: 58 additions & 12 deletions docs/source/user_guide_rag.rst
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
.. _user-guide-rag:

User Guide: RAG
#################
###############

This guide provides a starting point for using the Neo4j GraphRAG package
and configuring it according to specific requirements.


************
**********
Quickstart
************
**********

To perform a GraphRAG query using the `neo4j-graphrag` package, a few components are needed:

Expand Down Expand Up @@ -63,9 +63,9 @@ In practice, it's done with only a few lines of code:

The following sections provide more details about how to customize this code.

******************************
**********************
GraphRAG Configuration
******************************
**********************

Each component can be configured individually: the LLM and the prompt.

Expand Down Expand Up @@ -775,7 +775,7 @@ See :ref:`hybridcypherretriever`.
.. _text2cypher-retriever-user-guide:

Text2Cypher Retriever
------------------------------------
---------------------

This retriever first asks an LLM to generate a Cypher query to fetch the exact
information required to answer the question from the database. Then this query is
Expand Down Expand Up @@ -853,7 +853,7 @@ See :ref:`text2cypherretriever`.
.. _custom-retriever:

Custom Retriever
===================
================

If the application requires very specific retrieval strategy, it is possible to implement
a custom retriever using the `Retriever` interface:
Expand Down Expand Up @@ -883,14 +883,60 @@ a custom retriever using the `Retriever` interface:
See :ref:`rawsearchresult` for a description of the returned type.


******************************
***********************
GraphRAG search options
***********************

Return context
==============

By default, the search method only returns the final answer. It is possible to see
what was retrieved as part of the context by setting `return_context=True`:

.. code:: python

rag.search("my question", return_context=True)


Return a custom message if context is empty
===========================================

If the retriever is not able to find any context, the LLM will return an answer anyway.
It is possible to skip the LLM call if the context is empty and return a user-defined message
instead:

.. code:: python

rag.search(
"my question",
response_fallback="I can not answer this question because I have no relevant context."
)


Pass configuration to the retriever search method
=================================================

The retrievers search method have a bunch of configuration options (see above),
which can also be configured through the GraphRAG search method using the `retriever_config`
argument. For instance, the following code snippet illustrates how to define the `top_k`
parameter for the retriever:

.. code:: python

rag.search(
"my question",
retriever_config={"top_k": 2}
)


**************
DB Operations
******************************
**************

See :ref:`database-interaction-section`.

Create a Vector Index
========================
=====================

.. code:: python

Expand Down Expand Up @@ -918,7 +964,7 @@ Create a Vector Index


Populate a Vector Index
==========================
=======================

.. code:: python

Expand Down Expand Up @@ -950,7 +996,7 @@ This property will also be added to the vector index.


Drop a Vector Index
========================
===================

.. warning::

Expand Down
2 changes: 2 additions & 0 deletions examples/question_answering/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def formatter(record: neo4j.Record) -> RetrieverResultItem:
result = rag.search(
"Tell me more about Avatar movies",
return_context=True,
# optional
response_fallback="I can't answer this question without context",
)
print(result.answer)
# print(result.retriever_result)
Expand Down
31 changes: 19 additions & 12 deletions src/neo4j_graphrag/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def search(
examples: str = "",
retriever_config: Optional[dict[str, Any]] = None,
return_context: bool | None = None,
response_fallback: str | None = None,
) -> RagResultModel:
"""
.. warning::
Expand All @@ -109,6 +110,7 @@ def search(
retriever_config (Optional[dict]): Parameters passed to the retriever.
search method; e.g.: top_k
return_context (bool): Whether to append the retriever result to the final result (default: False).
response_fallback (Optional[str]): If not null, will return this message instead of calling the LLM if context comes back empty.

Returns:
RagResultModel: The LLM-generated answer.
Expand All @@ -126,6 +128,7 @@ def search(
examples=examples,
retriever_config=retriever_config or {},
return_context=return_context,
response_fallback=response_fallback,
)
except ValidationError as e:
raise SearchValidationError(e.errors())
Expand All @@ -135,18 +138,22 @@ def search(
retriever_result: RetrieverResult = self.retriever.search(
query_text=query, **validated_data.retriever_config
)
context = "\n".join(item.content for item in retriever_result.items)
prompt = self.prompt_template.format(
query_text=query_text, context=context, examples=validated_data.examples
)
logger.debug(f"RAG: retriever_result={prettify(retriever_result)}")
logger.debug(f"RAG: prompt={prompt}")
answer = self.llm.invoke(
prompt,
message_history,
system_instruction=self.prompt_template.system_instructions,
)
result: dict[str, Any] = {"answer": answer.content}
if len(retriever_result.items) == 0 and response_fallback is not None:
answer = response_fallback
else:
context = "\n".join(item.content for item in retriever_result.items)
prompt = self.prompt_template.format(
query_text=query_text, context=context, examples=validated_data.examples
)
logger.debug(f"RAG: retriever_result={prettify(retriever_result)}")
logger.debug(f"RAG: prompt={prompt}")
llm_response = self.llm.invoke(
prompt,
message_history,
system_instruction=self.prompt_template.system_instructions,
)
answer = llm_response.content
result: dict[str, Any] = {"answer": answer}
if return_context:
result["retriever_result"] = retriever_result
return RagResultModel(**result)
Expand Down
1 change: 1 addition & 0 deletions src/neo4j_graphrag/generation/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class RagSearchModel(BaseModel):
examples: str = ""
retriever_config: dict[str, Any] = {}
return_context: bool = False
response_fallback: str | None = None


class RagResultModel(BaseModel):
Expand Down
21 changes: 19 additions & 2 deletions tests/unit/test_graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def test_graphrag_happy_path(retriever_mock: MagicMock, llm: MagicMock) -> None:
)
llm.invoke.return_value = LLMResponse(content="llm generated text")

res = rag.search("question")
res = rag.search("question", retriever_config={"top_k": 111})

retriever_mock.search.assert_called_once_with(query_text="question")
retriever_mock.search.assert_called_once_with(query_text="question", top_k=111)
llm.invoke.assert_called_once_with(
"""Context:
item content 1
Expand Down Expand Up @@ -263,6 +263,23 @@ def test_graphrag_happy_path_custom_system_instruction(
assert res.answer == "llm generated text"


def test_graphrag_happy_path_response_fallback(
retriever_mock: MagicMock, llm: MagicMock
) -> None:
rag = GraphRAG(
retriever=retriever_mock,
llm=llm,
)
retriever_mock.search.return_value = RetrieverResult(items=[])
res = rag.search(
"question",
response_fallback="I can't answer this question without context",
)

assert llm.invoke.call_count == 0
assert res.answer == "I can't answer this question without context"


def test_graphrag_initialization_error(llm: MagicMock) -> None:
with pytest.raises(RagInitializationError) as excinfo:
GraphRAG(
Expand Down