Skip to content

Commit fb35857

Browse files
WHALEEYEWendong-Fanlightaime
authored
refactor: ModelType (#998)
Co-authored-by: Wendong <w3ndong.fan@gmail.com> Co-authored-by: Wendong-Fan <133094783+Wendong-Fan@users.noreply.github.com> Co-authored-by: Guohao Li <lightaime@gmail.com>
1 parent 081d20b commit fb35857

File tree

85 files changed

+2348
-3035
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

85 files changed

+2348
-3035
lines changed

.github/workflows/build_package.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ jobs:
5959
AZURE_OPENAI_API_KEY: "${{ secrets.AZURE_OPENAI_API_KEY }}"
6060
AZURE_API_VERSION: "${{ secrets.AZURE_API_VERSION }}"
6161
AZURE_DEPLOYMENT_NAME: "${{ secrets.AZURE_DEPLOYMENT_NAME }}"
62-
AZURE_OPENAI_ENDPOINT: "${{ secrets.AZURE_OPENAI_ENDPOINT }}"
62+
AZURE_OPENAI_BASE_URL: "${{ secrets.AZURE_OPENAI_BASE_URL }}"
6363
MISTRAL_API_KEY: "${{ secrets.MISTRAL_API_KEY }}"
6464
REKA_API_KEY: "${{ secrets.REKA_API_KEY }}"
6565
NEO4J_URI: "${{ secrets.NEO4J_URI }}"

.github/workflows/pytest_package.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838
AZURE_OPENAI_API_KEY: "to-be-filled"
3939
AZURE_API_VERSION: "to-be-filled"
4040
AZURE_DEPLOYMENT_NAME: "to-be-filled"
41-
AZURE_OPENAI_ENDPOINT: "https://camel.openai.azure.com/"
41+
AZURE_OPENAI_BASE_URL: "https://camel.openai.azure.com/"
4242
MISTRAL_API_KEY: "${{ secrets.MISTRAL_API_KEY }}"
4343
REKA_API_KEY: "${{ secrets.REKA_API_KEY }}"
4444
NEO4J_URI: "${{ secrets.NEO4J_URI }}"
@@ -74,7 +74,7 @@ jobs:
7474
AZURE_OPENAI_API_KEY: "to-be-filled"
7575
AZURE_API_VERSION: "to-be-filled"
7676
AZURE_DEPLOYMENT_NAME: "to-be-filled"
77-
AZURE_OPENAI_ENDPOINT: "https://camel.openai.azure.com/"
77+
AZURE_OPENAI_BASE_URL: "https://camel.openai.azure.com/"
7878
MISTRAL_API_KEY: "${{ secrets.MISTRAL_API_KEY }}"
7979
REKA_API_KEY: "${{ secrets.REKA_API_KEY }}"
8080
NEO4J_URI: "${{ secrets.NEO4J_URI }}"
@@ -110,7 +110,7 @@ jobs:
110110
AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }}"
111111
AZURE_API_VERSION: ${{ secrets.AZURE_API_VERSION }}"
112112
AZURE_DEPLOYMENT_NAME: ${{ secrets.AZURE_DEPLOYMENT_NAME }}"
113-
AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_OPENAI_ENDPOINT }}"
113+
AZURE_OPENAI_BASE_URL: ${{ secrets.AZURE_OPENAI_BASE_URL }}"
114114
MISTRAL_API_KEY: "${{ secrets.MISTRAL_API_KEY }}"
115115
REKA_API_KEY: "${{ secrets.REKA_API_KEY }}"
116116
NEO4J_URI: "${{ secrets.NEO4J_URI }}"

camel/agents/chat_agent.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from pydantic import BaseModel
3535

