Skip to content

Commit 7eb0326

Browse files
committed
Added test_graphrag_happy_path_with_neo4j_message_history
1 parent 39682b5 commit 7eb0326

File tree

2 files changed

+95
-5
lines changed

2 files changed

+95
-5
lines changed

src/neo4j_graphrag/generation/graphrag.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ def search(
103103
104104
Args:
105105
query_text (str): The user question.
106-
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
106+
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
107+
with each message having a specific role assigned.
107108
examples (str): Examples added to the LLM prompt.
108109
retriever_config (Optional[dict]): Parameters passed to the retriever.
109110
search method; e.g.: top_k
@@ -175,7 +176,7 @@ def chat_summary_prompt(
175176
if isinstance(message_history, list):
176177
message_history = InMemoryMessageHistory(messages=message_history)
177178
message_list = [
178-
": ".join([f"{value}" for _, value in message.items()])
179+
f"{message['role']}: {message['content']}"
179180
for message in message_history.messages
180181
]
181182
history = "\n".join(message_list)

tests/e2e/test_graphrag_e2e.py

Lines changed: 92 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,18 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from unittest.mock import MagicMock
16+
from unittest.mock import MagicMock, call
1717

1818
import neo4j
1919
import pytest
2020
from neo4j_graphrag.exceptions import LLMGenerationError
2121
from neo4j_graphrag.generation.graphrag import GraphRAG
2222
from neo4j_graphrag.generation.types import RagResultModel
2323
from neo4j_graphrag.llm import LLMResponse
24+
from neo4j_graphrag.llm.types import LLMMessage
25+
from neo4j_graphrag.message_history import Neo4jMessageHistory
2426
from neo4j_graphrag.retrievers import VectorCypherRetriever
25-
from neo4j_graphrag.types import RetrieverResult
27+
from neo4j_graphrag.types import RetrieverResult, RetrieverResultItem
2628

2729
from tests.e2e.conftest import BiologyEmbedder
2830
from tests.e2e.utils import build_data_objects, populate_neo4j
@@ -79,6 +81,93 @@ def test_graphrag_happy_path(
7981
assert result.retriever_result is None
8082

8183

84+
@pytest.mark.usefixtures("populate_neo4j_db")
85+
def test_graphrag_happy_path_with_neo4j_message_history(
86+
retriever_mock: MagicMock,
87+
llm: MagicMock,
88+
driver: neo4j.Driver,
89+
) -> None:
90+
rag = GraphRAG(
91+
retriever=retriever_mock,
92+
llm=llm,
93+
)
94+
retriever_mock.search.return_value = RetrieverResult(
95+
items=[
96+
RetrieverResultItem(content="item content 1"),
97+
RetrieverResultItem(content="item content 2"),
98+
]
99+
)
100+
llm.invoke.side_effect = [
101+
LLMResponse(content="llm generated summary"),
102+
LLMResponse(content="llm generated text"),
103+
]
104+
message_history = Neo4jMessageHistory(
105+
driver=driver,
106+
session_id="123",
107+
node_label="Message",
108+
)
109+
message_history.clear()
110+
message_history.add_messages(
111+
messages=[
112+
LLMMessage(role="user", content="initial question"),
113+
LLMMessage(role="assistant", content="answer to initial question"),
114+
]
115+
)
116+
res = rag.search(
117+
query_text="question",
118+
message_history=message_history,
119+
)
120+
expected_retriever_query_text = """
121+
Message Summary:
122+
llm generated summary
123+
124+
Current Query:
125+
question
126+
"""
127+
128+
first_invocation_input = """
129+
Summarize the message history:
130+
131+
user: initial question
132+
assistant: answer to initial question
133+
"""
134+
first_invocation_system_instruction = "You are a summarization assistant. Summarize the given text in no more than 300 words."
135+
second_invocation = """Context:
136+
item content 1
137+
item content 2
138+
139+
Examples:
140+
141+
142+
Question:
143+
question
144+
145+
Answer:
146+
"""
147+
retriever_mock.search.assert_called_once_with(
148+
query_text=expected_retriever_query_text
149+
)
150+
assert llm.invoke.call_count == 2
151+
llm.invoke.assert_has_calls(
152+
[
153+
call(
154+
input=first_invocation_input,
155+
system_instruction=first_invocation_system_instruction,
156+
),
157+
call(
158+
second_invocation,
159+
message_history.messages,
160+
system_instruction="Answer the user question using the provided context.",
161+
),
162+
]
163+
)
164+
165+
assert isinstance(res, RagResultModel)
166+
assert res.answer == "llm generated text"
167+
assert res.retriever_result is None
168+
message_history.clear()
169+
170+
82171
@pytest.mark.usefixtures("populate_neo4j_db")
83172
def test_graphrag_happy_path_return_context(
84173
driver: MagicMock, llm: MagicMock, biology_embedder: BiologyEmbedder
@@ -127,7 +216,7 @@ def test_graphrag_happy_path_return_context(
127216

128217
@pytest.mark.usefixtures("populate_neo4j_db")
129218
def test_graphrag_happy_path_examples(
130-
driver: MagicMock, llm: MagicMock, biology_embedder: BiologyEmbedder
219+
driver: MagicMock, llm: MagicMock, biology_embedder: MagicMock
131220
) -> None:
132221
retriever = VectorCypherRetriever(
133222
driver,

0 commit comments

Comments
 (0)