|
13 | 13 | # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= |
14 | 14 |
|
15 | 15 | import os |
16 | | -from typing import Any, Dict, Optional, Union |
| 16 | +from typing import Any, Dict, List, Optional, Type, Union |
| 17 | + |
| 18 | +from pydantic import BaseModel |
17 | 19 |
|
18 | 20 | from camel.configs import ZhipuAIConfig |
| 21 | +from camel.logger import get_logger |
| 22 | +from camel.messages import OpenAIMessage |
| 23 | +from camel.models._utils import try_modify_message_with_format |
19 | 24 | from camel.models.openai_compatible_model import OpenAICompatibleModel |
20 | | -from camel.types import ModelType |
| 25 | +from camel.types import ( |
| 26 | + ChatCompletion, |
| 27 | + ModelType, |
| 28 | +) |
21 | 29 | from camel.utils import ( |
22 | 30 | BaseTokenCounter, |
23 | 31 | api_keys_required, |
24 | 32 | ) |
25 | 33 |
|
| 34 | +logger = get_logger(__name__) |
| 35 | + |
26 | 36 |
|
27 | 37 | class ZhipuAIModel(OpenAICompatibleModel): |
28 | 38 | r"""ZhipuAI API in a unified OpenAICompatibleModel interface. |
@@ -85,3 +95,52 @@ def __init__( |
85 | 95 | max_retries=max_retries, |
86 | 96 | **kwargs, |
87 | 97 | ) |
| 98 | + |
| 99 | + def _request_parse( |
| 100 | + self, |
| 101 | + messages: List[OpenAIMessage], |
| 102 | + response_format: Type[BaseModel], |
| 103 | + tools: Optional[List[Dict[str, Any]]] = None, |
| 104 | + ) -> ChatCompletion: |
| 105 | + import copy |
| 106 | + |
| 107 | + request_config = copy.deepcopy(self.model_config_dict) |
| 108 | + request_config.pop("stream", None) |
| 109 | + if tools is not None: |
| 110 | + request_config["tools"] = tools |
| 111 | + |
| 112 | + try_modify_message_with_format(messages[-1], response_format) |
| 113 | + request_config["response_format"] = {"type": "json_object"} |
| 114 | + try: |
| 115 | + return self._client.beta.chat.completions.parse( |
| 116 | + messages=messages, |
| 117 | + model=self.model_type, |
| 118 | + **request_config, |
| 119 | + ) |
| 120 | + except Exception as e: |
| 121 | + logger.error(f"Fallback attempt also failed: {e}") |
| 122 | + raise |
| 123 | + |
| 124 | + async def _arequest_parse( |
| 125 | + self, |
| 126 | + messages: List[OpenAIMessage], |
| 127 | + response_format: Type[BaseModel], |
| 128 | + tools: Optional[List[Dict[str, Any]]] = None, |
| 129 | + ) -> ChatCompletion: |
| 130 | + import copy |
| 131 | + |
| 132 | + request_config = copy.deepcopy(self.model_config_dict) |
| 133 | + request_config.pop("stream", None) |
| 134 | + if tools is not None: |
| 135 | + request_config["tools"] = tools |
| 136 | + try_modify_message_with_format(messages[-1], response_format) |
| 137 | + request_config["response_format"] = {"type": "json_object"} |
| 138 | + try: |
| 139 | + return await self._async_client.beta.chat.completions.parse( |
| 140 | + messages=messages, |
| 141 | + model=self.model_type, |
| 142 | + **request_config, |
| 143 | + ) |
| 144 | + except Exception as e: |
| 145 | + logger.error(f"Fallback attempt also failed: {e}") |
| 146 | + raise |
0 commit comments