Skip to content

Commit 39682b5

Browse files
committed
Fixed tests
1 parent caa52e3 commit 39682b5

File tree

4 files changed

+106
-19
lines changed

4 files changed

+106
-19
lines changed

src/neo4j_graphrag/generation/graphrag.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import logging
1818
import warnings
19-
from typing import Any, Optional
19+
from typing import Any, List, Optional, Union
2020

2121
from pydantic import ValidationError
2222

@@ -28,6 +28,7 @@
2828
from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel
2929
from neo4j_graphrag.llm import LLMInterface
3030
from neo4j_graphrag.llm.types import LLMMessage
31+
from neo4j_graphrag.message_history import InMemoryMessageHistory, MessageHistory
3132
from neo4j_graphrag.retrievers.base import Retriever
3233
from neo4j_graphrag.types import RetrieverResult
3334

@@ -84,7 +85,7 @@ def __init__(
8485
def search(
8586
self,
8687
query_text: str = "",
87-
message_history: Optional[list[LLMMessage]] = None,
88+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
8889
examples: str = "",
8990
retriever_config: Optional[dict[str, Any]] = None,
9091
return_context: bool | None = None,
@@ -127,6 +128,8 @@ def search(
127128
)
128129
except ValidationError as e:
129130
raise SearchValidationError(e.errors())
131+
if isinstance(message_history, list):
132+
message_history = InMemoryMessageHistory(messages=message_history)
130133
query = self.build_query(validated_data.query_text, message_history)
131134
retriever_result: RetrieverResult = self.retriever.search(
132135
query_text=query, **validated_data.retriever_config
@@ -139,7 +142,7 @@ def search(
139142
logger.debug(f"RAG: prompt={prompt}")
140143
answer = self.llm.invoke(
141144
prompt,
142-
message_history,
145+
message_history.messages if message_history else None,
143146
system_instruction=self.prompt_template.system_instructions,
144147
)
145148
result: dict[str, Any] = {"answer": answer.content}
@@ -148,10 +151,14 @@ def search(
148151
return RagResultModel(**result)
149152

150153
def build_query(
151-
self, query_text: str, message_history: Optional[list[LLMMessage]] = None
154+
self,
155+
query_text: str,
156+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
152157
) -> str:
153158
summary_system_message = "You are a summarization assistant. Summarize the given text in no more than 300 words."
154159
if message_history:
160+
if isinstance(message_history, list):
161+
message_history = InMemoryMessageHistory(messages=message_history)
155162
summarization_prompt = self.chat_summary_prompt(
156163
message_history=message_history
157164
)
@@ -162,10 +169,14 @@ def build_query(
162169
return self.conversation_prompt(summary=summary, current_query=query_text)
163170
return query_text
164171

165-
def chat_summary_prompt(self, message_history: list[LLMMessage]) -> str:
172+
def chat_summary_prompt(
173+
self, message_history: Union[List[LLMMessage], MessageHistory]
174+
) -> str:
175+
if isinstance(message_history, list):
176+
message_history = InMemoryMessageHistory(messages=message_history)
166177
message_list = [
167178
": ".join([f"{value}" for _, value in message.items()])
168-
for message in message_history
179+
for message in message_history.messages
169180
]
170181
history = "\n".join(message_list)
171182
return f"""

src/neo4j_graphrag/message_history.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ def clear(self) -> None: ...
5757

5858

5959
class InMemoryMessageHistory(MessageHistory):
60-
def __init__(self, messages: List[LLMMessage] = []) -> None:
61-
self._messages = messages
60+
def __init__(self, messages: Optional[List[LLMMessage]] = None) -> None:
61+
self._messages = messages or []
6262

6363
@property
6464
def messages(self) -> List[LLMMessage]:

tests/unit/test_graphrag.py

Lines changed: 84 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from neo4j_graphrag.generation.prompts import RagTemplate
2222
from neo4j_graphrag.generation.types import RagResultModel
2323
from neo4j_graphrag.llm import LLMResponse
24+
from neo4j_graphrag.llm.types import LLMMessage
25+
from neo4j_graphrag.message_history import InMemoryMessageHistory
2426
from neo4j_graphrag.types import RetrieverResult, RetrieverResultItem
2527

2628

@@ -114,14 +116,14 @@ def test_graphrag_happy_path_with_message_history(
114116
question
115117
"""
116118

117-
first_invokation_input = """
119+
first_invocation_input = """
118120
Summarize the message history:
119121
120122
user: initial question
121123
assistant: answer to initial question
122124
"""
123-
first_invokation_system_instruction = "You are a summarization assistant. Summarize the given text in no more than 300 words."
124-
second_invokation = """Context:
125+
first_invocation_system_instruction = "You are a summarization assistant. Summarize the given text in no more than 300 words."
126+
second_invocation = """Context:
125127
item content 1
126128
item content 2
127129
@@ -141,11 +143,11 @@ def test_graphrag_happy_path_with_message_history(
141143
llm.invoke.assert_has_calls(
142144
[
143145
call(
144-
input=first_invokation_input,
145-
system_instruction=first_invokation_system_instruction,
146+
input=first_invocation_input,
147+
system_instruction=first_invocation_system_instruction,
146148
),
147149
call(
148-
second_invokation,
150+
second_invocation,
149151
message_history,
150152
system_instruction="Answer the user question using the provided context.",
151153
),
@@ -157,6 +159,82 @@ def test_graphrag_happy_path_with_message_history(
157159
assert res.retriever_result is None
158160

159161

162+
def test_graphrag_happy_path_with_in_memory_message_history(
163+
retriever_mock: MagicMock, llm: MagicMock
164+
) -> None:
165+
rag = GraphRAG(
166+
retriever=retriever_mock,
167+
llm=llm,
168+
)
169+
retriever_mock.search.return_value = RetrieverResult(
170+
items=[
171+
RetrieverResultItem(content="item content 1"),
172+
RetrieverResultItem(content="item content 2"),
173+
]
174+
)
175+
llm.invoke.side_effect = [
176+
LLMResponse(content="llm generated summary"),
177+
LLMResponse(content="llm generated text"),
178+
]
179+
message_history = InMemoryMessageHistory(
180+
messages=[
181+
LLMMessage(role="user", content="initial question"),
182+
LLMMessage(role="assistant", content="answer to initial question"),
183+
]
184+
)
185+
res = rag.search("question", message_history)
186+
187+
expected_retriever_query_text = """
188+
Message Summary:
189+
llm generated summary
190+
191+
Current Query:
192+
question
193+
"""
194+
195+
first_invocation_input = """
196+
Summarize the message history:
197+
198+
user: initial question
199+
assistant: answer to initial question
200+
"""
201+
first_invocation_system_instruction = "You are a summarization assistant. Summarize the given text in no more than 300 words."
202+
second_invocation = """Context:
203+
item content 1
204+
item content 2
205+
206+
Examples:
207+
208+
209+
Question:
210+
question
211+
212+
Answer:
213+
"""
214+
215+
retriever_mock.search.assert_called_once_with(
216+
query_text=expected_retriever_query_text
217+
)
218+
assert llm.invoke.call_count == 2
219+
llm.invoke.assert_has_calls(
220+
[
221+
call(
222+
input=first_invocation_input,
223+
system_instruction=first_invocation_system_instruction,
224+
),
225+
call(
226+
second_invocation,
227+
message_history.messages,
228+
system_instruction="Answer the user question using the provided context.",
229+
),
230+
]
231+
)
232+
233+
assert isinstance(res, RagResultModel)
234+
assert res.answer == "llm generated text"
235+
assert res.retriever_result is None
236+
237+
160238
def test_graphrag_happy_path_custom_system_instruction(
161239
retriever_mock: MagicMock, llm: MagicMock
162240
) -> None:

tests/unit/test_message_history.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,7 @@ def test_neo4j_message_history_invalid_session_id(driver: MagicMock) -> None:
6262
def test_neo4j_message_history_invalid_driver() -> None:
6363
with pytest.raises(ValidationError) as exc_info:
6464
Neo4jMessageHistory(session_id="123", driver=1.5, node_label="123", window=1) # type: ignore[arg-type]
65-
assert "Input should be a valid dictionary or instance of Neo4jDriver" in str(
66-
exc_info.value
67-
)
65+
assert "Input should be an instance of Driver" in str(exc_info.value)
6866

6967

7068
def test_neo4j_message_history_invalid_node_label(driver: MagicMock) -> None:
@@ -81,8 +79,8 @@ def test_neo4j_message_history_invalid_window(driver: MagicMock) -> None:
8179
assert "Input should be greater than 0" in str(exc_info.value)
8280

8381

84-
def test_neo4j_message_history_messages_setter(neo4j_driver: MagicMock) -> None:
85-
message_history = Neo4jMessageHistory(session_id="123", driver=neo4j_driver)
82+
def test_neo4j_message_history_messages_setter(driver: MagicMock) -> None:
83+
message_history = Neo4jMessageHistory(session_id="123", driver=driver)
8684
with pytest.raises(NotImplementedError) as exc_info:
8785
message_history.messages = [
8886
LLMMessage(role="user", content="may thy knife chip and shatter"),

0 commit comments

Comments
 (0)