Skip to content

Commit 985d6c8

Browse files
committed
Added message history classes
1 parent 4dae4eb commit 985d6c8

File tree

4 files changed

+334
-1
lines changed

4 files changed

+334
-1
lines changed

src/neo4j_graphrag/message_history.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
from abc import ABC, abstractmethod
2+
from typing import List, Optional, Union
3+
4+
import neo4j
5+
from pydantic import PositiveInt
6+
7+
from neo4j_graphrag.llm.types import (
8+
LLMMessage,
9+
)
10+
from neo4j_graphrag.types import (
11+
Neo4jDriverModel,
12+
Neo4jMessageHistoryModel,
13+
)
14+
15+
CREATE_SESSION_NODE_QUERY = "MERGE (s:`{node_label}` {{id:$session_id}})"
16+
17+
CLEAR_SESSION_QUERY = (
18+
"MATCH (s:`{node_label}`)-[:LAST_MESSAGE]->(last_message) "
19+
"WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT]-() "
20+
"WITH p, length(p) AS length ORDER BY length DESC LIMIT 1 "
21+
"UNWIND nodes(p) as node DETACH DELETE node;"
22+
)
23+
24+
GET_MESSAGES_QUERY = (
25+
"MATCH (s:`{node_label}`)-[:LAST_MESSAGE]->(last_message) "
26+
"WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT*0.."
27+
"{window}]-() WITH p, length(p) AS length "
28+
"ORDER BY length DESC LIMIT 1 UNWIND reverse(nodes(p)) AS node "
29+
"RETURN {{data:{{content: node.content}}, role:node.role}} AS result"
30+
)
31+
32+
ADD_MESSAGE_QUERY = (
33+
"MATCH (s:`{node_label}`) WHERE s.id = $session_id "
34+
"OPTIONAL MATCH (s)-[lm:LAST_MESSAGE]->(last_message) "
35+
"CREATE (s)-[:LAST_MESSAGE]->(new:Message) "
36+
"SET new += {{role:$role, content:$content}} "
37+
"WITH new, lm, last_message WHERE last_message IS NOT NULL "
38+
"CREATE (last_message)-[:NEXT]->(new) "
39+
"DELETE lm"
40+
)
41+
42+
43+
class MessageHistory(ABC):
44+
@property
45+
@abstractmethod
46+
def messages(self) -> List[LLMMessage]: ...
47+
48+
@abstractmethod
49+
def add_message(self, message: LLMMessage) -> None: ...
50+
51+
def add_messages(self, messages: List[LLMMessage]) -> None:
52+
for message in messages:
53+
self.add_message(message)
54+
55+
@abstractmethod
56+
def clear(self) -> None: ...
57+
58+
59+
class InMemoryMessageHistory(MessageHistory):
60+
def __init__(self, messages: List[LLMMessage] = []) -> None:
61+
self._messages = messages
62+
63+
@property
64+
def messages(self) -> List[LLMMessage]:
65+
return self._messages
66+
67+
def add_message(self, message: LLMMessage) -> None:
68+
self._messages.append(message)
69+
70+
def add_messages(self, messages: List[LLMMessage]) -> None:
71+
self._messages.extend(messages)
72+
73+
def clear(self) -> None:
74+
self._messages = []
75+
76+
77+
class Neo4jMessageHistory(MessageHistory):
78+
def __init__(
79+
self,
80+
session_id: Union[str, int],
81+
driver: neo4j.Driver,
82+
node_label: str = "Session",
83+
window: Optional[PositiveInt] = None,
84+
) -> None:
85+
validated_data = Neo4jMessageHistoryModel(
86+
session_id=session_id,
87+
driver_model=Neo4jDriverModel(driver=driver),
88+
node_label=node_label,
89+
window=window,
90+
)
91+
self._driver = validated_data.driver_model.driver
92+
self._session_id = validated_data.session_id
93+
self._node_label = validated_data.node_label
94+
self._window = (
95+
"" if validated_data.window is None else validated_data.window - 1
96+
)
97+
# Create session node
98+
self._driver.execute_query(
99+
query_=CREATE_SESSION_NODE_QUERY.format(node_label=self._node_label),
100+
parameters_={"session_id": self._session_id},
101+
)
102+
103+
@property
104+
def messages(self) -> List[LLMMessage]:
105+
result = self._driver.execute_query(
106+
query_=GET_MESSAGES_QUERY.format(
107+
node_label=self._node_label, window=self._window
108+
),
109+
parameters_={"session_id": self._session_id},
110+
)
111+
messages = [
112+
LLMMessage(
113+
content=el["result"]["data"]["content"],
114+
role=el["result"]["role"],
115+
)
116+
for el in result.records
117+
]
118+
return messages
119+
120+
@messages.setter
121+
def messages(self, messages: List[LLMMessage]) -> None:
122+
raise NotImplementedError(
123+
"Direct assignment to 'messages' is not allowed."
124+
" Use the 'add_messages' instead."
125+
)
126+
127+
def add_message(self, message: LLMMessage) -> None:
128+
self._driver.execute_query(
129+
query_=ADD_MESSAGE_QUERY.format(node_label=self._node_label),
130+
parameters_={
131+
"role": message["role"],
132+
"content": message["content"],
133+
"session_id": self._session_id,
134+
},
135+
)
136+
137+
def clear(self) -> None:
138+
self._driver.execute_query(
139+
query_=CLEAR_SESSION_QUERY.format(node_label=self._node_label),
140+
parameters_={"session_id": self._session_id},
141+
)
142+
143+
def __del__(self) -> None:
144+
if self._driver:
145+
self._driver.close()

