55
55
56
56
57
57
class MessageHistory (ABC ):
58
+ """Abstract base class for message history storage."""
59
+
58
60
@property
59
61
@abstractmethod
60
62
def messages (self ) -> List [LLMMessage ]: ...
@@ -71,6 +73,24 @@ def clear(self) -> None: ...
71
73
72
74
73
75
class InMemoryMessageHistory (MessageHistory ):
76
+ """Message history stored in memory
77
+
78
+ Example:
79
+
80
+ .. code-block:: python
81
+
82
+ from neo4j_graphrag.llm.types import LLMMessage
83
+ from neo4j_graphrag.message_history import InMemoryMessageHistory
84
+
85
+ history = InMemoryMessageHistory()
86
+
87
+ message = LLMMessage(role="user", content="Hello!")
88
+ history.add_message(message)
89
+
90
+ Args:
91
+ messages (Optional[List[LLMMessage]]): List of messages to initialize the history with. Defaults to None.
92
+ """
93
+
74
94
def __init__ (self , messages : Optional [List [LLMMessage ]] = None ) -> None :
75
95
self ._messages = messages or []
76
96
@@ -89,6 +109,33 @@ def clear(self) -> None:
89
109
90
110
91
111
class Neo4jMessageHistory (MessageHistory ):
112
+ """Message history stored in a Neo4j database
113
+
114
+ Example:
115
+
116
+ .. code-block:: python
117
+
118
+ import neo4j
119
+ from neo4j_graphrag.llm.types import LLMMessage
120
+ from neo4j_graphrag.message_history import Neo4jMessageHistory
121
+
122
+ driver = neo4j.GraphDatabase.driver(URI, auth=AUTH)
123
+
124
+ history = Neo4jMessageHistory(
125
+ session_id="123", driver=driver, node_label="Message", window=10
126
+ )
127
+
128
+ message = LLMMessage(role="user", content="Hello!")
129
+ history.add_message(message)
130
+
131
+ Args:
132
+ session_id (Union[str, int]): Unique identifier for the chat session.
133
+ driver (neo4j.Driver): Neo4j driver instance.
134
+ node_label (str, optional): Label used for session nodes in Neo4j. Defaults to "Session".
135
+ window (Optional[PositiveInt], optional): Number of previous messages to return when retrieving messages.
136
+
137
+ """
138
+
92
139
def __init__ (
93
140
self ,
94
141
session_id : Union [str , int ],
@@ -139,6 +186,11 @@ def messages(self, messages: List[LLMMessage]) -> None:
139
186
)
140
187
141
188
def add_message (self , message : LLMMessage ) -> None :
189
+ """Add a message to the message history.
190
+
191
+ Args:
192
+ message (LLMMessage): The message to add.
193
+ """
142
194
self ._driver .execute_query (
143
195
query_ = ADD_MESSAGE_QUERY .format (node_label = self ._node_label ),
144
196
parameters_ = {
@@ -149,6 +201,7 @@ def add_message(self, message: LLMMessage) -> None:
149
201
)
150
202
151
203
def clear (self ) -> None :
204
+ """Clear the message history."""
152
205
self ._driver .execute_query (
153
206
query_ = CLEAR_SESSION_QUERY .format (node_label = self ._node_label ),
154
207
parameters_ = {"session_id" : self ._session_id },
0 commit comments