Skip to content

Commit 756a495

Browse files
committed
Refactored graphrag
1 parent 803596f commit 756a495

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

src/neo4j_graphrag/generation/graphrag.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel
2929
from neo4j_graphrag.llm import LLMInterface
3030
from neo4j_graphrag.llm.types import LLMMessage
31-
from neo4j_graphrag.message_history import InMemoryMessageHistory, MessageHistory
31+
from neo4j_graphrag.message_history import MessageHistory
3232
from neo4j_graphrag.retrievers.base import Retriever
3333
from neo4j_graphrag.types import RetrieverResult
3434

@@ -129,8 +129,8 @@ def search(
129129
)
130130
except ValidationError as e:
131131
raise SearchValidationError(e.errors())
132-
if isinstance(message_history, list):
133-
message_history = InMemoryMessageHistory(messages=message_history)
132+
if isinstance(message_history, MessageHistory):
133+
message_history = message_history.messages
134134
query = self.build_query(validated_data.query_text, message_history)
135135
retriever_result: RetrieverResult = self.retriever.search(
136136
query_text=query, **validated_data.retriever_config
@@ -143,7 +143,7 @@ def search(
143143
logger.debug(f"RAG: prompt={prompt}")
144144
answer = self.llm.invoke(
145145
prompt,
146-
message_history.messages if message_history else None,
146+
message_history,
147147
system_instruction=self.prompt_template.system_instructions,
148148
)
149149
result: dict[str, Any] = {"answer": answer.content}
@@ -158,8 +158,8 @@ def build_query(
158158
) -> str:
159159
summary_system_message = "You are a summarization assistant. Summarize the given text in no more than 300 words."
160160
if message_history:
161-
if isinstance(message_history, list):
162-
message_history = InMemoryMessageHistory(messages=message_history)
161+
if isinstance(message_history, MessageHistory):
162+
message_history = message_history.messages
163163
summarization_prompt = self.chat_summary_prompt(
164164
message_history=message_history
165165
)
@@ -173,11 +173,10 @@ def build_query(
173173
def chat_summary_prompt(
174174
self, message_history: Union[List[LLMMessage], MessageHistory]
175175
) -> str:
176-
if isinstance(message_history, list):
177-
message_history = InMemoryMessageHistory(messages=message_history)
176+
if isinstance(message_history, MessageHistory):
177+
message_history = message_history.messages
178178
message_list = [
179-
f"{message['role']}: {message['content']}"
180-
for message in message_history.messages
179+
f"{message['role']}: {message['content']}" for message in message_history
181180
]
182181
history = "\n".join(message_list)
183182
return f"""

0 commit comments

Comments
 (0)