src/neo4j_graphrag/types.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
from enum import Enum
18-
from typing import Any, Callable, Literal, Optional
18+
from typing import Any, Callable, Literal, Optional, Union
1919

2020
import neo4j
2121
from pydantic import (
@@ -251,3 +251,10 @@ class Text2CypherRetrieverModel(BaseModel):
251251
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
252252
custom_prompt: Optional[str] = None
253253
neo4j_database: Optional[str] = None
254+
255+
256+
class Neo4jMessageHistoryModel(BaseModel):
257+
session_id: Union[str, int]
258+
driver_model: Neo4jDriverModel
259+
node_label: str = "Session"
260+
window: Optional[PositiveInt] = None

tests/e2e/test_message_history_e2e.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import neo4j
2+
from neo4j_graphrag.llm.types import LLMMessage
3+
from neo4j_graphrag.message_history import Neo4jMessageHistory
4+
5+
6+
def test_neo4j_message_history_add_message(driver: neo4j.Driver) -> None:
7+
driver.execute_query(query_="MATCH (n) DETACH DELETE n;")
8+
message_history = Neo4jMessageHistory(session_id="123", driver=driver)
9+
message_history.add_message(
10+
LLMMessage(role="user", content="Hello"),
11+
)
12+
assert len(message_history.messages) == 1
13+
assert message_history.messages[0]["role"] == "user"
14+
assert message_history.messages[0]["content"] == "Hello"
15+
16+
17+
def test_neo4j_message_history_add_messages(driver: neo4j.Driver) -> None:
18+
driver.execute_query(query_="MATCH (n) DETACH DELETE n;")
19+
message_history = Neo4jMessageHistory(session_id="123", driver=driver)
20+
message_history.add_messages(
21+
[
22+
LLMMessage(role="system", content="You are a helpful assistant."),
23+
LLMMessage(role="user", content="Hello"),
24+
LLMMessage(
25+
role="assistant",
26+
content="Hello, how may I help you today?",
27+
),
28+
LLMMessage(role="user", content="I'd like to buy a new car."),
29+
LLMMessage(
30+
role="assistant",
31+
content="I'd be happy to help you find the perfect car.",
32+
),
33+
]
34+
)
35+
assert len(message_history.messages) == 5
36+
assert message_history.messages[0]["role"] == "system"
37+
assert message_history.messages[0]["content"] == "You are a helpful assistant."
38+
assert message_history.messages[1]["role"] == "user"
39+
assert message_history.messages[1]["content"] == "Hello"
40+
assert message_history.messages[2]["role"] == "assistant"
41+
assert message_history.messages[2]["content"] == "Hello, how may I help you today?"
42+
assert message_history.messages[3]["role"] == "user"
43+
assert message_history.messages[3]["content"] == "I'd like to buy a new car."
44+
assert message_history.messages[4]["role"] == "assistant"
45+
assert (
46+
message_history.messages[4]["content"]
47+
== "I'd be happy to help you find the perfect car."
48+
)
49+
50+
51+
def test_neo4j_message_history_clear(driver: neo4j.Driver) -> None:
52+
driver.execute_query(query_="MATCH (n) DETACH DELETE n;")
53+
message_history = Neo4jMessageHistory(session_id="123", driver=driver)
54+
message_history.add_messages(
55+
[
56+
LLMMessage(role="system", content="You are a helpful assistant."),
57+
LLMMessage(role="user", content="Hello"),
58+
]
59+
)
60+
assert len(message_history.messages) == 2
61+
message_history.clear()
62+
assert len(message_history.messages) == 0
63+
64+
65+
def test_neo4j_message_window_size(driver: neo4j.Driver) -> None:
66+
driver.execute_query(query_="MATCH (n) DETACH DELETE n;")
67+
message_history = Neo4jMessageHistory(session_id="123", driver=driver, window=1)
68+
message_history.add_messages(
69+
[
70+
LLMMessage(role="system", content="You are a helpful assistant."),
71+
LLMMessage(role="user", content="Hello"),
72+
LLMMessage(
73+
role="assistant",
74+
content="Hello, how may I help you today?",
75+
),
76+
LLMMessage(role="user", content="I'd like to buy a new car."),
77+
LLMMessage(
78+
role="assistant",
79+
content="I'd be happy to help you find the perfect car.",
80+
),
81+
]
82+
)
83+
assert len(message_history.messages) == 1
84+
assert (
85+
message_history.messages[0]["content"]
86+
== "I'd be happy to help you find the perfect car."
87+
)
88+
assert message_history.messages[0]["role"] == "assistant"

tests/unit/test_message_history.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from unittest.mock import MagicMock
2+
3+
import pytest
4+
from neo4j_graphrag.llm.types import LLMMessage
5+
from neo4j_graphrag.message_history import InMemoryMessageHistory, Neo4jMessageHistory
6+
from pydantic import ValidationError
7+
8+
9+
def test_in_memory_message_history_add_message() -> None:
10+
message_history = InMemoryMessageHistory()
11+
message_history.add_message(
12+
LLMMessage(role="user", content="may thy knife chip and shatter")
13+
)
14+
assert len(message_history.messages) == 1
15+
assert message_history.messages[0]["role"] == "user"
16+
assert message_history.messages[0]["content"] == "may thy knife chip and shatter"
17+
18+
19+
def test_in_memory_message_history_add_messages() -> None:
20+
message_history = InMemoryMessageHistory()
21+
message_history.add_messages(
22+
[
23+
LLMMessage(role="user", content="may thy knife chip and shatter"),
24+
LLMMessage(
25+
role="assistant",
26+
content="He who controls the spice controls the universe.",
27+
),
28+
]
29+
)
30+
assert len(message_history.messages) == 2
31+
assert message_history.messages[0]["role"] == "user"
32+
assert message_history.messages[0]["content"] == "may thy knife chip and shatter"
33+
assert message_history.messages[1]["role"] == "assistant"
34+
assert (
35+
message_history.messages[1]["content"]
36+
== "He who controls the spice controls the universe."
37+
)
38+
39+
40+
def test_in_memory_message_history_clear() -> None:
41+
message_history = InMemoryMessageHistory()
42+
message_history.add_messages(
43+
[
44+
LLMMessage(role="user", content="may thy knife chip and shatter"),
45+
LLMMessage(
46+
role="assistant",
47+
content="He who controls the spice controls the universe.",
48+
),
49+
]
50+
)
51+
assert len(message_history.messages) == 2
52+
message_history.clear()
53+
assert len(message_history.messages) == 0
54+
55+
56+
def test_neo4j_message_history_invalid_session_id(driver: MagicMock) -> None:
57+
with pytest.raises(ValidationError) as exc_info:
58+
Neo4jMessageHistory(session_id=1.5, driver=driver, node_label="123", window=1) # type: ignore[arg-type]
59+
assert "Input should be a valid string" in str(exc_info.value)
60+
61+
62+
def test_neo4j_message_history_invalid_driver() -> None:
63+
with pytest.raises(ValidationError) as exc_info:
64+
Neo4jMessageHistory(session_id="123", driver=1.5, node_label="123", window=1) # type: ignore[arg-type]
65+
assert "Input should be a valid dictionary or instance of Neo4jDriver" in str(
66+
exc_info.value
67+
)
68+
69+
70+
def test_neo4j_message_history_invalid_node_label(driver: MagicMock) -> None:
71+
with pytest.raises(ValidationError) as exc_info:
72+
Neo4jMessageHistory(session_id="123", driver=driver, node_label=1.5, window=1) # type: ignore[arg-type]
73+
assert "Input should be a valid string" in str(exc_info.value)
74+
75+
76+
def test_neo4j_message_history_invalid_window(driver: MagicMock) -> None:
77+
with pytest.raises(ValidationError) as exc_info:
78+
Neo4jMessageHistory(
79+
session_id="123", driver=driver, node_label="123", window=-1
80+
)
81+
assert "Input should be greater than 0" in str(exc_info.value)
82+
83+
84+
def test_neo4j_message_history_messages_setter(neo4j_driver: MagicMock) -> None:
85+
message_history = Neo4jMessageHistory(session_id="123", driver=neo4j_driver)
86+
with pytest.raises(NotImplementedError) as exc_info:
87+
message_history.messages = [
88+
LLMMessage(role="user", content="may thy knife chip and shatter"),
89+
]
90+
assert (
91+
str(exc_info.value)
92+
== "Direct assignment to 'messages' is not allowed. Use the 'add_messages' instead."
93+
)

0 commit comments

Comments
 (0)