Skip to content

Commit b2e94d8

Browse files
committed
Makes the build_query and chat_summary_prompt methods in the GraphRAG class private
1 parent baa0bed commit b2e94d8

File tree

2 files changed

+6
-12
lines changed

2 files changed

+6
-12
lines changed

src/neo4j_graphrag/generation/graphrag.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def search(
131131
raise SearchValidationError(e.errors())
132132
if isinstance(message_history, MessageHistory):
133133
message_history = message_history.messages
134-
query = self.build_query(validated_data.query_text, message_history)
134+
query = self._build_query(validated_data.query_text, message_history)
135135
retriever_result: RetrieverResult = self.retriever.search(
136136
query_text=query, **validated_data.retriever_config
137137
)
@@ -151,16 +151,14 @@ def search(
151151
result["retriever_result"] = retriever_result
152152
return RagResultModel(**result)
153153

154-
def build_query(
154+
def _build_query(
155155
self,
156156
query_text: str,
157-
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
157+
message_history: Optional[List[LLMMessage]] = None,
158158
) -> str:
159159
summary_system_message = "You are a summarization assistant. Summarize the given text in no more than 300 words."
160160
if message_history:
161-
if isinstance(message_history, MessageHistory):
162-
message_history = message_history.messages
163-
summarization_prompt = self.chat_summary_prompt(
161+
summarization_prompt = self._chat_summary_prompt(
164162
message_history=message_history
165163
)
166164
summary = self.llm.invoke(
@@ -170,11 +168,7 @@ def build_query(
170168
return self.conversation_prompt(summary=summary, current_query=query_text)
171169
return query_text
172170

173-
def chat_summary_prompt(
174-
self, message_history: Union[List[LLMMessage], MessageHistory]
175-
) -> str:
176-
if isinstance(message_history, MessageHistory):
177-
message_history = message_history.messages
171+
def _chat_summary_prompt(self, message_history: List[LLMMessage]) -> str:
178172
message_list = [
179173
f"{message['role']}: {message['content']}" for message in message_history
180174
]

tests/unit/test_graphrag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def test_chat_summary_template(retriever_mock: MagicMock, llm: MagicMock) -> Non
294294
retriever=retriever_mock,
295295
llm=llm,
296296
)
297-
prompt = rag.chat_summary_prompt(message_history=message_history) # type: ignore
297+
prompt = rag._chat_summary_prompt(message_history=message_history) # type: ignore
298298
assert (
299299
prompt
300300
== """

0 commit comments

Comments
 (0)