Skip to content

Commit 770eb89

Browse files
authored
Merge branch 'main' into update-generate-reply-function
2 parents a87f4c2 + c3334b9 commit 770eb89

File tree

18 files changed

+102
-48
lines changed

18 files changed

+102
-48
lines changed

autogen/agentchat/group/safeguards/events.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,34 @@
55
from __future__ import annotations
66

77
from collections.abc import Callable
8-
from typing import Any
8+
from typing import Any, Literal
99
from uuid import UUID
1010

1111
from termcolor import colored
1212

1313
from ....events.base_event import BaseEvent, wrap_event
1414

15+
# Type for termcolor colors
16+
TermColor = Literal[
17+
"black",
18+
"grey",
19+
"red",
20+
"green",
21+
"yellow",
22+
"blue",
23+
"magenta",
24+
"cyan",
25+
"light_grey",
26+
"dark_grey",
27+
"light_red",
28+
"light_green",
29+
"light_yellow",
30+
"light_blue",
31+
"light_magenta",
32+
"light_cyan",
33+
"white",
34+
]
35+
1536

1637
@wrap_event
1738
class SafeguardEvent(BaseEvent):
@@ -52,7 +73,7 @@ def print(self, f: Callable[..., Any] | None = None) -> None:
5273
f = f or print
5374

5475
# Choose color based on event type
55-
color = "green"
76+
color: TermColor = "green"
5677
if self.event_type == "load":
5778
color = "green"
5879
elif self.event_type == "check":

autogen/interop/pydantic_ai/pydantic_ai.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def inject_params(
6363

6464
@wraps(f)
6565
def wrapper(*args: Any, **kwargs: Any) -> Any:
66-
current_retry = 0 if ctx_typed is None else ctx_typed.retries.get(tool_typed.name, 0)
66+
current_retry = 0 if ctx_typed is None else ctx_typed.retries.get(tool_typed.name, 0) # type: ignore[attr-defined]
6767

6868
if current_retry >= max_retries:
6969
raise ValueError(f"{tool_typed.name} failed after {max_retries} retries")
@@ -73,12 +73,12 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
7373
kwargs.pop("ctx", None)
7474
ctx_typed.retry = current_retry
7575
result = f(**kwargs, ctx=ctx_typed) # type: ignore[call-arg]
76-
ctx_typed.retries[tool_typed.name] = 0
76+
ctx_typed.retries[tool_typed.name] = 0 # type: ignore[attr-defined]
7777
else:
7878
result = f(**kwargs) # type: ignore[call-arg]
7979
except Exception as e:
8080
if ctx_typed is not None:
81-
ctx_typed.retries[tool_typed.name] = ctx_typed.retries.get(tool_typed.name, 0) + 1
81+
ctx_typed.retries[tool_typed.name] = ctx_typed.retries.get(tool_typed.name, 0) + 1 # type: ignore[attr-defined]
8282
raise e
8383

8484
return result
@@ -145,6 +145,10 @@ def convert_tool(cls, tool: Any, deps: Any = None, **kwargs: Any) -> Tool:
145145
else None
146146
)
147147

148+
# Initialize retries dict for tracking retry counts per tool
149+
if ctx is not None:
150+
ctx.retries = {} # type: ignore[attr-defined]
151+
148152
func = PydanticAIInteroperability.inject_params(
149153
ctx=ctx,
150154
tool=pydantic_ai_tool,
@@ -154,7 +158,7 @@ def convert_tool(cls, tool: Any, deps: Any = None, **kwargs: Any) -> Tool:
154158
name=pydantic_ai_tool.name,
155159
description=pydantic_ai_tool.description,
156160
func_or_tool=func,
157-
parameters_json_schema=pydantic_ai_tool.function_schema.json_schema,
161+
parameters_json_schema=pydantic_ai_tool.function_schema.json_schema, # type: ignore[attr-defined]
158162
)
159163

160164
@classmethod

autogen/llm_config/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class ModelClient(Protocol):
3232
class ModelClientResponseProtocol(Protocol):
3333
class Choice(Protocol):
3434
class Message(Protocol):
35-
content: str | dict[str, Any] | list[dict[str, Any]]
35+
content: str | dict[str, Any] | list[dict[str, Any]] | None
3636

3737
message: Message
3838

autogen/oai/anthropic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ def create_client(self):
150150

