Skip to content

Commit f56ba20

Browse files
committed
Updated CLEAR_SESSION_QUERY
1 parent 6e8e10d commit f56ba20

File tree

3 files changed

+16
-5
lines changed

3 files changed

+16
-5
lines changed

src/neo4j_graphrag/message_history.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,12 @@
3030
CREATE_SESSION_NODE_QUERY = "MERGE (s:`{node_label}` {{id:$session_id}})"
3131

3232
CLEAR_SESSION_QUERY = (
33-
"MATCH (s:`{node_label}`)-[:LAST_MESSAGE]->(last_message) "
34-
"WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT]-() "
35-
"WITH p, length(p) AS length ORDER BY length DESC LIMIT 1 "
36-
"UNWIND nodes(p) as node DETACH DELETE node;"
33+
"MATCH (s:`{node_label}`) "
34+
"WHERE s.id = $session_id "
35+
"OPTIONAL MATCH p=(s)-[:LAST_MESSAGE]->(:Message)<-[:NEXT*0..]-(:Message) "
36+
"WITH CASE WHEN p IS NULL THEN [s] ELSE nodes(p) + [s] END AS nodes "
37+
"UNWIND nodes AS node "
38+
"DETACH DELETE node;"
3739
)
3840

3941
GET_MESSAGES_QUERY = (

tests/e2e/test_graphrag_e2e.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def test_graphrag_happy_path_with_neo4j_message_history(
105105
driver=driver,
106106
session_id="123",
107107
)
108-
message_history.clear()
109108
message_history.add_messages(
110109
messages=[
111110
LLMMessage(role="user", content="initial question"),

tests/e2e/test_message_history_e2e.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,16 @@ def test_neo4j_message_history_clear(driver: neo4j.Driver) -> None:
7676
assert len(message_history.messages) == 0
7777

7878

79+
def test_neo4j_message_history_clear_no_messages(driver: neo4j.Driver) -> None:
80+
driver.execute_query(query_="MATCH (n) DETACH DELETE n;")
81+
message_history = Neo4jMessageHistory(session_id="123", driver=driver)
82+
message_history.clear()
83+
results = driver.execute_query(
84+
query_="MATCH (s:`Session`) WHERE s.id = '123' RETURN s"
85+
)
86+
assert results.records == []
87+
88+
7989
def test_neo4j_message_window_size(driver: neo4j.Driver) -> None:
8090
driver.execute_query(query_="MATCH (n) DETACH DELETE n;")
8191
message_history = Neo4jMessageHistory(session_id="123", driver=driver, window=1)

0 commit comments

Comments
 (0)