Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ..logger.logger_utils import get_current_ts
from ..runtime_logging import log_chat_completion, log_new_client, log_new_wrapper, logging_enabled
from ..token_count_utils import count_token
from .client_utils import FormatterProtocol, logging_formatter
from .client_utils import FormatterProtocol, logging_formatter, merge_config_with_tools
from .openai_utils import OAI_PRICE1K, get_key, is_valid_api_key

TOOL_ENABLED = False
Expand Down Expand Up @@ -639,8 +639,11 @@ def _create_or_parse(*args, **kwargs):
warnings.warn(
f"The {params.get('model')} model does not support streaming. The stream will be set to False."
)
if params.get("tools", False):
raise ModelToolNotSupportedError(params.get("model"))
if "tools" in params:
if params["tools"]: # If tools exist, raise as unsupported
raise ModelToolNotSupportedError(params.get("model"))
else:
params.pop("tools") # Remove empty tools list
self._process_reasoning_model_params(params)
params["stream"] = False
response = create_or_parse(**params)
Expand Down Expand Up @@ -1083,7 +1086,7 @@ def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol:
for i in ordered_clients_indices:
# merge the input config with the i-th config in the config list
client_config = self._config_list[i]
full_config = {**config, **client_config, "tools": config.get("tools", []) + client_config.get("tools", [])}
full_config = merge_config_with_tools(config, client_config)

# separate the config into create_config and extra_kwargs
create_config, extra_kwargs = self._separate_create_config(full_config)
Expand Down
30 changes: 30 additions & 0 deletions autogen/oai/client_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,36 @@ def validate_parameter(
return param_value


def merge_config_with_tools(config: dict[str, Any], client_config: dict[str, Any]) -> dict[str, Any]:
"""Merge configuration dictionaries with proper tools and functions handling.

This function merges two configuration dictionaries while ensuring that:
1. Empty 'tools' arrays are not added unnecessarily
2. 'tools' and deprecated 'functions' parameters are not both present
3. Actual tool configurations are properly merged

Args:
config: The base configuration dictionary (e.g., from create() call)
client_config: The client-specific configuration dictionary (e.g., from config_list)

Returns:
dict[str, Any]: The merged configuration with proper tools/functions handling
"""
# Start with a clean merge of both configs
full_config = {**config, **client_config}

# Add tools if tools contains something AND are not using deprecated functions
config_tools = config.get("tools", [])
client_tools = client_config.get("tools", [])

if config_tools or client_tools:
# Don't add tools if functions parameter is present (deprecated API)
if "functions" not in full_config:
full_config["tools"] = config_tools + client_tools

return full_config


def should_hide_tools(messages: list[dict[str, Any]], tools: list[dict[str, Any]], hide_tools_param: str) -> bool:
"""Determines if tools should be hidden. This function is used to hide tools when they have been run, minimising the chance of the LLM choosing them when they shouldn't.
Parameters:
Expand Down
Loading