3636
from camel.agents.base import BaseAgent
37-
from camel.configs import ChatGPTConfig
3837
from camel.memories import (
3938
AgentMemory,
4039
ChatHistoryMemory,
@@ -169,14 +168,13 @@ def __init__(
169168
else ModelFactory.create(
170169
model_platform=ModelPlatformType.OPENAI,
171170
model_type=ModelType.GPT_4O_MINI,
172-
model_config_dict=ChatGPTConfig().as_dict(),
173171
)
174172
)
175173
self.output_language: Optional[str] = output_language
176174
if self.output_language is not None:
177175
self.set_output_language(self.output_language)
178176

179-
self.model_type: ModelType = self.model_backend.model_type
177+
self.model_type = self.model_backend.model_type
180178

181179
# tool registration
182180
external_tools = external_tools or []
@@ -439,12 +437,7 @@ def step(
439437
a boolean indicating whether the chat session has terminated,
440438
and information about the chat session.
441439
"""
442-
if (
443-
isinstance(self.model_type, ModelType)
444-
and "lama" in self.model_type.value
445-
or isinstance(self.model_type, str)
446-
and "lama" in self.model_type
447-
):
440+
if "llama" in self.model_type.lower():
448441
if self.model_backend.model_config_dict.get("tools", None):
449442
tool_prompt = self._generate_tool_prompt(self.tool_schema_list)
450443

@@ -525,10 +518,7 @@ def step(
525518
self._step_tool_call_and_update(response)
526519
)
527520

528-
if (
529-
output_schema is not None
530-
and self.model_type.supports_tool_calling
531-
):
521+
if output_schema is not None:
532522
(
533523
output_messages,
534524
finish_reasons,
@@ -619,7 +609,7 @@ def step(
619609

620610
if (
621611
output_schema is not None
622-
and self.model_type.supports_tool_calling
612+
and self.model_type.support_native_tool_calling
623613
):
624614
(
625615
output_messages,
@@ -727,7 +717,10 @@ async def step_async(
727717
await self._step_tool_call_and_update_async(response)
728718
)
729719

730-
if output_schema is not None and self.model_type.supports_tool_calling:
720+
if (
721+
output_schema is not None
722+
and self.model_type.support_native_tool_calling
723+
):
731724
(
732725
output_messages,
733726
finish_reasons,
@@ -1193,10 +1186,7 @@ def get_usage_dict(
11931186
Returns:
11941187
dict: Usage dictionary.
11951188
"""
1196-
if isinstance(self.model_type, ModelType):
1197-
encoding = get_model_encoding(self.model_type.value_for_tiktoken)
1198-
else:
1199-
encoding = get_model_encoding("gpt-4o-mini")
1189+
encoding = get_model_encoding(self.model_type.value_for_tiktoken)
12001190
completion_tokens = 0
12011191
for message in output_messages:
12021192
completion_tokens += len(encoding.encode(message.content))

camel/configs/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from .litellm_config import LITELLM_API_PARAMS, LiteLLMConfig
1919
from .mistral_config import MISTRAL_API_PARAMS, MistralConfig
2020
from .ollama_config import OLLAMA_API_PARAMS, OllamaConfig
21-
from .openai_config import OPENAI_API_PARAMS, ChatGPTConfig, OpenSourceConfig
21+
from .openai_config import OPENAI_API_PARAMS, ChatGPTConfig
2222
from .reka_config import REKA_API_PARAMS, RekaConfig
2323
from .samba_config import (
2424
SAMBA_CLOUD_API_PARAMS,
@@ -40,7 +40,6 @@
4040
'ANTHROPIC_API_PARAMS',
4141
'GROQ_API_PARAMS',
4242
'GroqConfig',
43-
'OpenSourceConfig',
4443
'LiteLLMConfig',
4544
'LITELLM_API_PARAMS',
4645
'OllamaConfig',

camel/configs/openai_config.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -112,24 +112,3 @@ class ChatGPTConfig(BaseConfig):
112112

113113

114114
OPENAI_API_PARAMS = {param for param in ChatGPTConfig.model_fields.keys()}
115-
116-
117-
class OpenSourceConfig(BaseConfig):
118-
r"""Defines parameters for setting up open-source models and includes
119-
parameters to be passed to chat completion function of OpenAI API.
120-
121-
Args:
122-
model_path (str): The path to a local folder containing the model
123-
files or the model card in HuggingFace hub.
124-
server_url (str): The URL to the server running the model inference
125-
which will be used as the API base of OpenAI API.
126-
api_params (ChatGPTConfig): An instance of :obj:ChatGPTConfig to
127-
contain the arguments to be passed to OpenAI API.
128-
"""
129-
130-
# Maybe the param needs to be renamed.
131-
# Warning: Field "model_path" has conflict with protected namespace
132-
# "model_".
133-
model_path: str
134-
server_url: str
135-
api_params: ChatGPTConfig = Field(default_factory=ChatGPTConfig)

camel/models/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121
from .model_factory import ModelFactory
2222
from .nemotron_model import NemotronModel
2323
from .ollama_model import OllamaModel
24-
from .open_source_model import OpenSourceModel
2524
from .openai_audio_models import OpenAIAudioModels
26-
from .openai_compatibility_model import OpenAICompatibilityModel
25+
from .openai_compatible_model import OpenAICompatibleModel
2726
from .openai_model import OpenAIModel
2827
from .reka_model import RekaModel
2928
from .samba_model import SambaModel
@@ -41,15 +40,14 @@
4140
'GroqModel',
4241
'StubModel',
4342
'ZhipuAIModel',
44-
'OpenSourceModel',
4543
'ModelFactory',
4644
'LiteLLMModel',
4745
'OpenAIAudioModels',
4846
'NemotronModel',
4947
'OllamaModel',
5048
'VLLMModel',
5149
'GeminiModel',
52-
'OpenAICompatibilityModel',
50+
'OpenAICompatibleModel',
5351
'RekaModel',
5452
'SambaModel',
5553
'TogetherAIModel',

camel/models/anthropic_model.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
# limitations under the License.
1313
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
1414
import os
15-
from typing import Any, Dict, List, Optional
15+
from typing import Any, Dict, List, Optional, Union
1616

1717
from anthropic import NOT_GIVEN, Anthropic
1818

19-
from camel.configs import ANTHROPIC_API_PARAMS
19+
from camel.configs import ANTHROPIC_API_PARAMS, AnthropicConfig
2020
from camel.messages import OpenAIMessage
2121
from camel.models.base_model import BaseModelBackend
2222
from camel.types import ChatCompletion, ModelType
@@ -28,36 +28,39 @@
2828

2929

3030
class AnthropicModel(BaseModelBackend):
31-
r"""Anthropic API in a unified BaseModelBackend interface."""
31+
r"""Anthropic API in a unified BaseModelBackend interface.
32+
33+
Args:
34+
model_type (Union[ModelType, str]): Model for which a backend is
35+
created, one of CLAUDE_* series.
36+
model_config_dict (Optional[Dict[str, Any]], optional): A dictionary
37+
that will be fed into Anthropic.messages.create(). If
38+
:obj:`None`, :obj:`AnthropicConfig().as_dict()` will be used.
39+
(default::obj:`None`)
40+
api_key (Optional[str], optional): The API key for authenticating with
41+
the Anthropic service. (default: :obj:`None`)
42+
url (Optional[str], optional): The url to the Anthropic service.
43+
(default: :obj:`None`)
44+
token_counter (Optional[BaseTokenCounter], optional): Token counter to
45+
use for the model. If not provided, :obj:`AnthropicTokenCounter`
46+
will be used. (default: :obj:`None`)
47+
"""
3248

3349
def __init__(
3450
self,
35-
model_type: ModelType,
36-
model_config_dict: Dict[str, Any],
51+
model_type: Union[ModelType, str],
52+
model_config_dict: Optional[Dict[str, Any]] = None,
3753
api_key: Optional[str] = None,
3854
url: Optional[str] = None,
3955
token_counter: Optional[BaseTokenCounter] = None,
4056
) -> None:
41-
r"""Constructor for Anthropic backend.
42-
43-
Args:
44-
model_type (ModelType): Model for which a backend is created,
45-
one of CLAUDE_* series.
46-
model_config_dict (Dict[str, Any]): A dictionary that will
47-
be fed into Anthropic.messages.create().
48-
api_key (Optional[str]): The API key for authenticating with the
49-
Anthropic service. (default: :obj:`None`)
50-
url (Optional[str]): The url to the Anthropic service. (default:
51-
:obj:`None`)
52-
token_counter (Optional[BaseTokenCounter]): Token counter to use
53-
for the model. If not provided, `AnthropicTokenCounter` will
54-
be used.
55-
"""
57+
if model_config_dict is None:
58+
model_config_dict = AnthropicConfig().as_dict()
59+
api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
60+
url = url or os.environ.get("ANTHROPIC_API_BASE_URL")
5661
super().__init__(
5762
model_type, model_config_dict, api_key, url, token_counter
5863
)
59-
self._api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
60-
self._url = url or os.environ.get("ANTHROPIC_API_BASE_URL")
6164
self.client = Anthropic(api_key=self._api_key, base_url=self._url)
6265

6366
def _convert_response_from_anthropic_to_openai(self, response):
@@ -89,7 +92,7 @@ def token_counter(self) -> BaseTokenCounter:
8992
tokenization style.
9093
"""
9194
if not self._token_counter:
92-
self._token_counter = AnthropicTokenCounter(self.model_type)
95+
self._token_counter = AnthropicTokenCounter()
9396
return self._token_counter
9497

9598
def count_tokens_from_prompt(self, prompt: str) -> int:
@@ -123,7 +126,7 @@ def run(
123126
else:
124127
sys_msg = NOT_GIVEN # type: ignore[assignment]
125128
response = self.client.messages.create(
126-
model=self.model_type.value,
129+
model=self.model_type,
127130
system=sys_msg,
128131
messages=messages, # type: ignore[arg-type]
129132
**self.model_config_dict,

0 commit comments

Comments
 (0)