Skip to content

Commit 59d74bb

Browse files
authored
Expose database in Neo4jMessageHistory (neo4j#354)
* Expose database in Neo4jMessageHistory * Remove unused import
1 parent e1941ed commit 59d74bb

File tree

3 files changed

+27
-1
lines changed

3 files changed

+27
-1
lines changed

src/neo4j_graphrag/message_history.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ class Neo4jMessageHistory(MessageHistory):
148148
Args:
149149
session_id (Union[str, int]): Unique identifier for the chat session.
150150
driver (neo4j.Driver): Neo4j driver instance.
151-
node_label (str, optional): Label used for session nodes in Neo4j. Defaults to "Session".
152151
window (Optional[PositiveInt], optional): Number of previous messages to return when retrieving messages.
152+
database (Optional[str], optional): Neo4j database name.
153153
154154
"""
155155

@@ -158,28 +158,33 @@ def __init__(
158158
session_id: Union[str, int],
159159
driver: neo4j.Driver,
160160
window: Optional[PositiveInt] = None,
161+
database: Optional[str] = None,
161162
) -> None:
162163
validated_data = Neo4jMessageHistoryModel(
163164
session_id=session_id,
164165
driver_model=Neo4jDriverModel(driver=driver),
165166
window=window,
167+
database=database,
166168
)
167169
self._driver = validated_data.driver_model.driver
168170
self._session_id = validated_data.session_id
169171
self._window = (
170172
"" if validated_data.window is None else validated_data.window - 1
171173
)
174+
self._database = validated_data.database
172175
# Create session node
173176
self._driver.execute_query(
174177
query_=CREATE_SESSION_NODE_QUERY.format(node_label="Session"),
175178
parameters_={"session_id": self._session_id},
179+
database_=self._database,
176180
)
177181

178182
@property
179183
def messages(self) -> List[LLMMessage]:
180184
result = self._driver.execute_query(
181185
query_=GET_MESSAGES_QUERY.format(node_label="Session", window=self._window),
182186
parameters_={"session_id": self._session_id},
187+
database_=self._database,
183188
)
184189
messages = [
185190
LLMMessage(
@@ -210,6 +215,7 @@ def add_message(self, message: LLMMessage) -> None:
210215
"content": message["content"],
211216
"session_id": self._session_id,
212217
},
218+
database_=self._database,
213219
)
214220

215221
def clear(self, delete_session_node: bool = False) -> None:
@@ -222,9 +228,11 @@ def clear(self, delete_session_node: bool = False) -> None:
222228
self._driver.execute_query(
223229
query_=DELETE_SESSION_AND_MESSAGES_QUERY.format(node_label="Session"),
224230
parameters_={"session_id": self._session_id},
231+
database_=self._database,
225232
)
226233
else:
227234
self._driver.execute_query(
228235
query_=DELETE_MESSAGES_QUERY.format(node_label="Session"),
229236
parameters_={"session_id": self._session_id},
237+
database_=self._database,
230238
)

src/neo4j_graphrag/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ class Neo4jMessageHistoryModel(BaseModel):
302302
session_id: Union[str, int]
303303
driver_model: Neo4jDriverModel
304304
window: Optional[PositiveInt] = None
305+
database: Optional[str] = None
305306

306307
@field_validator("session_id")
307308
def validate_session_id(cls, v: Union[str, int]) -> Union[str, int]:

tests/unit/test_message_history.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,20 @@ def test_neo4j_message_history_messages_setter(driver: MagicMock) -> None:
9595
str(exc_info.value)
9696
== "Direct assignment to 'messages' is not allowed. Use the 'add_messages' instead."
9797
)
98+
99+
100+
def test_neo4j_message_history_messages_getter_custom_db(driver: MagicMock) -> None:
101+
driver.execute_query.side_effect = [
102+
MagicMock(records=[]),
103+
MagicMock(
104+
records=[{"result": {"data": {"content": "my message"}, "role": "user"}}]
105+
),
106+
]
107+
message_history = Neo4jMessageHistory(
108+
session_id="123", driver=driver, database="my_db"
109+
)
110+
messages = message_history.messages
111+
assert len(messages) == 1
112+
assert messages[0] == LLMMessage(content="my message", role="user")
113+
for c in driver.execute_query.call_args_list:
114+
assert c.kwargs["database_"] == "my_db"

0 commit comments

Comments
 (0)