|
12 | 12 | # limitations under the License. |
13 | 13 | # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== |
14 | 14 | import os |
15 | | -from typing import Any, Dict, List, Optional |
| 15 | +from typing import Any, Dict, List, Optional, Union |
16 | 16 |
|
17 | 17 | from anthropic import NOT_GIVEN, Anthropic |
18 | 18 |
|
19 | | -from camel.configs import ANTHROPIC_API_PARAMS |
| 19 | +from camel.configs import ANTHROPIC_API_PARAMS, AnthropicConfig |
20 | 20 | from camel.messages import OpenAIMessage |
21 | 21 | from camel.models.base_model import BaseModelBackend |
22 | 22 | from camel.types import ChatCompletion, ModelType |
|
28 | 28 |
|
29 | 29 |
|
30 | 30 | class AnthropicModel(BaseModelBackend): |
31 | | - r"""Anthropic API in a unified BaseModelBackend interface.""" |
| 31 | + r"""Anthropic API in a unified BaseModelBackend interface. |
| 32 | +
|
| 33 | + Args: |
| 34 | + model_type (Union[ModelType, str]): Model for which a backend is |
| 35 | + created, one of CLAUDE_* series. |
| 36 | + model_config_dict (Optional[Dict[str, Any]], optional): A dictionary |
| 37 | + that will be fed into Anthropic.messages.create(). If |
| 38 | + :obj:`None`, :obj:`AnthropicConfig().as_dict()` will be used. |
| 39 | + (default::obj:`None`) |
| 40 | + api_key (Optional[str], optional): The API key for authenticating with |
| 41 | + the Anthropic service. (default: :obj:`None`) |
| 42 | + url (Optional[str], optional): The url to the Anthropic service. |
| 43 | + (default: :obj:`None`) |
| 44 | + token_counter (Optional[BaseTokenCounter], optional): Token counter to |
| 45 | + use for the model. If not provided, :obj:`AnthropicTokenCounter` |
| 46 | + will be used. (default: :obj:`None`) |
| 47 | + """ |
32 | 48 |
|
33 | 49 | def __init__( |
34 | 50 | self, |
35 | | - model_type: ModelType, |
36 | | - model_config_dict: Dict[str, Any], |
| 51 | + model_type: Union[ModelType, str], |
| 52 | + model_config_dict: Optional[Dict[str, Any]] = None, |
37 | 53 | api_key: Optional[str] = None, |
38 | 54 | url: Optional[str] = None, |
39 | 55 | token_counter: Optional[BaseTokenCounter] = None, |
40 | 56 | ) -> None: |
41 | | - r"""Constructor for Anthropic backend. |
42 | | -
|
43 | | - Args: |
44 | | - model_type (ModelType): Model for which a backend is created, |
45 | | - one of CLAUDE_* series. |
46 | | - model_config_dict (Dict[str, Any]): A dictionary that will |
47 | | - be fed into Anthropic.messages.create(). |
48 | | - api_key (Optional[str]): The API key for authenticating with the |
49 | | - Anthropic service. (default: :obj:`None`) |
50 | | - url (Optional[str]): The url to the Anthropic service. (default: |
51 | | - :obj:`None`) |
52 | | - token_counter (Optional[BaseTokenCounter]): Token counter to use |
53 | | - for the model. If not provided, `AnthropicTokenCounter` will |
54 | | - be used. |
55 | | - """ |
| 57 | + if model_config_dict is None: |
| 58 | + model_config_dict = AnthropicConfig().as_dict() |
| 59 | + api_key = api_key or os.environ.get("ANTHROPIC_API_KEY") |
| 60 | + url = url or os.environ.get("ANTHROPIC_API_BASE_URL") |
56 | 61 | super().__init__( |
57 | 62 | model_type, model_config_dict, api_key, url, token_counter |
58 | 63 | ) |
59 | | - self._api_key = api_key or os.environ.get("ANTHROPIC_API_KEY") |
60 | | - self._url = url or os.environ.get("ANTHROPIC_API_BASE_URL") |
61 | 64 | self.client = Anthropic(api_key=self._api_key, base_url=self._url) |
62 | 65 |
|
63 | 66 | def _convert_response_from_anthropic_to_openai(self, response): |
@@ -89,7 +92,7 @@ def token_counter(self) -> BaseTokenCounter: |
89 | 92 | tokenization style. |
90 | 93 | """ |
91 | 94 | if not self._token_counter: |
92 | | - self._token_counter = AnthropicTokenCounter(self.model_type) |
| 95 | + self._token_counter = AnthropicTokenCounter() |
93 | 96 | return self._token_counter |
94 | 97 |
|
95 | 98 | def count_tokens_from_prompt(self, prompt: str) -> int: |
@@ -123,7 +126,7 @@ def run( |
123 | 126 | else: |
124 | 127 | sys_msg = NOT_GIVEN # type: ignore[assignment] |
125 | 128 | response = self.client.messages.create( |
126 | | - model=self.model_type.value, |
| 129 | + model=self.model_type, |
127 | 130 | system=sys_msg, |
128 | 131 | messages=messages, # type: ignore[arg-type] |
129 | 132 | **self.model_config_dict, |
|
0 commit comments