Skip to content

Commit 8ae9423

Browse files
committed
Removed node_label parameter from Neo4jMessageHistory
1 parent c3c28c8 commit 8ae9423

File tree

6 files changed

+10
-35
lines changed

6 files changed

+10
-35
lines changed

examples/customize/llms/llm_with_neo4j_message_history.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@
3434
database=DATABASE,
3535
)
3636

37-
history = Neo4jMessageHistory(
38-
session_id="123", driver=driver, node_label="Message", window=10
39-
)
37+
history = Neo4jMessageHistory(session_id="123", driver=driver, window=10)
4038

4139
for question in questions:
4240
res: LLMResponse = llm.invoke(

examples/question_answering/graphrag_with_neo4j_message_history.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,7 @@
5151
llm=llm,
5252
)
5353

54-
history = Neo4jMessageHistory(
55-
session_id="123", driver=driver, node_label="Message", window=10
56-
)
54+
history = Neo4jMessageHistory(session_id="123", driver=driver, window=10)
5755

5856
questions = [
5957
"Who starred in the Apollo 13 movies?",

src/neo4j_graphrag/message_history.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ class Neo4jMessageHistory(MessageHistory):
129129
driver = neo4j.GraphDatabase.driver(URI, auth=AUTH)
130130
131131
history = Neo4jMessageHistory(
132-
session_id="123", driver=driver, node_label="Message", window=10
132+
session_id="123", driver=driver, window=10
133133
)
134134
135135
message = LLMMessage(role="user", content="Hello!")
@@ -147,33 +147,28 @@ def __init__(
147147
self,
148148
session_id: Union[str, int],
149149
driver: neo4j.Driver,
150-
node_label: str = "Session",
151150
window: Optional[PositiveInt] = None,
152151
) -> None:
153152
validated_data = Neo4jMessageHistoryModel(
154153
session_id=session_id,
155154
driver_model=Neo4jDriverModel(driver=driver),
156-
node_label=node_label,
157155
window=window,
158156
)
159157
self._driver = validated_data.driver_model.driver
160158
self._session_id = validated_data.session_id
161-
self._node_label = validated_data.node_label
162159
self._window = (
163160
"" if validated_data.window is None else validated_data.window - 1
164161
)
165162
# Create session node
166163
self._driver.execute_query(
167-
query_=CREATE_SESSION_NODE_QUERY.format(node_label=self._node_label),
164+
query_=CREATE_SESSION_NODE_QUERY.format(node_label="Session"),
168165
parameters_={"session_id": self._session_id},
169166
)
170167

171168
@property
172169
def messages(self) -> List[LLMMessage]:
173170
result = self._driver.execute_query(
174-
query_=GET_MESSAGES_QUERY.format(
175-
node_label=self._node_label, window=self._window
176-
),
171+
query_=GET_MESSAGES_QUERY.format(node_label="Session", window=self._window),
177172
parameters_={"session_id": self._session_id},
178173
)
179174
messages = [
@@ -199,7 +194,7 @@ def add_message(self, message: LLMMessage) -> None:
199194
message (LLMMessage): The message to add.
200195
"""
201196
self._driver.execute_query(
202-
query_=ADD_MESSAGE_QUERY.format(node_label=self._node_label),
197+
query_=ADD_MESSAGE_QUERY.format(node_label="Session"),
203198
parameters_={
204199
"role": message["role"],
205200
"content": message["content"],
@@ -210,6 +205,6 @@ def add_message(self, message: LLMMessage) -> None:
210205
def clear(self) -> None:
211206
"""Clear the message history."""
212207
self._driver.execute_query(
213-
query_=CLEAR_SESSION_QUERY.format(node_label=self._node_label),
208+
query_=CLEAR_SESSION_QUERY.format(node_label="Session"),
214209
parameters_={"session_id": self._session_id},
215210
)

src/neo4j_graphrag/types.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -256,17 +256,10 @@ class Text2CypherRetrieverModel(BaseModel):
256256
class Neo4jMessageHistoryModel(BaseModel):
257257
session_id: Union[str, int]
258258
driver_model: Neo4jDriverModel
259-
node_label: str = "Session"
260259
window: Optional[PositiveInt] = None
261260

262261
@field_validator("session_id")
263262
def validate_session_id(cls, v: Union[str, int]) -> Union[str, int]:
264263
if isinstance(v, str) and len(v) == 0:
265264
raise ValueError("session_id cannot be empty")
266265
return v
267-
268-
@field_validator("node_label")
269-
def validate_node_label(cls, v: str) -> str:
270-
if len(v) == 0:
271-
raise ValueError("node_label cannot be empty")
272-
return v

tests/e2e/test_graphrag_e2e.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ def test_graphrag_happy_path_with_neo4j_message_history(
104104
message_history = Neo4jMessageHistory(
105105
driver=driver,
106106
session_id="123",
107-
node_label="Message",
108107
)
109108
message_history.clear()
110109
message_history.add_messages(

tests/unit/test_message_history.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,27 +69,19 @@ def test_in_memory_message_history_clear() -> None:
6969

7070
def test_neo4j_message_history_invalid_session_id(driver: MagicMock) -> None:
7171
with pytest.raises(ValidationError) as exc_info:
72-
Neo4jMessageHistory(session_id=1.5, driver=driver, node_label="123", window=1) # type: ignore[arg-type]
72+
Neo4jMessageHistory(session_id=1.5, driver=driver, window=1) # type: ignore[arg-type]
7373
assert "Input should be a valid string" in str(exc_info.value)
7474

7575

7676
def test_neo4j_message_history_invalid_driver() -> None:
7777
with pytest.raises(ValidationError) as exc_info:
78-
Neo4jMessageHistory(session_id="123", driver=1.5, node_label="123", window=1) # type: ignore[arg-type]
78+
Neo4jMessageHistory(session_id="123", driver=1.5, window=1) # type: ignore[arg-type]
7979
assert "Input should be an instance of Driver" in str(exc_info.value)
8080

8181

82-
def test_neo4j_message_history_invalid_node_label(driver: MagicMock) -> None:
83-
with pytest.raises(ValidationError) as exc_info:
84-
Neo4jMessageHistory(session_id="123", driver=driver, node_label=1.5, window=1) # type: ignore[arg-type]
85-
assert "Input should be a valid string" in str(exc_info.value)
86-
87-
8882
def test_neo4j_message_history_invalid_window(driver: MagicMock) -> None:
8983
with pytest.raises(ValidationError) as exc_info:
90-
Neo4jMessageHistory(
91-
session_id="123", driver=driver, node_label="123", window=-1
92-
)
84+
Neo4jMessageHistory(session_id="123", driver=driver, window=-1)
9385
assert "Input should be greater than 0" in str(exc_info.value)
9486

9587

0 commit comments

Comments
 (0)