|
13 | 13 | # See the License for the specific language governing permissions and
|
14 | 14 | # limitations under the License.
|
15 | 15 |
|
16 |
| -from unittest.mock import MagicMock |
| 16 | +from unittest.mock import MagicMock, call |
17 | 17 |
|
18 | 18 | import neo4j
|
19 | 19 | import pytest
|
20 | 20 | from neo4j_graphrag.exceptions import LLMGenerationError
|
21 | 21 | from neo4j_graphrag.generation.graphrag import GraphRAG
|
22 | 22 | from neo4j_graphrag.generation.types import RagResultModel
|
23 | 23 | from neo4j_graphrag.llm import LLMResponse
|
| 24 | +from neo4j_graphrag.llm.types import LLMMessage |
| 25 | +from neo4j_graphrag.message_history import Neo4jMessageHistory |
24 | 26 | from neo4j_graphrag.retrievers import VectorCypherRetriever
|
25 |
| -from neo4j_graphrag.types import RetrieverResult |
| 27 | +from neo4j_graphrag.types import RetrieverResult, RetrieverResultItem |
26 | 28 |
|
27 | 29 | from tests.e2e.conftest import BiologyEmbedder
|
28 | 30 | from tests.e2e.utils import build_data_objects, populate_neo4j
|
@@ -79,6 +81,93 @@ def test_graphrag_happy_path(
|
79 | 81 | assert result.retriever_result is None
|
80 | 82 |
|
81 | 83 |
|
| 84 | +@pytest.mark.usefixtures("populate_neo4j_db") |
| 85 | +def test_graphrag_happy_path_with_neo4j_message_history( |
| 86 | + retriever_mock: MagicMock, |
| 87 | + llm: MagicMock, |
| 88 | + driver: neo4j.Driver, |
| 89 | +) -> None: |
| 90 | + rag = GraphRAG( |
| 91 | + retriever=retriever_mock, |
| 92 | + llm=llm, |
| 93 | + ) |
| 94 | + retriever_mock.search.return_value = RetrieverResult( |
| 95 | + items=[ |
| 96 | + RetrieverResultItem(content="item content 1"), |
| 97 | + RetrieverResultItem(content="item content 2"), |
| 98 | + ] |
| 99 | + ) |
| 100 | + llm.invoke.side_effect = [ |
| 101 | + LLMResponse(content="llm generated summary"), |
| 102 | + LLMResponse(content="llm generated text"), |
| 103 | + ] |
| 104 | + message_history = Neo4jMessageHistory( |
| 105 | + driver=driver, |
| 106 | + session_id="123", |
| 107 | + node_label="Message", |
| 108 | + ) |
| 109 | + message_history.clear() |
| 110 | + message_history.add_messages( |
| 111 | + messages=[ |
| 112 | + LLMMessage(role="user", content="initial question"), |
| 113 | + LLMMessage(role="assistant", content="answer to initial question"), |
| 114 | + ] |
| 115 | + ) |
| 116 | + res = rag.search( |
| 117 | + query_text="question", |
| 118 | + message_history=message_history, |
| 119 | + ) |
| 120 | + expected_retriever_query_text = """ |
| 121 | +Message Summary: |
| 122 | +llm generated summary |
| 123 | +
|
| 124 | +Current Query: |
| 125 | +question |
| 126 | +""" |
| 127 | + |
| 128 | + first_invocation_input = """ |
| 129 | +Summarize the message history: |
| 130 | +
|
| 131 | +user: initial question |
| 132 | +assistant: answer to initial question |
| 133 | +""" |
| 134 | + first_invocation_system_instruction = "You are a summarization assistant. Summarize the given text in no more than 300 words." |
| 135 | + second_invocation = """Context: |
| 136 | +item content 1 |
| 137 | +item content 2 |
| 138 | +
|
| 139 | +Examples: |
| 140 | +
|
| 141 | +
|
| 142 | +Question: |
| 143 | +question |
| 144 | +
|
| 145 | +Answer: |
| 146 | +""" |
| 147 | + retriever_mock.search.assert_called_once_with( |
| 148 | + query_text=expected_retriever_query_text |
| 149 | + ) |
| 150 | + assert llm.invoke.call_count == 2 |
| 151 | + llm.invoke.assert_has_calls( |
| 152 | + [ |
| 153 | + call( |
| 154 | + input=first_invocation_input, |
| 155 | + system_instruction=first_invocation_system_instruction, |
| 156 | + ), |
| 157 | + call( |
| 158 | + second_invocation, |
| 159 | + message_history.messages, |
| 160 | + system_instruction="Answer the user question using the provided context.", |
| 161 | + ), |
| 162 | + ] |
| 163 | + ) |
| 164 | + |
| 165 | + assert isinstance(res, RagResultModel) |
| 166 | + assert res.answer == "llm generated text" |
| 167 | + assert res.retriever_result is None |
| 168 | + message_history.clear() |
| 169 | + |
| 170 | + |
82 | 171 | @pytest.mark.usefixtures("populate_neo4j_db")
|
83 | 172 | def test_graphrag_happy_path_return_context(
|
84 | 173 | driver: MagicMock, llm: MagicMock, biology_embedder: BiologyEmbedder
|
@@ -127,7 +216,7 @@ def test_graphrag_happy_path_return_context(
|
127 | 216 |
|
128 | 217 | @pytest.mark.usefixtures("populate_neo4j_db")
|
129 | 218 | def test_graphrag_happy_path_examples(
|
130 |
| - driver: MagicMock, llm: MagicMock, biology_embedder: BiologyEmbedder |
| 219 | + driver: MagicMock, llm: MagicMock, biology_embedder: MagicMock |
131 | 220 | ) -> None:
|
132 | 221 | retriever = VectorCypherRetriever(
|
133 | 222 | driver,
|
|
0 commit comments