151151
@require_optional_import("anthropic", "anthropic")
152152
class AnthropicClient:
153+
RESPONSE_USAGE_KEYS: list[str] = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]
154+
153155
def __init__(self, **kwargs: Unpack[AnthropicEntryDict]):
154156
"""Initialize the Anthropic API client.
155157

autogen/oai/bedrock.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ def create_client(self):
9797
class BedrockClient:
9898
"""Client for Amazon's Bedrock Converse API."""
9999

100+
RESPONSE_USAGE_KEYS: list[str] = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]
101+
100102
_retries = 5
101103

102104
def __init__(self, **kwargs: Unpack[BedrockEntryDict]):

autogen/oai/cerebras.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ def create_client(self):
7575
class CerebrasClient:
7676
"""Client for Cerebras's API."""
7777

78+
RESPONSE_USAGE_KEYS: list[str] = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]
79+
7880
def __init__(self, api_key=None, **kwargs: Unpack[CerebrasEntryDict]):
7981
"""Requires api_key or environment variable to be set
8082

autogen/oai/client.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from pydantic import BaseModel, Field, HttpUrl
2020
from pydantic.type_adapter import TypeAdapter
2121

22+
from autogen.oai.oai_models.chat_completion import ChatCompletionExtended
23+
2224
from ..cache import Cache
2325
from ..code_utils import content_str
2426
from ..doc_utils import export_module
@@ -57,10 +59,10 @@
5759

5860
if openai.__version__ >= "1.1.0":
5961
TOOL_ENABLED = True
60-
ERROR = None
62+
ERROR: ImportError | None = None
6163
from openai.lib._pydantic import _ensure_strict_json_schema
6264
else:
63-
ERROR: ImportError | None = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
65+
ERROR = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.") # type: ignore[assignment]
6466

6567
# OpenAI = object
6668
# AzureOpenAI = object
@@ -77,7 +79,7 @@
7779
if cerebras_result.is_successful:
7880
cerebras_import_exception: ImportError | None = None
7981
else:
80-
cerebras_AuthenticationError = cerebras_InternalServerError = cerebras_RateLimitError = Exception # noqa: N816
82+
cerebras_AuthenticationError = cerebras_InternalServerError = cerebras_RateLimitError = Exception # type: ignore[assignment,misc] # noqa: N816
8183
cerebras_import_exception = ImportError("cerebras_cloud_sdk not found")
8284

8385
with optional_import_block() as gemini_result:
@@ -91,7 +93,7 @@
9193
if gemini_result.is_successful:
9294
gemini_import_exception: ImportError | None = None
9395
else:
94-
gemini_InternalServerError = gemini_ResourceExhausted = Exception # noqa: N816
96+
gemini_InternalServerError = gemini_ResourceExhausted = Exception # type: ignore[assignment,misc] # noqa: N816
9597
gemini_import_exception = ImportError("google-genai not found")
9698

9799
with optional_import_block() as anthropic_result:
@@ -105,7 +107,7 @@
105107
if anthropic_result.is_successful:
106108
anthropic_import_exception: ImportError | None = None
107109
else:
108-
anthorpic_InternalServerError = anthorpic_RateLimitError = Exception # noqa: N816
110+
anthorpic_InternalServerError = anthorpic_RateLimitError = Exception # type: ignore[assignment,misc] # noqa: N816
109111
anthropic_import_exception = ImportError("anthropic not found")
110112

111113
with optional_import_block() as mistral_result:
@@ -174,7 +176,7 @@
174176
if ollama_result.is_successful:
175177
ollama_import_exception: ImportError | None = None
176178
else:
177-
ollama_RequestError = ollama_ResponseError = Exception # noqa: N816
179+
ollama_RequestError = ollama_ResponseError = Exception # type: ignore[assignment,misc] # noqa: N816
178180
ollama_import_exception = ImportError("ollama not found")
179181

180182
with optional_import_block() as bedrock_result:
@@ -340,6 +342,8 @@ def __init__(self, config):
340342
class OpenAIClient:
341343
"""Follows the Client protocol and wraps the OpenAI client."""
342344

345+
RESPONSE_USAGE_KEYS: list[str] = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]
346+
343347
def __init__(self, client: OpenAI | AzureOpenAI, response_format: BaseModel | dict[str, Any] | None = None):
344348
self._oai_client = client
345349
self.response_format = response_format
@@ -712,7 +716,7 @@ def cost(self, response: ChatCompletion | Completion) -> float:
712716
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 # type: ignore [operator]
713717

714718
@staticmethod
715-
def get_usage(response: ChatCompletion | Completion) -> dict:
719+
def get_usage(response: ChatCompletion | Completion) -> dict[str, Any]:
716720
return {
717721
"prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0,
718722
"completion_tokens": response.usage.completion_tokens if response.usage is not None else 0,
@@ -900,21 +904,21 @@ def _register_default_client(self, config: dict[str, Any], openai_config: dict[s
900904
def create_azure_openai_client() -> AzureOpenAI:
901905
self._configure_azure_openai(config, openai_config)
902906
client = AzureOpenAI(**openai_config)
903-
self._clients.append(OpenAIClient(client, response_format=response_format))
907+
self._clients.append(OpenAIClient(client, response_format=response_format)) # type: ignore[arg-type]
904908
return client
905909

906910
client = create_azure_openai_client()
907911
elif api_type is not None and api_type.startswith("cerebras"):
908912
if cerebras_import_exception:
909913
raise ImportError("Please install `cerebras_cloud_sdk` to use Cerebras OpenAI API.")
910914
client = CerebrasClient(response_format=response_format, **openai_config)
911-
self._clients.append(client)
915+
self._clients.append(client) # type: ignore[arg-type]
912916
elif api_type is not None and api_type.startswith("google"):
913917
if gemini_import_exception:
914918
raise ImportError("Please install `google-genai` and 'vertexai' to use Google's API.")
915919
self._configure_openai_config_for_gemini(config, openai_config)
916920
client = GeminiClient(response_format=response_format, **openai_config)
917-
self._clients.append(client)
921+
self._clients.append(client) # type: ignore[arg-type]
918922
elif api_type is not None and api_type.startswith("anthropic"):
919923
if "api_key" not in config and "aws_region" in config:
920924
self._configure_openai_config_for_bedrock(config, openai_config)
@@ -923,44 +927,44 @@ def create_azure_openai_client() -> AzureOpenAI:
923927
if anthropic_import_exception:
924928
raise ImportError("Please install `anthropic` to use Anthropic API.")
925929
client = AnthropicClient(response_format=response_format, **openai_config)
926-
self._clients.append(client)
930+
self._clients.append(client) # type: ignore[arg-type]
927931
elif api_type is not None and api_type.startswith("mistral"):
928932
if mistral_import_exception:
929933
raise ImportError("Please install `mistralai` to use the Mistral.AI API.")
930934
client = MistralAIClient(response_format=response_format, **openai_config)
931-
self._clients.append(client)
935+
self._clients.append(client) # type: ignore[arg-type]
932936
elif api_type is not None and api_type.startswith("together"):
933937
if together_import_exception:
934938
raise ImportError("Please install `together` to use the Together.AI API.")
935939
client = TogetherClient(response_format=response_format, **openai_config)
936-
self._clients.append(client)
940+
self._clients.append(client) # type: ignore[arg-type]
937941
elif api_type is not None and api_type.startswith("groq"):
938942
if groq_import_exception:
939943
raise ImportError("Please install `groq` to use the Groq API.")
940944
client = GroqClient(response_format=response_format, **openai_config)
941-
self._clients.append(client)
945+
self._clients.append(client) # type: ignore[arg-type]
942946
elif api_type is not None and api_type.startswith("cohere"):
943947
if cohere_import_exception:
944948
raise ImportError("Please install `cohere` to use the Cohere API.")
945949
client = CohereClient(response_format=response_format, **openai_config)
946-
self._clients.append(client)
950+
self._clients.append(client) # type: ignore[arg-type]
947951
elif api_type is not None and api_type.startswith("ollama"):
948952
if ollama_import_exception:
949953
raise ImportError("Please install `ollama` and `fix-busted-json` to use the Ollama API.")
950954
client = OllamaClient(response_format=response_format, **openai_config)
951-
self._clients.append(client)
955+
self._clients.append(client) # type: ignore[arg-type]
952956
elif api_type is not None and api_type.startswith("bedrock"):
953957
self._configure_openai_config_for_bedrock(config, openai_config)
954958
if bedrock_import_exception:
955959
raise ImportError("Please install `boto3` to use the Amazon Bedrock API.")
956960
client = BedrockClient(response_format=response_format, **openai_config)
957-
self._clients.append(client)
961+
self._clients.append(client) # type: ignore[arg-type]
958962
elif api_type is not None and api_type.startswith("responses"):
959963
# OpenAI Responses API (stateful). Reuse the same OpenAI SDK but call the `/responses` endpoint via the new client.
960964
@require_optional_import("openai>=1.66.2", "openai")
961965
def create_responses_client() -> OpenAI:
962966
client = OpenAI(**openai_config)
963-
self._clients.append(OpenAIResponsesClient(client, response_format=response_format))
967+
self._clients.append(OpenAIResponsesClient(client, response_format=response_format)) # type: ignore[arg-type]
964968
return client
965969

966970
client = create_responses_client()
@@ -969,7 +973,7 @@ def create_responses_client() -> OpenAI:
969973
@require_optional_import("openai>=1.66.2", "openai")
970974
def create_openai_client() -> OpenAI:
971975
client = OpenAI(**openai_config)
972-
self._clients.append(OpenAIClient(client, response_format))
976+
self._clients.append(OpenAIClient(client, response_format)) # type: ignore[arg-type]
973977
return client
974978

975979
client = create_openai_client()
@@ -1134,12 +1138,12 @@ def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol:
11341138
)
11351139
request_ts = get_current_ts()
11361140

1137-
response: ModelClient.ModelClientResponseProtocol = cache.get(key, None)
1141+
response: ChatCompletionExtended | None = cache.get(key, None)
11381142

11391143
if response is not None:
11401144
response.message_retrieval_function = client.message_retrieval
11411145
try:
1142-
response.cost # type: ignore [attr-defined]
1146+
response.cost
11431147
except AttributeError:
11441148
# update attribute if cost is not calculated
11451149
response.cost = client.cost(response)
@@ -1157,7 +1161,7 @@ def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol:
11571161
request=params,
11581162
response=response,
11591163
is_cached=1,
1160-
cost=response.cost,
1164+
cost=response.cost if response.cost is not None else 0.0,
11611165
start_time=request_ts,
11621166
)
11631167

@@ -1272,12 +1276,10 @@ def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol:
12721276
raise RuntimeError("Should not reach here.")
12731277

12741278
@staticmethod
1275-
def _cost_with_customized_price(
1276-
response: ModelClient.ModelClientResponseProtocol, price_1k: tuple[float, float]
1277-
) -> None:
1279+
def _cost_with_customized_price(response: ChatCompletion | Completion, price_1k: tuple[float, float]) -> float:
12781280
"""If a customized cost is passed, overwrite the cost in the response."""
1279-
n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr]
1280-
n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0 # type: ignore [union-attr]
1281+
n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0
1282+
n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0
12811283
if n_output_tokens is None:
12821284
n_output_tokens = 0
12831285
return (n_input_tokens * price_1k[0] + n_output_tokens * price_1k[1]) / 1000
@@ -1451,17 +1453,17 @@ def clear_usage_summary(self) -> None:
14511453

14521454
@classmethod
14531455
def extract_text_or_completion_object(
1454-
cls, response: ModelClient.ModelClientResponseProtocol
1455-
) -> list[str] | list[ModelClient.ModelClientResponseProtocol.Choice.Message]:
1456+
cls, response: ChatCompletionExtended
1457+
) -> list[str] | list[ChatCompletionMessage]:
14561458
"""Extract the text or ChatCompletion objects from a completion or chat response.
14571459
14581460
Args:
1459-
response (ChatCompletion | Completion): The response from openai.
1461+
response: The response from openai with message_retrieval_function attached.
14601462
14611463
Returns:
14621464
A list of text, or a list of ChatCompletion objects if function_call/tool_calls are present.
14631465
"""
1464-
return response.message_retrieval_function(response)
1466+
return response.message_retrieval_function(response) # type: ignore [misc]
14651467

14661468

14671469
# -----------------------------------------------------------------------------

autogen/oai/cohere.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def create_client(self):
9999
class CohereClient:
100100
"""Client for Cohere's API."""
101101

102+
RESPONSE_USAGE_KEYS: list[str] = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]
103+
102104
def __init__(self, **kwargs: Unpack[CohereEntryDict]):
103105
"""Requires api_key or environment variable to be set
104106

autogen/oai/gemini.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ def create_client(self):
139139
class GeminiClient:
140140
"""Client for Google's Gemini API."""
141141

142+
RESPONSE_USAGE_KEYS: list[str] = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]
143+
142144
# Mapping, where Key is a term used by Autogen, and Value is a term used by Gemini
143145
PARAMS_MAPPING = {
144146
"max_tokens": "max_output_tokens",

autogen/oai/groq.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ def create_client(self):
7777
class GroqClient:
7878
"""Client for Groq's API."""
7979

80+
RESPONSE_USAGE_KEYS: list[str] = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]
81+
8082
def __init__(self, **kwargs: Unpack[GroqEntryDict]):
8183
"""Requires api_key or environment variable to be set
8284

0 commit comments

Comments
 (0)