@@ -148,8 +148,8 @@ class Neo4jMessageHistory(MessageHistory):
148
148
Args:
149
149
session_id (Union[str, int]): Unique identifier for the chat session.
150
150
driver (neo4j.Driver): Neo4j driver instance.
151
- node_label (str, optional): Label used for session nodes in Neo4j. Defaults to "Session".
152
151
window (Optional[PositiveInt], optional): Number of previous messages to return when retrieving messages.
152
+ database (Optional[str], optional): Neo4j database name.
153
153
154
154
"""
155
155
@@ -158,28 +158,33 @@ def __init__(
158
158
session_id : Union [str , int ],
159
159
driver : neo4j .Driver ,
160
160
window : Optional [PositiveInt ] = None ,
161
+ database : Optional [str ] = None ,
161
162
) -> None :
162
163
validated_data = Neo4jMessageHistoryModel (
163
164
session_id = session_id ,
164
165
driver_model = Neo4jDriverModel (driver = driver ),
165
166
window = window ,
167
+ database = database ,
166
168
)
167
169
self ._driver = validated_data .driver_model .driver
168
170
self ._session_id = validated_data .session_id
169
171
self ._window = (
170
172
"" if validated_data .window is None else validated_data .window - 1
171
173
)
174
+ self ._database = validated_data .database
172
175
# Create session node
173
176
self ._driver .execute_query (
174
177
query_ = CREATE_SESSION_NODE_QUERY .format (node_label = "Session" ),
175
178
parameters_ = {"session_id" : self ._session_id },
179
+ database_ = self ._database ,
176
180
)
177
181
178
182
@property
179
183
def messages (self ) -> List [LLMMessage ]:
180
184
result = self ._driver .execute_query (
181
185
query_ = GET_MESSAGES_QUERY .format (node_label = "Session" , window = self ._window ),
182
186
parameters_ = {"session_id" : self ._session_id },
187
+ database_ = self ._database ,
183
188
)
184
189
messages = [
185
190
LLMMessage (
@@ -210,6 +215,7 @@ def add_message(self, message: LLMMessage) -> None:
210
215
"content" : message ["content" ],
211
216
"session_id" : self ._session_id ,
212
217
},
218
+ database_ = self ._database ,
213
219
)
214
220
215
221
def clear (self , delete_session_node : bool = False ) -> None :
@@ -222,9 +228,11 @@ def clear(self, delete_session_node: bool = False) -> None:
222
228
self ._driver .execute_query (
223
229
query_ = DELETE_SESSION_AND_MESSAGES_QUERY .format (node_label = "Session" ),
224
230
parameters_ = {"session_id" : self ._session_id },
231
+ database_ = self ._database ,
225
232
)
226
233
else :
227
234
self ._driver .execute_query (
228
235
query_ = DELETE_MESSAGES_QUERY .format (node_label = "Session" ),
229
236
parameters_ = {"session_id" : self ._session_id },
237
+ database_ = self ._database ,
230
238
)
0 commit comments