Skip to content

Commit f8092fc

Browse files
Support for conversations with message history (#234)
* Add system_instruction parameter * Add chat_history parameter * Add missing doc strings * Open AI * Add a summary of the chat history to the query embedding * an extra llm call adds to latency and cost, but including the entire - or even part of - the chat history can potentially create a very large embedding context * Anthropic * Change return type of Anthropic get_messages() * from list to anthropic.MessageParam * Cohere * upgrading from Cohere API v1 to API v2, as the v2 is handling chat history in a way that is consistent with the other providers * Mistral * VertexAI * Formatting * Fix mypy errors * Ollama * plus added the `options` parameter to the ollama `chat` call * Override of the system message * an idea of how to override the system instructions for some invokations * Use TYPE_CHECKING for dev dependencies * Formatting * Rename `chat_history` to `message_history` * Use BaseMessage class type * for the type declaration of the `message_history` parameter * System instruction override * Revert BaseMessage class type * bring back list[dicy[str,str]] type declaration for the `message_history` parameter * Fix mypy errors * Update tests * Fix ollama NameError * Fix NameError in unit tests * at the same time as making mypy shut up * Add TypeDict `LLMMessage` * to help with the type declaration of the message history * Simplify the retriever prompt * Fix E2E tests * Unit tests for the system instruction override * Move and rename the prompts * ... for query embedding and summarization to the GraphRAG class * Update changelog * Add missing parameter in example * Add LLMMessage to the docs * Update docs README
1 parent 324fd2c commit f8092fc

22 files changed

+1527
-238
lines changed

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22

33
## Next
44

5+
### Added
6+
- Support for conversations with message history, including a new `message_history` parameter for LLM interactions.
7+
- Ability to include system instructions and override them for specific invocations.
8+
- Summarization of chat history to enhance query embedding and context handling.
9+
10+
### Changed
11+
- Updated LLM implementations to handle message history consistently across providers.
12+
513
## 1.3.0
614

715
### Added

docs/README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
Building the docs requires Python 3.8.1+
44

5-
Ensure the dev dependencies in `pyproject.toml` are installed.
5+
1. Ensure the dev dependencies in `pyproject.toml` are installed.
66

7-
From the root directory, run the Makefile:
7+
2. Add your changes to the appropriate `.rst` source file in `docs/source` directory.
8+
9+
3. From the root directory, run the Makefile:
810

911
```
1012
make -C docs html

docs/source/types.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ LLMResponse
2828
.. autoclass:: neo4j_graphrag.llm.types.LLMResponse
2929

3030

31+
LLMMessage
32+
===========
33+
34+
.. autoclass:: neo4j_graphrag.llm.types.LLMMessage
35+
36+
3137
RagResultModel
3238
==============
3339

examples/customize/llms/custom_llm.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,34 @@
11
import random
22
import string
3-
from typing import Any
3+
from typing import Any, Optional
44

55
from neo4j_graphrag.llm import LLMInterface, LLMResponse
6+
from neo4j_graphrag.llm.types import LLMMessage
67

78

89
class CustomLLM(LLMInterface):
9-
def __init__(self, model_name: str, **kwargs: Any):
10+
def __init__(
11+
self, model_name: str, system_instruction: Optional[str] = None, **kwargs: Any
12+
):
1013
super().__init__(model_name, **kwargs)
1114

12-
def invoke(self, input: str) -> LLMResponse:
15+
def invoke(
16+
self,
17+
input: str,
18+
message_history: Optional[list[LLMMessage]] = None,
19+
system_instruction: Optional[str] = None,
20+
) -> LLMResponse:
1321
content: str = (
1422
self.model_name + ": " + "".join(random.choices(string.ascii_letters, k=30))
1523
)
1624
return LLMResponse(content=content)
1725

18-
async def ainvoke(self, input: str) -> LLMResponse:
26+
async def ainvoke(
27+
self,
28+
input: str,
29+
message_history: Optional[list[LLMMessage]] = None,
30+
system_instruction: Optional[str] = None,
31+
) -> LLMResponse:
1932
raise NotImplementedError()
2033

2134

src/neo4j_graphrag/generation/graphrag.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from neo4j_graphrag.generation.prompts import RagTemplate
2828
from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel
2929
from neo4j_graphrag.llm import LLMInterface
30+
from neo4j_graphrag.llm.types import LLMMessage
3031
from neo4j_graphrag.retrievers.base import Retriever
3132
from neo4j_graphrag.types import RetrieverResult
3233

@@ -83,6 +84,7 @@ def __init__(
8384
def search(
8485
self,
8586
query_text: str = "",
87+
message_history: Optional[list[LLMMessage]] = None,
8688
examples: str = "",
8789
retriever_config: Optional[dict[str, Any]] = None,
8890
return_context: bool | None = None,
@@ -99,14 +101,15 @@ def search(
99101
100102
101103
Args:
102-
query_text (str): The user question
104+
query_text (str): The user question.
105+
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
103106
examples (str): Examples added to the LLM prompt.
104-
retriever_config (Optional[dict]): Parameters passed to the retriever
107+
retriever_config (Optional[dict]): Parameters passed to the retriever.
105108
search method; e.g.: top_k
106-
return_context (bool): Whether to append the retriever result to the final result (default: False)
109+
return_context (bool): Whether to append the retriever result to the final result (default: False).
107110
108111
Returns:
109-
RagResultModel: The LLM-generated answer
112+
RagResultModel: The LLM-generated answer.
110113
111114
"""
112115
if return_context is None:
@@ -124,18 +127,54 @@ def search(
124127
)
125128
except ValidationError as e:
126129
raise SearchValidationError(e.errors())
127-
query_text = validated_data.query_text
130+
query = self.build_query(validated_data.query_text, message_history)
128131
retriever_result: RetrieverResult = self.retriever.search(
129-
query_text=query_text, **validated_data.retriever_config
132+
query_text=query, **validated_data.retriever_config
130133
)
131134
context = "\n".join(item.content for item in retriever_result.items)
132135
prompt = self.prompt_template.format(
133136
query_text=query_text, context=context, examples=validated_data.examples
134137
)
135138
logger.debug(f"RAG: retriever_result={retriever_result}")
136139
logger.debug(f"RAG: prompt={prompt}")
137-
answer = self.llm.invoke(prompt)
140+
answer = self.llm.invoke(prompt, message_history)
138141
result: dict[str, Any] = {"answer": answer.content}
139142
if return_context:
140143
result["retriever_result"] = retriever_result
141144
return RagResultModel(**result)
145+
146+
def build_query(
147+
self, query_text: str, message_history: Optional[list[LLMMessage]] = None
148+
) -> str:
149+
summary_system_message = "You are a summarization assistant. Summarize the given text in no more than 300 words."
150+
if message_history:
151+
summarization_prompt = self.chat_summary_prompt(
152+
message_history=message_history
153+
)
154+
summary = self.llm.invoke(
155+
input=summarization_prompt,
156+
system_instruction=summary_system_message,
157+
).content
158+
return self.conversation_prompt(summary=summary, current_query=query_text)
159+
return query_text
160+
161+
def chat_summary_prompt(self, message_history: list[LLMMessage]) -> str:
162+
message_list = [
163+
": ".join([f"{value}" for _, value in message.items()])
164+
for message in message_history
165+
]
166+
history = "\n".join(message_list)
167+
return f"""
168+
Summarize the message history:
169+
170+
{history}
171+
"""
172+
173+
def conversation_prompt(self, summary: str, current_query: str) -> str:
174+
return f"""
175+
Message Summary:
176+
{summary}
177+
178+
Current Query:
179+
{current_query}
180+
"""

src/neo4j_graphrag/llm/anthropic_llm.py

Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,22 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
from typing import Any, Optional
16+
from typing import Any, Iterable, Optional, TYPE_CHECKING, cast
17+
18+
from pydantic import ValidationError
1719

1820
from neo4j_graphrag.exceptions import LLMGenerationError
1921
from neo4j_graphrag.llm.base import LLMInterface
20-
from neo4j_graphrag.llm.types import LLMResponse
22+
from neo4j_graphrag.llm.types import (
23+
BaseMessage,
24+
LLMMessage,
25+
LLMResponse,
26+
MessageList,
27+
UserMessage,
28+
)
29+
30+
if TYPE_CHECKING:
31+
from anthropic.types.message_param import MessageParam
2132

2233

2334
class AnthropicLLM(LLMInterface):
@@ -26,6 +37,7 @@ class AnthropicLLM(LLMInterface):
2637
Args:
2738
model_name (str, optional): Name of the LLM to use. Defaults to "gemini-1.5-flash-001".
2839
model_params (Optional[dict], optional): Additional parameters passed to the model when text is sent to it. Defaults to None.
40+
system_instruction: Optional[str], optional): Additional instructions for setting the behavior and context for the model in a conversation. Defaults to None.
2941
**kwargs (Any): Arguments passed to the model when for the class is initialised. Defaults to None.
3042
3143
Raises:
@@ -49,6 +61,7 @@ def __init__(
4961
self,
5062
model_name: str,
5163
model_params: Optional[dict[str, Any]] = None,
64+
system_instruction: Optional[str] = None,
5265
**kwargs: Any,
5366
):
5467
try:
@@ -58,55 +71,86 @@ def __init__(
5871
"""Could not import Anthropic Python client.
5972
Please install it with `pip install "neo4j-graphrag[anthropic]"`."""
6073
)
61-
super().__init__(model_name, model_params)
74+
super().__init__(model_name, model_params, system_instruction)
6275
self.anthropic = anthropic
6376
self.client = anthropic.Anthropic(**kwargs)
6477
self.async_client = anthropic.AsyncAnthropic(**kwargs)
6578

66-
def invoke(self, input: str) -> LLMResponse:
79+
def get_messages(
80+
self, input: str, message_history: Optional[list[LLMMessage]] = None
81+
) -> Iterable[MessageParam]:
82+
messages: list[dict[str, str]] = []
83+
if message_history:
84+
try:
85+
MessageList(messages=cast(list[BaseMessage], message_history))
86+
except ValidationError as e:
87+
raise LLMGenerationError(e.errors()) from e
88+
messages.extend(cast(Iterable[dict[str, Any]], message_history))
89+
messages.append(UserMessage(content=input).model_dump())
90+
return messages # type: ignore
91+
92+
def invoke(
93+
self,
94+
input: str,
95+
message_history: Optional[list[LLMMessage]] = None,
96+
system_instruction: Optional[str] = None,
97+
) -> LLMResponse:
6798
"""Sends text to the LLM and returns a response.
6899
69100
Args:
70101
input (str): The text to send to the LLM.
102+
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
103+
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
71104
72105
Returns:
73106
LLMResponse: The response from the LLM.
74107
"""
75108
try:
109+
messages = self.get_messages(input, message_history)
110+
system_message = (
111+
system_instruction
112+
if system_instruction is not None
113+
else self.system_instruction
114+
)
76115
response = self.client.messages.create(
77116
model=self.model_name,
78-
messages=[
79-
{
80-
"role": "user",
81-
"content": input,
82-
}
83-
],
117+
system=system_message, # type: ignore
118+
messages=messages,
84119
**self.model_params,
85120
)
86-
return LLMResponse(content=response.content)
121+
return LLMResponse(content=response.content) # type: ignore
87122
except self.anthropic.APIError as e:
88123
raise LLMGenerationError(e)
89124

90-
async def ainvoke(self, input: str) -> LLMResponse:
125+
async def ainvoke(
126+
self,
127+
input: str,
128+
message_history: Optional[list[LLMMessage]] = None,
129+
system_instruction: Optional[str] = None,
130+
) -> LLMResponse:
91131
"""Asynchronously sends text to the LLM and returns a response.
92132
93133
Args:
94134
input (str): The text to send to the LLM.
135+
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
136+
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
95137
96138
Returns:
97139
LLMResponse: The response from the LLM.
98140
"""
99141
try:
142+
messages = self.get_messages(input, message_history)
143+
system_message = (
144+
system_instruction
145+
if system_instruction is not None
146+
else self.system_instruction
147+
)
100148
response = await self.async_client.messages.create(
101149
model=self.model_name,
102-
messages=[
103-
{
104-
"role": "user",
105-
"content": input,
106-
}
107-
],
150+
system=system_message, # type: ignore
151+
messages=messages,
108152
**self.model_params,
109153
)
110-
return LLMResponse(content=response.content)
154+
return LLMResponse(content=response.content) # type: ignore
111155
except self.anthropic.APIError as e:
112156
raise LLMGenerationError(e)

src/neo4j_graphrag/llm/base.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from abc import ABC, abstractmethod
1818
from typing import Any, Optional
1919

20-
from .types import LLMResponse
20+
from .types import LLMMessage, LLMResponse
2121

2222

2323
class LLMInterface(ABC):
@@ -26,24 +26,34 @@ class LLMInterface(ABC):
2626
Args:
2727
model_name (str): The name of the language model.
2828
model_params (Optional[dict], optional): Additional parameters passed to the model when text is sent to it. Defaults to None.
29+
system_instruction: Optional[str], optional): Additional instructions for setting the behavior and context for the model in a conversation. Defaults to None.
2930
**kwargs (Any): Arguments passed to the model when for the class is initialised. Defaults to None.
3031
"""
3132

3233
def __init__(
3334
self,
3435
model_name: str,
3536
model_params: Optional[dict[str, Any]] = None,
37+
system_instruction: Optional[str] = None,
3638
**kwargs: Any,
3739
):
3840
self.model_name = model_name
3941
self.model_params = model_params or {}
42+
self.system_instruction = system_instruction
4043

4144
@abstractmethod
42-
def invoke(self, input: str) -> LLMResponse:
45+
def invoke(
46+
self,
47+
input: str,
48+
message_history: Optional[list[LLMMessage]] = None,
49+
system_instruction: Optional[str] = None,
50+
) -> LLMResponse:
4351
"""Sends a text input to the LLM and retrieves a response.
4452
4553
Args:
46-
input (str): Text sent to the LLM
54+
input (str): Text sent to the LLM.
55+
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
56+
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
4757
4858
Returns:
4959
LLMResponse: The response from the LLM.
@@ -53,11 +63,18 @@ def invoke(self, input: str) -> LLMResponse:
5363
"""
5464

5565
@abstractmethod
56-
async def ainvoke(self, input: str) -> LLMResponse:
66+
async def ainvoke(
67+
self,
68+
input: str,
69+
message_history: Optional[list[LLMMessage]] = None,
70+
system_instruction: Optional[str] = None,
71+
) -> LLMResponse:
5772
"""Asynchronously sends a text input to the LLM and retrieves a response.
5873
5974
Args:
60-
input (str): Text sent to the LLM
75+
input (str): Text sent to the LLM.
76+
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
77+
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
6178
6279
Returns:
6380
LLMResponse: The response from the LLM.

0 commit comments

Comments
 (0)