Skip to content

Commit 2dac7a6

Browse files
committed
Updated LLMs
1 parent 7eb0326 commit 2dac7a6

File tree

8 files changed

+107
-44
lines changed

8 files changed

+107
-44
lines changed

examples/customize/llms/custom_llm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import random
22
import string
3-
from typing import Any, Optional
3+
from typing import Any, List, Optional, Union
44

55
from neo4j_graphrag.llm import LLMInterface, LLMResponse
66
from neo4j_graphrag.llm.types import LLMMessage
7+
from neo4j_graphrag.message_history import MessageHistory
78

89

910
class CustomLLM(LLMInterface):
@@ -15,7 +16,7 @@ def __init__(
1516
def invoke(
1617
self,
1718
input: str,
18-
message_history: Optional[list[LLMMessage]] = None,
19+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
1920
system_instruction: Optional[str] = None,
2021
) -> LLMResponse:
2122
content: str = (
@@ -26,7 +27,7 @@ def invoke(
2627
async def ainvoke(
2728
self,
2829
input: str,
29-
message_history: Optional[list[LLMMessage]] = None,
30+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
3031
system_instruction: Optional[str] = None,
3132
) -> LLMResponse:
3233
raise NotImplementedError()

src/neo4j_graphrag/llm/anthropic_llm.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
from typing import TYPE_CHECKING, Any, Iterable, Optional, cast
16+
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union, cast
1717

1818
from pydantic import ValidationError
1919

@@ -26,6 +26,7 @@
2626
MessageList,
2727
UserMessage,
2828
)
29+
from neo4j_graphrag.message_history import MessageHistory
2930

3031
if TYPE_CHECKING:
3132
from anthropic.types.message_param import MessageParam
@@ -76,10 +77,14 @@ def __init__(
7677
self.async_client = anthropic.AsyncAnthropic(**kwargs)
7778

7879
def get_messages(
79-
self, input: str, message_history: Optional[list[LLMMessage]] = None
80+
self,
81+
input: str,
82+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
8083
) -> Iterable[MessageParam]:
8184
messages: list[dict[str, str]] = []
8285
if message_history:
86+
if isinstance(message_history, MessageHistory):
87+
message_history = message_history.messages
8388
try:
8489
MessageList(messages=cast(list[BaseMessage], message_history))
8590
except ValidationError as e:
@@ -91,20 +96,23 @@ def get_messages(
9196
def invoke(
9297
self,
9398
input: str,
94-
message_history: Optional[list[LLMMessage]] = None,
99+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
95100
system_instruction: Optional[str] = None,
96101
) -> LLMResponse:
97102
"""Sends text to the LLM and returns a response.
98103
99104
Args:
100105
input (str): The text to send to the LLM.
101-
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
106+
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
107+
with each message having a specific role assigned.
102108
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
103109
104110
Returns:
105111
LLMResponse: The response from the LLM.
106112
"""
107113
try:
114+
if isinstance(message_history, MessageHistory):
115+
message_history = message_history.messages
108116
messages = self.get_messages(input, message_history)
109117
response = self.client.messages.create(
110118
model=self.model_name,
@@ -124,20 +132,23 @@ def invoke(
124132
async def ainvoke(
125133
self,
126134
input: str,
127-
message_history: Optional[list[LLMMessage]] = None,
135+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
128136
system_instruction: Optional[str] = None,
129137
) -> LLMResponse:
130138
"""Asynchronously sends text to the LLM and returns a response.
131139
132140
Args:
133141
input (str): The text to send to the LLM.
134-
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
142+
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
143+
with each message having a specific role assigned.
135144
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
136145
137146
Returns:
138147
LLMResponse: The response from the LLM.
139148
"""
140149
try:
150+
if isinstance(message_history, MessageHistory):
151+
message_history = message_history.messages
141152
messages = self.get_messages(input, message_history)
142153
response = await self.async_client.messages.create(
143154
model=self.model_name,

src/neo4j_graphrag/llm/base.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
from __future__ import annotations
1616

1717
from abc import ABC, abstractmethod
18-
from typing import Any, Optional
18+
from typing import Any, List, Optional, Union
19+
20+
from neo4j_graphrag.message_history import MessageHistory
1921

2022
from .types import (
2123
LLMMessage,
@@ -45,14 +47,15 @@ def __init__(
4547
def invoke(
4648
self,
4749
input: str,
48-
message_history: Optional[list[LLMMessage]] = None,
50+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
4951
system_instruction: Optional[str] = None,
5052
) -> LLMResponse:
5153
"""Sends a text input to the LLM and retrieves a response.
5254
5355
Args:
5456
input (str): Text sent to the LLM.
55-
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
57+
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
58+
with each message having a specific role assigned.
5659
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
5760
5861
Returns:
@@ -66,14 +69,15 @@ def invoke(
6669
async def ainvoke(
6770
self,
6871
input: str,
69-
message_history: Optional[list[LLMMessage]] = None,
72+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
7073
system_instruction: Optional[str] = None,
7174
) -> LLMResponse:
7275
"""Asynchronously sends a text input to the LLM and retrieves a response.
7376
7477
Args:
7578
input (str): Text sent to the LLM.
76-
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
79+
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
80+
with each message having a specific role assigned.
7781
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
7882
7983
Returns:

src/neo4j_graphrag/llm/cohere_llm.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
from typing import TYPE_CHECKING, Any, Iterable, Optional, cast
17+
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union, cast
1818

1919
from pydantic import ValidationError
2020

@@ -28,6 +28,7 @@
2828
SystemMessage,
2929
UserMessage,
3030
)
31+
from neo4j_graphrag.message_history import MessageHistory
3132

3233
if TYPE_CHECKING:
3334
from cohere import ChatMessages
@@ -78,13 +79,15 @@ def __init__(
7879
def get_messages(
7980
self,
8081
input: str,
81-
message_history: Optional[list[LLMMessage]] = None,
82+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
8283
system_instruction: Optional[str] = None,
8384
) -> ChatMessages:
8485
messages = []
8586
if system_instruction:
8687
messages.append(SystemMessage(content=system_instruction).model_dump())
8788
if message_history:
89+
if isinstance(message_history, MessageHistory):
90+
message_history = message_history.messages
8891
try:
8992
MessageList(messages=cast(list[BaseMessage], message_history))
9093
except ValidationError as e:
@@ -96,20 +99,23 @@ def get_messages(
9699
def invoke(
97100
self,
98101
input: str,
99-
message_history: Optional[list[LLMMessage]] = None,
102+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
100103
system_instruction: Optional[str] = None,
101104
) -> LLMResponse:
102105
"""Sends text to the LLM and returns a response.
103106
104107
Args:
105108
input (str): The text to send to the LLM.
106-
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
109+
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
110+
with each message having a specific role assigned.
107111
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
108112
109113
Returns:
110114
LLMResponse: The response from the LLM.
111115
"""
112116
try:
117+
if isinstance(message_history, MessageHistory):
118+
message_history = message_history.messages
113119
messages = self.get_messages(input, message_history, system_instruction)
114120
res = self.client.chat(
115121
messages=messages,
@@ -124,20 +130,23 @@ def invoke(
124130
async def ainvoke(
125131
self,
126132
input: str,
127-
message_history: Optional[list[LLMMessage]] = None,
133+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
128134
system_instruction: Optional[str] = None,
129135
) -> LLMResponse:
130136
"""Asynchronously sends text to the LLM and returns a response.
131137
132138
Args:
133139
input (str): The text to send to the LLM.
134-
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
140+
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
141+
with each message having a specific role assigned.
135142
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
136143
137144
Returns:
138145
LLMResponse: The response from the LLM.
139146
"""
140147
try:
148+
if isinstance(message_history, MessageHistory):
149+
message_history = message_history.messages
141150
messages = self.get_messages(input, message_history, system_instruction)
142151
res = self.async_client.chat(
143152
messages=messages,

src/neo4j_graphrag/llm/mistralai_llm.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
import os
18-
from typing import Any, Iterable, Optional, cast
18+
from typing import Any, Iterable, List, Optional, Union, cast
1919

2020
from pydantic import ValidationError
2121

@@ -29,6 +29,7 @@
2929
SystemMessage,
3030
UserMessage,
3131
)
32+
from neo4j_graphrag.message_history import MessageHistory
3233

3334
try:
3435
from mistralai import Messages, Mistral
@@ -68,13 +69,15 @@ def __init__(
6869
def get_messages(
6970
self,
7071
input: str,
71-
message_history: Optional[list[LLMMessage]] = None,
72+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
7273
system_instruction: Optional[str] = None,
7374
) -> list[Messages]:
7475
messages = []
7576
if system_instruction:
7677
messages.append(SystemMessage(content=system_instruction).model_dump())
7778
if message_history:
79+
if isinstance(message_history, MessageHistory):
80+
message_history = message_history.messages
7881
try:
7982
MessageList(messages=cast(list[BaseMessage], message_history))
8083
except ValidationError as e:
@@ -86,15 +89,15 @@ def get_messages(
8689
def invoke(
8790
self,
8891
input: str,
89-
message_history: Optional[list[LLMMessage]] = None,
92+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
9093
system_instruction: Optional[str] = None,
9194
) -> LLMResponse:
9295
"""Sends a text input to the Mistral chat completion model
9396
and returns the response's content.
9497
9598
Args:
9699
input (str): Text sent to the LLM.
97-
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
100+
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, with each message having a specific role assigned.
98101
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
99102
100103
Returns:
@@ -104,6 +107,8 @@ def invoke(
104107
LLMGenerationError: If anything goes wrong.
105108
"""
106109
try:
110+
if isinstance(message_history, MessageHistory):
111+
message_history = message_history.messages
107112
messages = self.get_messages(input, message_history, system_instruction)
108113
response = self.client.chat.complete(
109114
model=self.model_name,
@@ -122,15 +127,16 @@ def invoke(
122127
async def ainvoke(
123128
self,
124129
input: str,
125-
message_history: Optional[list[LLMMessage]] = None,
130+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
126131
system_instruction: Optional[str] = None,
127132
) -> LLMResponse:
128133
"""Asynchronously sends a text input to the MistralAI chat
129134
completion model and returns the response's content.
130135
131136
Args:
132137
input (str): Text sent to the LLM.
133-
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
138+
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
139+
with each message having a specific role assigned.
134140
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
135141
136142
Returns:
@@ -140,6 +146,8 @@ async def ainvoke(
140146
LLMGenerationError: If anything goes wrong.
141147
"""
142148
try:
149+
if isinstance(message_history, MessageHistory):
150+
message_history = message_history.messages
143151
messages = self.get_messages(input, message_history, system_instruction)
144152
response = await self.client.chat.complete_async(
145153
model=self.model_name,

0 commit comments

Comments
 (0)