1919from  pydantic  import  BaseModel , Field , HttpUrl 
2020from  pydantic .type_adapter  import  TypeAdapter 
2121
22+ from  autogen .oai .oai_models .chat_completion  import  ChatCompletionExtended 
23+ 
2224from  ..cache  import  Cache 
2325from  ..code_utils  import  content_str 
2426from  ..doc_utils  import  export_module 
5759
5860    if  openai .__version__  >=  "1.1.0" :
5961        TOOL_ENABLED  =  True 
60-     ERROR  =  None 
62+     ERROR :  ImportError   |   None  =  None 
6163    from  openai .lib ._pydantic  import  _ensure_strict_json_schema 
6264else :
63-     ERROR :  ImportError   |   None   =  ImportError ("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper." )
65+     ERROR   =  ImportError ("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper." )   # type: ignore[assignment] 
6466
6567    # OpenAI = object 
6668    # AzureOpenAI = object 
7779if  cerebras_result .is_successful :
7880    cerebras_import_exception : ImportError  |  None  =  None 
7981else :
80-     cerebras_AuthenticationError  =  cerebras_InternalServerError  =  cerebras_RateLimitError  =  Exception   # noqa: N816 
82+     cerebras_AuthenticationError  =  cerebras_InternalServerError  =  cerebras_RateLimitError  =  Exception   # type: ignore[assignment,misc]  #  noqa: N816 
8183    cerebras_import_exception  =  ImportError ("cerebras_cloud_sdk not found" )
8284
8385with  optional_import_block () as  gemini_result :
9193if  gemini_result .is_successful :
9294    gemini_import_exception : ImportError  |  None  =  None 
9395else :
94-     gemini_InternalServerError  =  gemini_ResourceExhausted  =  Exception   # noqa: N816 
96+     gemini_InternalServerError  =  gemini_ResourceExhausted  =  Exception   # type: ignore[assignment,misc]  #  noqa: N816 
9597    gemini_import_exception  =  ImportError ("google-genai not found" )
9698
9799with  optional_import_block () as  anthropic_result :
105107if  anthropic_result .is_successful :
106108    anthropic_import_exception : ImportError  |  None  =  None 
107109else :
108-     anthorpic_InternalServerError  =  anthorpic_RateLimitError  =  Exception   # noqa: N816 
110+     anthorpic_InternalServerError  =  anthorpic_RateLimitError  =  Exception   # type: ignore[assignment,misc]  #  noqa: N816 
109111    anthropic_import_exception  =  ImportError ("anthropic not found" )
110112
111113with  optional_import_block () as  mistral_result :
174176if  ollama_result .is_successful :
175177    ollama_import_exception : ImportError  |  None  =  None 
176178else :
177-     ollama_RequestError  =  ollama_ResponseError  =  Exception   # noqa: N816 
179+     ollama_RequestError  =  ollama_ResponseError  =  Exception   # type: ignore[assignment,misc]  #  noqa: N816 
178180    ollama_import_exception  =  ImportError ("ollama not found" )
179181
180182with  optional_import_block () as  bedrock_result :
@@ -340,6 +342,8 @@ def __init__(self, config):
340342class  OpenAIClient :
341343    """Follows the Client protocol and wraps the OpenAI client.""" 
342344
345+     RESPONSE_USAGE_KEYS : list [str ] =  ["prompt_tokens" , "completion_tokens" , "total_tokens" , "cost" , "model" ]
346+ 
343347    def  __init__ (self , client : OpenAI  |  AzureOpenAI , response_format : BaseModel  |  dict [str , Any ] |  None  =  None ):
344348        self ._oai_client  =  client 
345349        self .response_format  =  response_format 
@@ -712,7 +716,7 @@ def cost(self, response: ChatCompletion | Completion) -> float:
712716        return  tmp_price1K  *  (n_input_tokens  +  n_output_tokens ) /  1000   # type: ignore [operator] 
713717
714718    @staticmethod  
715-     def  get_usage (response : ChatCompletion  |  Completion ) ->  dict :
719+     def  get_usage (response : ChatCompletion  |  Completion ) ->  dict [ str ,  Any ] :
716720        return  {
717721            "prompt_tokens" : response .usage .prompt_tokens  if  response .usage  is  not None  else  0 ,
718722            "completion_tokens" : response .usage .completion_tokens  if  response .usage  is  not None  else  0 ,
@@ -900,21 +904,21 @@ def _register_default_client(self, config: dict[str, Any], openai_config: dict[s
900904                def  create_azure_openai_client () ->  AzureOpenAI :
901905                    self ._configure_azure_openai (config , openai_config )
902906                    client  =  AzureOpenAI (** openai_config )
903-                     self ._clients .append (OpenAIClient (client , response_format = response_format ))
907+                     self ._clients .append (OpenAIClient (client , response_format = response_format ))   # type: ignore[arg-type] 
904908                    return  client 
905909
906910                client  =  create_azure_openai_client ()
907911            elif  api_type  is  not None  and  api_type .startswith ("cerebras" ):
908912                if  cerebras_import_exception :
909913                    raise  ImportError ("Please install `cerebras_cloud_sdk` to use Cerebras OpenAI API." )
910914                client  =  CerebrasClient (response_format = response_format , ** openai_config )
911-                 self ._clients .append (client )
915+                 self ._clients .append (client )   # type: ignore[arg-type] 
912916            elif  api_type  is  not None  and  api_type .startswith ("google" ):
913917                if  gemini_import_exception :
914918                    raise  ImportError ("Please install `google-genai` and 'vertexai' to use Google's API." )
915919                self ._configure_openai_config_for_gemini (config , openai_config )
916920                client  =  GeminiClient (response_format = response_format , ** openai_config )
917-                 self ._clients .append (client )
921+                 self ._clients .append (client )   # type: ignore[arg-type] 
918922            elif  api_type  is  not None  and  api_type .startswith ("anthropic" ):
919923                if  "api_key"  not  in config  and  "aws_region"  in  config :
920924                    self ._configure_openai_config_for_bedrock (config , openai_config )
@@ -923,44 +927,44 @@ def create_azure_openai_client() -> AzureOpenAI:
923927                if  anthropic_import_exception :
924928                    raise  ImportError ("Please install `anthropic` to use Anthropic API." )
925929                client  =  AnthropicClient (response_format = response_format , ** openai_config )
926-                 self ._clients .append (client )
930+                 self ._clients .append (client )   # type: ignore[arg-type] 
927931            elif  api_type  is  not None  and  api_type .startswith ("mistral" ):
928932                if  mistral_import_exception :
929933                    raise  ImportError ("Please install `mistralai` to use the Mistral.AI API." )
930934                client  =  MistralAIClient (response_format = response_format , ** openai_config )
931-                 self ._clients .append (client )
935+                 self ._clients .append (client )   # type: ignore[arg-type] 
932936            elif  api_type  is  not None  and  api_type .startswith ("together" ):
933937                if  together_import_exception :
934938                    raise  ImportError ("Please install `together` to use the Together.AI API." )
935939                client  =  TogetherClient (response_format = response_format , ** openai_config )
936-                 self ._clients .append (client )
940+                 self ._clients .append (client )   # type: ignore[arg-type] 
937941            elif  api_type  is  not None  and  api_type .startswith ("groq" ):
938942                if  groq_import_exception :
939943                    raise  ImportError ("Please install `groq` to use the Groq API." )
940944                client  =  GroqClient (response_format = response_format , ** openai_config )
941-                 self ._clients .append (client )
945+                 self ._clients .append (client )   # type: ignore[arg-type] 
942946            elif  api_type  is  not None  and  api_type .startswith ("cohere" ):
943947                if  cohere_import_exception :
944948                    raise  ImportError ("Please install `cohere` to use the Cohere API." )
945949                client  =  CohereClient (response_format = response_format , ** openai_config )
946-                 self ._clients .append (client )
950+                 self ._clients .append (client )   # type: ignore[arg-type] 
947951            elif  api_type  is  not None  and  api_type .startswith ("ollama" ):
948952                if  ollama_import_exception :
949953                    raise  ImportError ("Please install `ollama` and `fix-busted-json` to use the Ollama API." )
950954                client  =  OllamaClient (response_format = response_format , ** openai_config )
951-                 self ._clients .append (client )
955+                 self ._clients .append (client )   # type: ignore[arg-type] 
952956            elif  api_type  is  not None  and  api_type .startswith ("bedrock" ):
953957                self ._configure_openai_config_for_bedrock (config , openai_config )
954958                if  bedrock_import_exception :
955959                    raise  ImportError ("Please install `boto3` to use the Amazon Bedrock API." )
956960                client  =  BedrockClient (response_format = response_format , ** openai_config )
957-                 self ._clients .append (client )
961+                 self ._clients .append (client )   # type: ignore[arg-type] 
958962            elif  api_type  is  not None  and  api_type .startswith ("responses" ):
959963                # OpenAI Responses API (stateful). Reuse the same OpenAI SDK but call the `/responses` endpoint via the new client. 
960964                @require_optional_import ("openai>=1.66.2" , "openai" ) 
961965                def  create_responses_client () ->  OpenAI :
962966                    client  =  OpenAI (** openai_config )
963-                     self ._clients .append (OpenAIResponsesClient (client , response_format = response_format ))
967+                     self ._clients .append (OpenAIResponsesClient (client , response_format = response_format ))   # type: ignore[arg-type] 
964968                    return  client 
965969
966970                client  =  create_responses_client ()
@@ -969,7 +973,7 @@ def create_responses_client() -> OpenAI:
969973                @require_optional_import ("openai>=1.66.2" , "openai" ) 
970974                def  create_openai_client () ->  OpenAI :
971975                    client  =  OpenAI (** openai_config )
972-                     self ._clients .append (OpenAIClient (client , response_format ))
976+                     self ._clients .append (OpenAIClient (client , response_format ))   # type: ignore[arg-type] 
973977                    return  client 
974978
975979                client  =  create_openai_client ()
@@ -1134,12 +1138,12 @@ def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol:
11341138                    )
11351139                    request_ts  =  get_current_ts ()
11361140
1137-                     response : ModelClient . ModelClientResponseProtocol  =  cache .get (key , None )
1141+                     response : ChatCompletionExtended   |   None  =  cache .get (key , None )
11381142
11391143                    if  response  is  not None :
11401144                        response .message_retrieval_function  =  client .message_retrieval 
11411145                        try :
1142-                             response .cost    # type: ignore [attr-defined] 
1146+                             response .cost 
11431147                        except  AttributeError :
11441148                            # update attribute if cost is not calculated 
11451149                            response .cost  =  client .cost (response )
@@ -1157,7 +1161,7 @@ def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol:
11571161                                request = params ,
11581162                                response = response ,
11591163                                is_cached = 1 ,
1160-                                 cost = response .cost ,
1164+                                 cost = response .cost   if   response . cost   is   not   None   else   0.0 ,
11611165                                start_time = request_ts ,
11621166                            )
11631167
@@ -1272,12 +1276,10 @@ def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol:
12721276        raise  RuntimeError ("Should not reach here." )
12731277
12741278    @staticmethod  
1275-     def  _cost_with_customized_price (
1276-         response : ModelClient .ModelClientResponseProtocol , price_1k : tuple [float , float ]
1277-     ) ->  None :
1279+     def  _cost_with_customized_price (response : ChatCompletion  |  Completion , price_1k : tuple [float , float ]) ->  float :
12781280        """If a customized cost is passed, overwrite the cost in the response.""" 
1279-         n_input_tokens  =  response .usage .prompt_tokens  if  response .usage  is  not None  else  0    # type: ignore [union-attr] 
1280-         n_output_tokens  =  response .usage .completion_tokens  if  response .usage  is  not None  else  0    # type: ignore [union-attr] 
1281+         n_input_tokens  =  response .usage .prompt_tokens  if  response .usage  is  not None  else  0 
1282+         n_output_tokens  =  response .usage .completion_tokens  if  response .usage  is  not None  else  0 
12811283        if  n_output_tokens  is  None :
12821284            n_output_tokens  =  0 
12831285        return  (n_input_tokens  *  price_1k [0 ] +  n_output_tokens  *  price_1k [1 ]) /  1000 
@@ -1451,17 +1453,17 @@ def clear_usage_summary(self) -> None:
14511453
14521454    @classmethod  
14531455    def  extract_text_or_completion_object (
1454-         cls , response : ModelClient . ModelClientResponseProtocol 
1455-     ) ->  list [str ] |  list [ModelClient . ModelClientResponseProtocol . Choice . Message ]:
1456+         cls , response : ChatCompletionExtended 
1457+     ) ->  list [str ] |  list [ChatCompletionMessage ]:
14561458        """Extract the text or ChatCompletion objects from a completion or chat response. 
14571459
14581460        Args: 
1459-             response (ChatCompletion | Completion) : The response from openai. 
1461+             response: The response from openai with message_retrieval_function attached . 
14601462
14611463        Returns: 
14621464            A list of text, or a list of ChatCompletion objects if function_call/tool_calls are present. 
14631465        """ 
1464-         return  response .message_retrieval_function (response )
1466+         return  response .message_retrieval_function (response )   # type: ignore [misc] 
14651467
14661468
14671469# ----------------------------------------------------------------------------- 
0 commit comments