28
28
from neo4j_graphrag .generation .types import RagInitModel , RagResultModel , RagSearchModel
29
29
from neo4j_graphrag .llm import LLMInterface
30
30
from neo4j_graphrag .llm .types import LLMMessage
31
- from neo4j_graphrag .message_history import InMemoryMessageHistory , MessageHistory
31
+ from neo4j_graphrag .message_history import MessageHistory
32
32
from neo4j_graphrag .retrievers .base import Retriever
33
33
from neo4j_graphrag .types import RetrieverResult
34
34
@@ -129,8 +129,8 @@ def search(
129
129
)
130
130
except ValidationError as e :
131
131
raise SearchValidationError (e .errors ())
132
- if isinstance (message_history , list ):
133
- message_history = InMemoryMessageHistory ( messages = message_history )
132
+ if isinstance (message_history , MessageHistory ):
133
+ message_history = message_history . messages
134
134
query = self .build_query (validated_data .query_text , message_history )
135
135
retriever_result : RetrieverResult = self .retriever .search (
136
136
query_text = query , ** validated_data .retriever_config
@@ -143,7 +143,7 @@ def search(
143
143
logger .debug (f"RAG: prompt={ prompt } " )
144
144
answer = self .llm .invoke (
145
145
prompt ,
146
- message_history . messages if message_history else None ,
146
+ message_history ,
147
147
system_instruction = self .prompt_template .system_instructions ,
148
148
)
149
149
result : dict [str , Any ] = {"answer" : answer .content }
@@ -158,8 +158,8 @@ def build_query(
158
158
) -> str :
159
159
summary_system_message = "You are a summarization assistant. Summarize the given text in no more than 300 words."
160
160
if message_history :
161
- if isinstance (message_history , list ):
162
- message_history = InMemoryMessageHistory ( messages = message_history )
161
+ if isinstance (message_history , MessageHistory ):
162
+ message_history = message_history . messages
163
163
summarization_prompt = self .chat_summary_prompt (
164
164
message_history = message_history
165
165
)
@@ -173,11 +173,10 @@ def build_query(
173
173
def chat_summary_prompt (
174
174
self , message_history : Union [List [LLMMessage ], MessageHistory ]
175
175
) -> str :
176
- if isinstance (message_history , list ):
177
- message_history = InMemoryMessageHistory ( messages = message_history )
176
+ if isinstance (message_history , MessageHistory ):
177
+ message_history = message_history . messages
178
178
message_list = [
179
- f"{ message ['role' ]} : { message ['content' ]} "
180
- for message in message_history .messages
179
+ f"{ message ['role' ]} : { message ['content' ]} " for message in message_history
181
180
]
182
181
history = "\n " .join (message_list )
183
182
return f"""
0 commit comments