diff --git a/src/neo4j_graphrag/message_history.py b/src/neo4j_graphrag/message_history.py index cd0ff87d5..7295479fa 100644 --- a/src/neo4j_graphrag/message_history.py +++ b/src/neo4j_graphrag/message_history.py @@ -148,8 +148,8 @@ class Neo4jMessageHistory(MessageHistory): Args: session_id (Union[str, int]): Unique identifier for the chat session. driver (neo4j.Driver): Neo4j driver instance. - node_label (str, optional): Label used for session nodes in Neo4j. Defaults to "Session". window (Optional[PositiveInt], optional): Number of previous messages to return when retrieving messages. + database (Optional[str], optional): Neo4j database name. """ @@ -158,21 +158,25 @@ def __init__( session_id: Union[str, int], driver: neo4j.Driver, window: Optional[PositiveInt] = None, + database: Optional[str] = None, ) -> None: validated_data = Neo4jMessageHistoryModel( session_id=session_id, driver_model=Neo4jDriverModel(driver=driver), window=window, + database=database, ) self._driver = validated_data.driver_model.driver self._session_id = validated_data.session_id self._window = ( "" if validated_data.window is None else validated_data.window - 1 ) + self._database = validated_data.database # Create session node self._driver.execute_query( query_=CREATE_SESSION_NODE_QUERY.format(node_label="Session"), parameters_={"session_id": self._session_id}, + database_=self._database, ) @property @@ -180,6 +184,7 @@ def messages(self) -> List[LLMMessage]: result = self._driver.execute_query( query_=GET_MESSAGES_QUERY.format(node_label="Session", window=self._window), parameters_={"session_id": self._session_id}, + database_=self._database, ) messages = [ LLMMessage( @@ -210,6 +215,7 @@ def add_message(self, message: LLMMessage) -> None: "content": message["content"], "session_id": self._session_id, }, + database_=self._database, ) def clear(self, delete_session_node: bool = False) -> None: @@ -222,9 +228,11 @@ def clear(self, delete_session_node: bool = False) -> None: self._driver.execute_query( query_=DELETE_SESSION_AND_MESSAGES_QUERY.format(node_label="Session"), parameters_={"session_id": self._session_id}, + database_=self._database, ) else: self._driver.execute_query( query_=DELETE_MESSAGES_QUERY.format(node_label="Session"), parameters_={"session_id": self._session_id}, + database_=self._database, ) diff --git a/src/neo4j_graphrag/types.py b/src/neo4j_graphrag/types.py index d6b969a51..1225b5af3 100644 --- a/src/neo4j_graphrag/types.py +++ b/src/neo4j_graphrag/types.py @@ -302,6 +302,7 @@ class Neo4jMessageHistoryModel(BaseModel): session_id: Union[str, int] driver_model: Neo4jDriverModel window: Optional[PositiveInt] = None + database: Optional[str] = None @field_validator("session_id") def validate_session_id(cls, v: Union[str, int]) -> Union[str, int]: diff --git a/tests/unit/test_message_history.py b/tests/unit/test_message_history.py index 762a6d908..befc74a79 100644 --- a/tests/unit/test_message_history.py +++ b/tests/unit/test_message_history.py @@ -95,3 +95,20 @@ def test_neo4j_message_history_messages_setter(driver: MagicMock) -> None: str(exc_info.value) == "Direct assignment to 'messages' is not allowed. Use the 'add_messages' instead." ) + + +def test_neo4j_message_history_messages_getter_custom_db(driver: MagicMock) -> None: + driver.execute_query.side_effect = [ + MagicMock(records=[]), + MagicMock( + records=[{"result": {"data": {"content": "my message"}, "role": "user"}}] + ), + ] + message_history = Neo4jMessageHistory( + session_id="123", driver=driver, database="my_db" + ) + messages = message_history.messages + assert len(messages) == 1 + assert messages[0] == LLMMessage(content="my message", role="user") + for c in driver.execute_query.call_args_list: + assert c.kwargs["database_"] == "my_db"