|
16 | 16 |
|
17 | 17 | import abc
|
18 | 18 | import json
|
19 |
| -from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union, cast |
| 19 | +from typing import TYPE_CHECKING, Any, List, Optional, Iterable, Union, cast |
20 | 20 | from openai.types.chat import (
|
21 | 21 | ChatCompletionMessageParam,
|
22 |
| - ChatCompletionSystemMessageParam, |
23 |
| - ChatCompletionUserMessageParam, |
24 |
| - ChatCompletionAssistantMessageParam, |
25 | 22 | ChatCompletionToolParam,
|
26 | 23 | )
|
27 | 24 |
|
|
38 | 35 | MessageList,
|
39 | 36 | ToolCall,
|
40 | 37 | ToolCallResponse,
|
| 38 | + SystemMessage, |
| 39 | + UserMessage, |
41 | 40 | )
|
42 | 41 |
|
43 | 42 | if TYPE_CHECKING:
|
@@ -77,51 +76,20 @@ def get_messages(
|
77 | 76 | input: str,
|
78 | 77 | message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
|
79 | 78 | system_instruction: Optional[str] = None,
|
80 |
| - ) -> Sequence[ |
81 |
| - ChatCompletionMessageParam |
82 |
| - ]: # Returns messages compatible with OpenAI's API |
83 |
| - messages: List[ChatCompletionMessageParam] = [] |
| 79 | + ) -> Iterable[ChatCompletionMessageParam]: |
| 80 | + messages = [] |
84 | 81 | if system_instruction:
|
85 |
| - messages.append( |
86 |
| - ChatCompletionSystemMessageParam( |
87 |
| - role="system", content=str(system_instruction) |
88 |
| - ) |
89 |
| - ) |
| 82 | + messages.append(SystemMessage(content=system_instruction).model_dump()) |
90 | 83 | if message_history:
|
91 | 84 | if isinstance(message_history, MessageHistory):
|
92 | 85 | message_history = message_history.messages
|
93 | 86 | try:
|
94 | 87 | MessageList(messages=cast(list[BaseMessage], message_history))
|
95 | 88 | except ValidationError as e:
|
96 | 89 | raise LLMGenerationError(e.errors()) from e
|
97 |
| - for msg in message_history: |
98 |
| - if isinstance(msg, dict): |
99 |
| - role = msg.get("role") |
100 |
| - content = msg.get("content") |
101 |
| - else: |
102 |
| - msg_dict = msg.model_dump() |
103 |
| - role = msg_dict.get("role") |
104 |
| - content = msg_dict.get("content") |
105 |
| - if role == "system": |
106 |
| - messages.append( |
107 |
| - ChatCompletionSystemMessageParam( |
108 |
| - role="system", content=str(content) |
109 |
| - ) |
110 |
| - ) |
111 |
| - elif role == "user": |
112 |
| - messages.append( |
113 |
| - ChatCompletionUserMessageParam( |
114 |
| - role="user", content=str(content) |
115 |
| - ) |
116 |
| - ) |
117 |
| - elif role == "assistant": |
118 |
| - messages.append( |
119 |
| - ChatCompletionAssistantMessageParam( |
120 |
| - role="assistant", content=str(content) |
121 |
| - ) |
122 |
| - ) |
123 |
| - messages.append(ChatCompletionUserMessageParam(role="user", content=str(input))) |
124 |
| - return messages |
| 90 | + messages.extend(cast(Iterable[dict[str, Any]], message_history)) |
| 91 | + messages.append(UserMessage(content=input).model_dump()) |
| 92 | + return messages # type: ignore |
125 | 93 |
|
126 | 94 | def invoke(
|
127 | 95 | self,
|
|
0 commit comments