diff --git a/litellm/utils.py b/litellm/utils.py index 9b8d11cfd5dc..3592a7398909 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2085,8 +2085,9 @@ def _get_non_default_params( if ( k in default_params and v != default_params[k] - and _should_drop_param(k=k, additional_drop_params=additional_drop_params) - is False + and not _should_drop_param( + k=k, additional_drop_params=additional_drop_params + ) ): non_default_params[k] = v @@ -2571,6 +2572,92 @@ def _remove_unsupported_params( return non_default_params +def _get_optional_params_defaults() -> dict[str, Any]: + """Return a dictionary of default parameter names and values, which are associated with get_optional_params""" + return { + "additional_drop_params": None, + "allowed_openai_params": None, + "api_version": None, + "audio": None, + "custom_llm_provider": "", + "drop_params": None, + "extra_headers": None, + "frequency_penalty": None, + "function_call": None, + "functions": None, + "logit_bias": None, + "logprobs": None, + "max_completion_tokens": None, + "max_retries": None, + "max_tokens": None, + "messages": None, + "modalities": None, + "model": None, + "n": None, + "parallel_tool_calls": None, + "prediction": None, + "presence_penalty": None, + "reasoning_effort": None, + "response_format": None, + "seed": None, + "stop": None, + "stream": False, + "stream_options": None, + "temperature": None, + "thinking": None, + "tool_choice": None, + "tools": None, + "top_logprobs": None, + "top_p": None, + "user": None, + } + + +def _get_optional_params_non_default_params( + passed_params: dict[str, Any], + default_params: dict[str, Any], + additional_drop_params=None, +): + """Filter parameters supplied to get_optional_params for non-default values. + + Args: + passed_params (dict): Dictionary of parameters passed to the calling function (get_optional_params). + default_params (dict, optional): Dictionary of default parameters for get_optional_params. + additional_drop_params (list, optional): Additional parameters specified by the end user to exclude + from the result. + Returns: + dict: Parameters that are non-default and not excluded. + Raises: + ValueError: If any of the excluded parameters are not valid keys in the default_params for get_optional_params. + """ + + excluded_non_default_params = { + "additional_drop_params", + "allowed_openai_params", + "api_version", + "custom_llm_provider", + "drop_params", + "messages", + "model", + } + + # From the parameters passed into this function, filter for parameters with non-default values. + non_default_params = { + k: v + for k, v in passed_params.items() + if ( + k in default_params + and k not in excluded_non_default_params + and v != default_params[k] + and not _should_drop_param( + k=k, additional_drop_params=additional_drop_params + ) + ) + } + + return non_default_params + + def get_optional_params( # noqa: PLR0915 # use the openai defaults # https://platform.openai.com/docs/api-reference/chat/create @@ -2659,122 +2746,56 @@ def get_optional_params( # noqa: PLR0915 non_default_params=passed_params, optional_params=optional_params ) - default_params = { - "functions": None, - "function_call": None, - "temperature": None, - "top_p": None, - "n": None, - "stream": None, - "stream_options": None, - "stop": None, - "max_tokens": None, - "max_completion_tokens": None, - "modalities": None, - "prediction": None, - "audio": None, - "presence_penalty": None, - "frequency_penalty": None, - "logit_bias": None, - "user": None, - "model": None, - "custom_llm_provider": "", - "response_format": None, - "seed": None, - "tools": None, - "tool_choice": None, - "max_retries": None, - "logprobs": None, - "top_logprobs": None, - "extra_headers": None, - "api_version": None, - "parallel_tool_calls": None, - "drop_params": None, - "allowed_openai_params": None, - "additional_drop_params": None, - "messages": None, - "reasoning_effort": None, - "thinking": None, - } - - # filter out those parameters that were passed with non-default values - - non_default_params = { - k: v - for k, v in passed_params.items() - if ( - k != "model" - and k != "custom_llm_provider" - and k != "api_version" - and k != "drop_params" - and k != "allowed_openai_params" - and k != "additional_drop_params" - and k != "messages" - and k in default_params - and v != default_params[k] - and _should_drop_param(k=k, additional_drop_params=additional_drop_params) - is False - ) - } + default_params = _get_optional_params_defaults() + non_default_params = _get_optional_params_non_default_params( + passed_params=passed_params, + default_params=default_params, + additional_drop_params=additional_drop_params, + ) ## raise exception if function calling passed in for a provider that doesn't support it - if ( - "functions" in non_default_params - or "function_call" in non_default_params - or "tools" in non_default_params + if any( + param_name in non_default_params + for param_name in ("functions", "function_call", "tools") ): + # Key to store function data which can be used for providers that don't support function calling via params + functions_unsupported_model_key = "functions_unsupported_model" + + # Handle Ollama as a special case (ollama actually supports JSON output so we can emulate function calling) + if custom_llm_provider == "ollama": + optional_params["format"] = "json" + # NOTE: This adjusts global state in LiteLLM. + litellm.add_function_to_prompt = ( + True # so that main.py adds the function call to the prompt + ) + non_default_params.pop( + "tool_choice", None + ) # causes ollama requests to hang when used later, so remove. + + # If the function isn't going to be added to the prompt, handle all providers that are not OpenAI-compatible if ( - custom_llm_provider == "ollama" - and custom_llm_provider != "text-completion-openai" - and custom_llm_provider != "azure" - and custom_llm_provider != "vertex_ai" - and custom_llm_provider != "anyscale" - and custom_llm_provider != "together_ai" - and custom_llm_provider != "groq" - and custom_llm_provider != "nvidia_nim" - and custom_llm_provider != "cerebras" - and custom_llm_provider != "xai" - and custom_llm_provider != "ai21_chat" - and custom_llm_provider != "volcengine" - and custom_llm_provider != "deepseek" - and custom_llm_provider != "codestral" - and custom_llm_provider != "mistral" - and custom_llm_provider != "anthropic" - and custom_llm_provider != "cohere_chat" - and custom_llm_provider != "cohere" - and custom_llm_provider != "bedrock" - and custom_llm_provider != "ollama_chat" - and custom_llm_provider != "openrouter" - and custom_llm_provider not in litellm.openai_compatible_providers + not litellm.add_function_to_prompt + and custom_llm_provider not in litellm.openai_compatible_providers ): - if custom_llm_provider == "ollama": - # ollama actually supports json output - optional_params["format"] = "json" - litellm.add_function_to_prompt = ( - True # so that main.py adds the function call to the prompt - ) - if "tools" in non_default_params: - optional_params["functions_unsupported_model"] = ( - non_default_params.pop("tools") - ) - non_default_params.pop( - "tool_choice", None - ) # causes ollama requests to hang - elif "functions" in non_default_params: - optional_params["functions_unsupported_model"] = ( - non_default_params.pop("functions") - ) - elif ( - litellm.add_function_to_prompt - ): # if user opts to add it to prompt instead - optional_params["functions_unsupported_model"] = non_default_params.pop( - "tools", non_default_params.pop("functions", None) - ) - else: - raise UnsupportedParamsError( - status_code=500, - message=f"Function calling is not supported by {custom_llm_provider}.", - ) + raise UnsupportedParamsError( + status_code=500, + message=f"Function calling is not supported by {custom_llm_provider}.", + ) + + # Attempt to add the supplied function call to the prompt, preferring tools > functions > function_call. + # The assumption is that we want to remove them all regardless of which parameter supplied the value. + if "function_call" in non_default_params: + optional_params[functions_unsupported_model_key] = non_default_params.pop( + "function_call" + ) + if "functions" in non_default_params: + optional_params[functions_unsupported_model_key] = non_default_params.pop( + "functions" + ) + if "tools" in non_default_params: + optional_params[functions_unsupported_model_key] = non_default_params.pop( + "tools" + ) provider_config: Optional[BaseConfig] = None if custom_llm_provider is not None and custom_llm_provider in [ diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/litellm/__init__.py b/tests/litellm/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/litellm/proxy/__init__.py b/tests/litellm/proxy/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/litellm/proxy/types_utils/__init__.py b/tests/litellm/proxy/types_utils/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/litellm/test_utils.py b/tests/litellm/test_utils.py new file mode 100644 index 000000000000..c48c5cf466e8 --- /dev/null +++ b/tests/litellm/test_utils.py @@ -0,0 +1,277 @@ +import pytest + +import litellm +from litellm import UnsupportedParamsError +from litellm.utils import ( + _get_optional_params_defaults, + _get_optional_params_non_default_params, + get_optional_params, +) + + +class TestGetOptionalParamsDefaults: + def test_returns_dictionary(self): + """Test that the function returns a dictionary.""" + result = _get_optional_params_defaults() + assert isinstance(result, dict) + + def test_return_value_is_not_mutated(self): + """Test that subsequent calls return independent copies of the default dictionary.""" + first_call = _get_optional_params_defaults() + second_call = _get_optional_params_defaults() + + # Verify they're equal but not the same object + assert first_call == second_call + assert first_call is not second_call + + # Modify the first result and verify the second isn't affected + first_call["temperature"] = 0.7 + assert second_call["temperature"] is None + + @pytest.mark.parametrize( + "param_name, expected_value", + [ + ("additional_drop_params", None), + ("allowed_openai_params", None), + ("api_version", None), + ("audio", None), + ("custom_llm_provider", ""), + ("drop_params", None), + ("extra_headers", None), + ("frequency_penalty", None), + ("function_call", None), + ("functions", None), + ("logit_bias", None), + ("logprobs", None), + ("max_completion_tokens", None), + ("max_retries", None), + ("max_tokens", None), + ("messages", None), + ("modalities", None), + ("model", None), + ("n", None), + ("parallel_tool_calls", None), + ("prediction", None), + ("presence_penalty", None), + ("reasoning_effort", None), + ("response_format", None), + ("seed", None), + ("stop", None), + ("stream", False), + ("stream_options", None), + ("temperature", None), + ("thinking", None), + ("tool_choice", None), + ("tools", None), + ("top_logprobs", None), + ("top_p", None), + ("user", None), + ], + ) + def test_individual_defaults(self, param_name, expected_value): + """Test that each parameter has the expected default value.""" + defaults = _get_optional_params_defaults() + assert param_name in defaults + assert defaults[param_name] == expected_value + + def test_completeness(self): + """Test that the function returns all expected parameters with no extras or missing items.""" + expected_params = { + "additional_drop_params", + "allowed_openai_params", + "api_version", + "audio", + "custom_llm_provider", + "drop_params", + "extra_headers", + "frequency_penalty", + "function_call", + "functions", + "logit_bias", + "logprobs", + "max_completion_tokens", + "max_retries", + "max_tokens", + "messages", + "modalities", + "model", + "n", + "parallel_tool_calls", + "prediction", + "presence_penalty", + "reasoning_effort", + "response_format", + "seed", + "stop", + "stream", + "stream_options", + "temperature", + "thinking", + "tool_choice", + "tools", + "top_logprobs", + "top_p", + "user", + } + + actual_params = set(_get_optional_params_defaults().keys()) + + # Check for extra parameters + extra_params = actual_params - expected_params + assert not extra_params, f"Unexpected parameters found: {extra_params}" + + # Check for missing parameters + missing_params = expected_params - actual_params + assert not missing_params, f"Expected parameters missing: {missing_params}" + + def test_custom_llm_provider_is_empty_string(self): + """Specifically test that custom_llm_provider has empty string as default (not None).""" + defaults = _get_optional_params_defaults() + assert defaults["custom_llm_provider"] == "" + assert defaults["custom_llm_provider"] is not None + + def test_stream_is_false(self): + """Specifically test that stream has False as default (not None).""" + defaults = _get_optional_params_defaults() + assert not defaults["stream"] + + def test_all_others_are_none(self): + """Test that all parameters except custom_llm_provider have None as default. + + This test may change in the future or no longer be required, but is included for now. + """ + defaults = _get_optional_params_defaults() + for key, value in defaults.items(): + if key in ["custom_llm_provider", "stream"]: + continue + assert value is None, f"Expected {key} to be None, but got {value}" + + +class TestGetOptionalParamsNonDefaultParams: + @pytest.mark.parametrize( + "passed_params, default_params, additional_drop_params, expected", + [ + # no non-defaults, should return empty + ( + {"model": "gpt-4", "api_version": "v1"}, + _get_optional_params_defaults(), + None, + {}, + ), + # one non-default parameter not excluded + ( + { + "temperature": 0.9, + "additional_drop_params": None, + "allowed_openai_params": "test", + "api_version": "v1", + "custom_llm_provider": "llamafile", + "drop_params": ["foo"], + "messages": ["bar"], + "model": "gpt-4", + }, + _get_optional_params_defaults(), + None, + {"temperature": 0.9}, + ), + # specifically exclude (drop) a parameter that is not default + ( + { + "temperature": 0.9, + "additional_drop_params": None, + "allowed_openai_params": "test", + "api_version": "v1", + "custom_llm_provider": "llamafile", + "drop_params": ["foo"], + "messages": ["bar"], + "model": "gpt-4", + }, + _get_optional_params_defaults(), + ["temperature"], + {}, + ), + # non-default param dropped, not default param left alone + ( + {"temperature": 0.9, "top_p": 0.95}, + _get_optional_params_defaults(), + ["top_p"], + {"temperature": 0.9}, + ), + ], + ) + def test_get_optional_params_non_default_params( + self, passed_params, default_params, additional_drop_params, expected + ): + result = _get_optional_params_non_default_params( + passed_params, + default_params, + additional_drop_params=additional_drop_params, + ) + assert result == expected + + +class TestGetOptionalParms: + def test_raises_on_unsupported_function_calling(self): + original_flag = litellm.add_function_to_prompt + + try: + litellm.add_function_to_prompt = False + + with pytest.raises( + UnsupportedParamsError, + match=r"^litellm.UnsupportedParamsError: Function calling is not supported by bad_provider.", + ): + get_optional_params( + model="qwerty", + custom_llm_provider="bad_provider", + functions="not_supported", + ) + finally: + litellm.add_function_to_prompt = original_flag + + def test_ollama_sets_json_and_removes_tool_choice(self): + original_flag = litellm.add_function_to_prompt + + try: + optional_params = get_optional_params( + model="qwerty", + custom_llm_provider="ollama", + functions="my_function", + tool_choice="auto", + ) + + assert optional_params["format"] == "json" + assert litellm.add_function_to_prompt + assert optional_params["functions_unsupported_model"] == "my_function" + finally: + litellm.add_function_to_prompt = original_flag + + @pytest.mark.parametrize( + "tools, functions, function_call, expected_value", + [ + ("foo", None, None, "foo"), + (None, None, "baz", "baz"), + ("foo", "bar", None, "foo"), + ("foo", None, "baz", "foo"), + (None, "bar", "baz", "bar"), + ("foo", "bar", "baz", "foo"), + ], + ) + def test_supplying_tools_funcs_calls( + self, tools, functions, function_call, expected_value + ): + original_flag = litellm.add_function_to_prompt + try: + optional_params = get_optional_params( + model="qwerty", + custom_llm_provider="ollama", + tools=tools, + functions=functions, + function_call=function_call, + ) + + assert optional_params["format"] == "json" + assert litellm.add_function_to_prompt + assert optional_params["functions_unsupported_model"] == expected_value + finally: + litellm.add_function_to_prompt = original_flag