From 51fbb313aba92cc7e94503da5ccaf87ec2533427 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 9 Jun 2025 14:31:38 +0200 Subject: [PATCH 1/2] Expose database in Neo4jMessageHistory --- src/neo4j_graphrag/message_history.py | 10 +++++++++- src/neo4j_graphrag/types.py | 1 + tests/unit/test_message_history.py | 19 ++++++++++++++++++- 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/src/neo4j_graphrag/message_history.py b/src/neo4j_graphrag/message_history.py index cd0ff87d5..7295479fa 100644 --- a/src/neo4j_graphrag/message_history.py +++ b/src/neo4j_graphrag/message_history.py @@ -148,8 +148,8 @@ class Neo4jMessageHistory(MessageHistory): Args: session_id (Union[str, int]): Unique identifier for the chat session. driver (neo4j.Driver): Neo4j driver instance. - node_label (str, optional): Label used for session nodes in Neo4j. Defaults to "Session". window (Optional[PositiveInt], optional): Number of previous messages to return when retrieving messages. + database (Optional[str], optional): Neo4j database name. """ @@ -158,21 +158,25 @@ def __init__( session_id: Union[str, int], driver: neo4j.Driver, window: Optional[PositiveInt] = None, + database: Optional[str] = None, ) -> None: validated_data = Neo4jMessageHistoryModel( session_id=session_id, driver_model=Neo4jDriverModel(driver=driver), window=window, + database=database, ) self._driver = validated_data.driver_model.driver self._session_id = validated_data.session_id self._window = ( "" if validated_data.window is None else validated_data.window - 1 ) + self._database = validated_data.database # Create session node self._driver.execute_query( query_=CREATE_SESSION_NODE_QUERY.format(node_label="Session"), parameters_={"session_id": self._session_id}, + database_=self._database, ) @property @@ -180,6 +184,7 @@ def messages(self) -> List[LLMMessage]: result = self._driver.execute_query( query_=GET_MESSAGES_QUERY.format(node_label="Session", window=self._window), parameters_={"session_id": self._session_id}, + database_=self._database, ) messages = [ LLMMessage( @@ -210,6 +215,7 @@ def add_message(self, message: LLMMessage) -> None: "content": message["content"], "session_id": self._session_id, }, + database_=self._database, ) def clear(self, delete_session_node: bool = False) -> None: @@ -222,9 +228,11 @@ def clear(self, delete_session_node: bool = False) -> None: self._driver.execute_query( query_=DELETE_SESSION_AND_MESSAGES_QUERY.format(node_label="Session"), parameters_={"session_id": self._session_id}, + database_=self._database, ) else: self._driver.execute_query( query_=DELETE_MESSAGES_QUERY.format(node_label="Session"), parameters_={"session_id": self._session_id}, + database_=self._database, ) diff --git a/src/neo4j_graphrag/types.py b/src/neo4j_graphrag/types.py index d6b969a51..1225b5af3 100644 --- a/src/neo4j_graphrag/types.py +++ b/src/neo4j_graphrag/types.py @@ -302,6 +302,7 @@ class Neo4jMessageHistoryModel(BaseModel): session_id: Union[str, int] driver_model: Neo4jDriverModel window: Optional[PositiveInt] = None + database: Optional[str] = None @field_validator("session_id") def validate_session_id(cls, v: Union[str, int]) -> Union[str, int]: diff --git a/tests/unit/test_message_history.py b/tests/unit/test_message_history.py index 762a6d908..4172701f3 100644 --- a/tests/unit/test_message_history.py +++ b/tests/unit/test_message_history.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock +from unittest.mock import MagicMock, ANY import pytest from neo4j_graphrag.message_history import InMemoryMessageHistory, Neo4jMessageHistory @@ -95,3 +95,20 @@ def test_neo4j_message_history_messages_setter(driver: MagicMock) -> None: str(exc_info.value) == "Direct assignment to 'messages' is not allowed. Use the 'add_messages' instead." ) + + +def test_neo4j_message_history_messages_getter_custom_db(driver: MagicMock) -> None: + driver.execute_query.side_effect = [ + MagicMock(records=[]), + MagicMock( + records=[{"result": {"data": {"content": "my message"}, "role": "user"}}] + ), + ] + message_history = Neo4jMessageHistory( + session_id="123", driver=driver, database="my_db" + ) + messages = message_history.messages + assert len(messages) == 1 + assert messages[0] == LLMMessage(content="my message", role="user") + for c in driver.execute_query.call_args_list: + assert c.kwargs["database_"] == "my_db" From 8acfaa73d77adca14f921048db377c2dd93886c4 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 9 Jun 2025 14:32:30 +0200 Subject: [PATCH 2/2] Remove unused import --- tests/unit/test_message_history.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_message_history.py b/tests/unit/test_message_history.py index 4172701f3..befc74a79 100644 --- a/tests/unit/test_message_history.py +++ b/tests/unit/test_message_history.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock, ANY +from unittest.mock import MagicMock import pytest from neo4j_graphrag.message_history import InMemoryMessageHistory, Neo4jMessageHistory