Skip to content

Commit c79bdf0

Browse files
committed
Try to remove typing
1 parent 882915f commit c79bdf0

File tree

1 file changed

+9
-41
lines changed

1 file changed

+9
-41
lines changed

src/neo4j_graphrag/llm/openai_llm.py

Lines changed: 9 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,9 @@
1616

1717
import abc
1818
import json
19-
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union, cast
19+
from typing import TYPE_CHECKING, Any, List, Optional, Iterable, Union, cast
2020
from openai.types.chat import (
2121
ChatCompletionMessageParam,
22-
ChatCompletionSystemMessageParam,
23-
ChatCompletionUserMessageParam,
24-
ChatCompletionAssistantMessageParam,
2522
ChatCompletionToolParam,
2623
)
2724

@@ -38,6 +35,8 @@
3835
MessageList,
3936
ToolCall,
4037
ToolCallResponse,
38+
SystemMessage,
39+
UserMessage,
4140
)
4241

4342
if TYPE_CHECKING:
@@ -77,51 +76,20 @@ def get_messages(
7776
input: str,
7877
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
7978
system_instruction: Optional[str] = None,
80-
) -> Sequence[
81-
ChatCompletionMessageParam
82-
]: # Returns messages compatible with OpenAI's API
83-
messages: List[ChatCompletionMessageParam] = []
79+
) -> Iterable[ChatCompletionMessageParam]:
80+
messages = []
8481
if system_instruction:
85-
messages.append(
86-
ChatCompletionSystemMessageParam(
87-
role="system", content=str(system_instruction)
88-
)
89-
)
82+
messages.append(SystemMessage(content=system_instruction).model_dump())
9083
if message_history:
9184
if isinstance(message_history, MessageHistory):
9285
message_history = message_history.messages
9386
try:
9487
MessageList(messages=cast(list[BaseMessage], message_history))
9588
except ValidationError as e:
9689
raise LLMGenerationError(e.errors()) from e
97-
for msg in message_history:
98-
if isinstance(msg, dict):
99-
role = msg.get("role")
100-
content = msg.get("content")
101-
else:
102-
msg_dict = msg.model_dump()
103-
role = msg_dict.get("role")
104-
content = msg_dict.get("content")
105-
if role == "system":
106-
messages.append(
107-
ChatCompletionSystemMessageParam(
108-
role="system", content=str(content)
109-
)
110-
)
111-
elif role == "user":
112-
messages.append(
113-
ChatCompletionUserMessageParam(
114-
role="user", content=str(content)
115-
)
116-
)
117-
elif role == "assistant":
118-
messages.append(
119-
ChatCompletionAssistantMessageParam(
120-
role="assistant", content=str(content)
121-
)
122-
)
123-
messages.append(ChatCompletionUserMessageParam(role="user", content=str(input)))
124-
return messages
90+
messages.extend(cast(Iterable[dict[str, Any]], message_history))
91+
messages.append(UserMessage(content=input).model_dump())
92+
return messages # type: ignore
12593

12694
def invoke(
12795
self,

0 commit comments

Comments
 (0)