Skip to content

Fix anthropic thinking + response_format #9594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions litellm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
DEFAULT_FAILURE_THRESHOLD_PERCENT = (
0.5 # default cooldown a deployment if 50% of requests fail in a given minute
)
DEFAULT_MAX_TOKENS = 4096
DEFAULT_REDIS_SYNC_INTERVAL = 1
DEFAULT_COOLDOWN_TIME_SECONDS = 5
DEFAULT_REPLICATE_POLLING_RETRIES = 5
Expand Down
17 changes: 15 additions & 2 deletions litellm/llms/anthropic/chat/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,15 @@ def map_openai_params(
model: str,
drop_params: bool,
) -> dict:

is_thinking_enabled = self.is_thinking_enabled(
non_default_params=non_default_params
)

## handle thinking tokens
self.update_optional_params_with_thinking_tokens(
non_default_params=non_default_params, optional_params=optional_params
)
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
Expand Down Expand Up @@ -349,19 +358,23 @@ def map_openai_params(
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
"""

_tool_choice = {"name": RESPONSE_FORMAT_TOOL_NAME, "type": "tool"}
if not is_thinking_enabled:
_tool_choice = {"name": RESPONSE_FORMAT_TOOL_NAME, "type": "tool"}
optional_params["tool_choice"] = _tool_choice

_tool = self._create_json_tool_call_for_response_format(
json_schema=json_schema,
)
optional_params = self._add_tools_to_optional_params(
optional_params=optional_params, tools=[_tool]
)
optional_params["tool_choice"] = _tool_choice

optional_params["json_mode"] = True
if param == "user":
optional_params["metadata"] = {"user_id": value}
if param == "thinking":
optional_params["thinking"] = value

return optional_params

def _create_json_tool_call_for_response_format(
Expand Down
26 changes: 25 additions & 1 deletion litellm/llms/base_llm/chat/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
Optional,
Type,
Union,
cast,
)

import httpx
from pydantic import BaseModel

from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
from litellm.constants import DEFAULT_MAX_TOKENS, RESPONSE_FORMAT_TOOL_NAME
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.openai import (
AllMessageValues,
Expand Down Expand Up @@ -102,6 +103,29 @@ def get_json_schema_from_pydantic_object(
) -> Optional[dict]:
return type_to_response_format_param(response_format=response_format)

def is_thinking_enabled(self, non_default_params: dict) -> bool:
return non_default_params.get("thinking", {}).get("type", None) == "enabled"

def update_optional_params_with_thinking_tokens(
self, non_default_params: dict, optional_params: dict
):
"""
Handles scenario where max tokens is not specified. For anthropic models (anthropic api/bedrock/vertex ai), this requires having the max tokens being set and being greater than the thinking token budget.

Checks 'non_default_params' for 'thinking' and 'max_tokens'

if 'thinking' is enabled and 'max_tokens' is not specified, set 'max_tokens' to the thinking token budget + DEFAULT_MAX_TOKENS
"""
is_thinking_enabled = self.is_thinking_enabled(non_default_params)
if is_thinking_enabled and "max_tokens" not in non_default_params:
thinking_token_budget = cast(dict, non_default_params["thinking"]).get(
"budget_tokens", None
)
if thinking_token_budget is not None:
optional_params["max_tokens"] = (
thinking_token_budget + DEFAULT_MAX_TOKENS
)

def should_fake_stream(
self,
model: Optional[str],
Expand Down
12 changes: 10 additions & 2 deletions litellm/llms/bedrock/chat/converse_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,10 @@ def map_openai_params(
drop_params: bool,
messages: Optional[List[AllMessageValues]] = None,
) -> dict:
is_thinking_enabled = self.is_thinking_enabled(non_default_params)
self.update_optional_params_with_thinking_tokens(
non_default_params=non_default_params, optional_params=optional_params
)
for param, value in non_default_params.items():
if param == "response_format" and isinstance(value, dict):

Expand Down Expand Up @@ -247,8 +251,11 @@ def map_openai_params(
optional_params = self._add_tools_to_optional_params(
optional_params=optional_params, tools=[_tool]
)
if litellm.utils.supports_tool_choice(
model=model, custom_llm_provider=self.custom_llm_provider
if (
litellm.utils.supports_tool_choice(
model=model, custom_llm_provider=self.custom_llm_provider
)
and not is_thinking_enabled
):
optional_params["tool_choice"] = ToolChoiceValuesBlock(
tool=SpecificToolChoiceBlock(
Expand Down Expand Up @@ -284,6 +291,7 @@ def map_openai_params(
optional_params["tool_choice"] = _tool_choice_value
if param == "thinking":
optional_params["thinking"] = value

return optional_params

@overload
Expand Down
23 changes: 23 additions & 0 deletions tests/llm_translation/base_llm_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,11 @@ def get_base_completion_call_args(self) -> dict:
"""Must return the base completion call args"""
pass

@abstractmethod
def get_base_completion_call_args_with_thinking(self) -> dict:
"""Must return the base completion call args"""
pass

@property
def completion_function(self):
return litellm.completion
Expand Down Expand Up @@ -1066,3 +1071,21 @@ def test_anthropic_response_format_streaming_vs_non_streaming(self):
json.loads(built_response.choices[0].message.content).keys()
== json.loads(non_stream_response.choices[0].message.content).keys()
), f"Got={json.loads(built_response.choices[0].message.content)}, Expected={json.loads(non_stream_response.choices[0].message.content)}"

def test_completion_thinking_with_response_format(self):
from pydantic import BaseModel

class RFormat(BaseModel):
question: str
answer: str

base_completion_call_args = self.get_base_completion_call_args_with_thinking()

messages = [{"role": "user", "content": "Generate 5 question + answer pairs"}]
response = self.completion_function(
**base_completion_call_args,
messages=messages,
response_format=RFormat,
)

print(response)
6 changes: 6 additions & 0 deletions tests/llm_translation/test_anthropic_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,12 @@ class TestAnthropicCompletion(BaseLLMChatTest, BaseAnthropicChatTest):
def get_base_completion_call_args(self) -> dict:
return {"model": "anthropic/claude-3-5-sonnet-20240620"}

def get_base_completion_call_args_with_thinking(self) -> dict:
return {
"model": "anthropic/claude-3-7-sonnet-latest",
"thinking": {"type": "enabled", "budget_tokens": 16000},
}

def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
from litellm.litellm_core_utils.prompt_templates.factory import (
Expand Down
15 changes: 14 additions & 1 deletion tests/llm_translation/test_bedrock_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from litellm.llms.bedrock.chat import BedrockLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.litellm_core_utils.prompt_templates.factory import _bedrock_tools_pt
from base_llm_unit_tests import BaseLLMChatTest
from base_llm_unit_tests import BaseLLMChatTest, BaseAnthropicChatTest
from base_rerank_unit_tests import BaseLLMRerankTest
from base_embedding_unit_tests import BaseLLMEmbeddingTest

Expand Down Expand Up @@ -2191,6 +2191,19 @@ def test_completion_cost(self):
assert cost > 0


class TestBedrockConverseAnthropicUnitTests(BaseAnthropicChatTest):
def get_base_completion_call_args(self) -> dict:
return {
"model": "bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0",
}

def get_base_completion_call_args_with_thinking(self) -> dict:
return {
"model": "bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
"thinking": {"type": "enabled", "budget_tokens": 16000},
}


class TestBedrockConverseChatNormal(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
Expand Down
2 changes: 2 additions & 0 deletions tests/llm_translation/test_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ async def test_chat_completion_cohere_citations(stream):
assert citations_chunk
else:
assert response.citations is not None
except litellm.ServiceUnavailableError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")

Expand Down
Loading