Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Set,
Tuple,
Union,
cast,
)

from llama_index.core.base.llms.types import (
Expand All @@ -23,6 +24,7 @@
LLMMetadata,
MessageRole,
ContentBlock,
ToolCallBlock,
)
from llama_index.core.base.llms.types import TextBlock as LITextBlock
from llama_index.core.base.llms.types import CitationBlock as LICitationBlock
Expand Down Expand Up @@ -351,8 +353,7 @@ def _completion_response_from_chat_response(

def _get_blocks_and_tool_calls_and_thinking(
self, response: Any
) -> Tuple[List[ContentBlock], List[Dict[str, Any]], List[Dict[str, Any]]]:
tool_calls = []
) -> Tuple[List[ContentBlock], List[Dict[str, Any]]]:
blocks: List[ContentBlock] = []
citations: List[TextCitation] = []
tracked_citations: Set[str] = set()
Expand Down Expand Up @@ -392,9 +393,15 @@ def _get_blocks_and_tool_calls_and_thinking(
)
)
elif isinstance(content_block, ToolUseBlock):
tool_calls.append(content_block.model_dump())
blocks.append(
ToolCallBlock(
tool_call_id=content_block.id,
tool_kwargs=cast(Dict[str, Any] | str, content_block.input),
tool_name=content_block.name,
)
)

return blocks, tool_calls, [x.model_dump() for x in citations]
return blocks, [x.model_dump() for x in citations]

@llm_chat_callback()
def chat(
Expand All @@ -412,17 +419,12 @@ def chat(
**all_kwargs,
)

blocks, tool_calls, citations = self._get_blocks_and_tool_calls_and_thinking(
response
)
blocks, citations = self._get_blocks_and_tool_calls_and_thinking(response)

return AnthropicChatResponse(
message=ChatMessage(
role=MessageRole.ASSISTANT,
blocks=blocks,
additional_kwargs={
"tool_calls": tool_calls,
},
),
citations=citations,
raw=dict(response),
Expand Down Expand Up @@ -532,13 +534,26 @@ def gen() -> Generator[AnthropicChatResponse, None, None]:
else:
tool_calls_to_send = cur_tool_calls

for tool_call in tool_calls_to_send:
if tool_call.id not in [
block.tool_call_id
for block in content
if isinstance(block, ToolCallBlock)
]:
content.append(
ToolCallBlock(
tool_call_id=tool_call.id,
tool_name=tool_call.name,
tool_kwargs=cast(
Dict[str, Any] | str, tool_call.input
),
)
)

yield AnthropicChatResponse(
message=ChatMessage(
role=role,
blocks=content,
additional_kwargs={
"tool_calls": [t.dict() for t in tool_calls_to_send]
},
),
citations=cur_citations,
delta=content_delta,
Expand All @@ -556,13 +571,31 @@ def gen() -> Generator[AnthropicChatResponse, None, None]:
content.append(cur_block)
cur_block = None

if cur_tool_call is not None:
tool_calls_to_send = [*cur_tool_calls, cur_tool_call]
else:
tool_calls_to_send = cur_tool_calls

for tool_call in tool_calls_to_send:
if tool_call.id not in [
block.tool_call_id
for block in content
if isinstance(block, ToolCallBlock)
]:
content.append(
ToolCallBlock(
tool_call_id=tool_call.id,
tool_name=tool_call.name,
tool_kwargs=cast(
Dict[str, Any] | str, tool_call.input
),
)
)

yield AnthropicChatResponse(
message=ChatMessage(
role=role,
blocks=content,
additional_kwargs={
"tool_calls": [t.dict() for t in tool_calls_to_send]
},
),
citations=cur_citations,
delta=content_delta,
Expand Down Expand Up @@ -600,17 +633,12 @@ async def achat(
**all_kwargs,
)

blocks, tool_calls, citations = self._get_blocks_and_tool_calls_and_thinking(
response
)
blocks, citations = self._get_blocks_and_tool_calls_and_thinking(response)

return AnthropicChatResponse(
message=ChatMessage(
role=MessageRole.ASSISTANT,
blocks=blocks,
additional_kwargs={
"tool_calls": tool_calls,
},
),
citations=citations,
raw=dict(response),
Expand Down Expand Up @@ -720,13 +748,26 @@ async def gen() -> ChatResponseAsyncGen:
else:
tool_calls_to_send = cur_tool_calls

for tool_call in tool_calls_to_send:
if tool_call.id not in [
block.tool_call_id
for block in content
if isinstance(block, ToolCallBlock)
]:
content.append(
ToolCallBlock(
tool_call_id=tool_call.id,
tool_name=tool_call.name,
tool_kwargs=cast(
Dict[str, Any] | str, tool_call.input
),
)
)

yield AnthropicChatResponse(
message=ChatMessage(
role=role,
blocks=content,
additional_kwargs={
"tool_calls": [t.dict() for t in tool_calls_to_send]
},
),
citations=cur_citations,
delta=content_delta,
Expand All @@ -744,13 +785,31 @@ async def gen() -> ChatResponseAsyncGen:
content.append(cur_block)
cur_block = None

if cur_tool_call is not None:
tool_calls_to_send = [*cur_tool_calls, cur_tool_call]
else:
tool_calls_to_send = cur_tool_calls

for tool_call in tool_calls_to_send:
if tool_call.id not in [
block.tool_call_id
for block in content
if isinstance(block, ToolCallBlock)
]:
content.append(
ToolCallBlock(
tool_call_id=tool_call.id,
tool_name=tool_call.name,
tool_kwargs=cast(
Dict[str, Any] | str, tool_call.input
),
)
)

yield AnthropicChatResponse(
message=ChatMessage(
role=role,
blocks=content,
additional_kwargs={
"tool_calls": [t.dict() for t in tool_calls_to_send]
},
),
citations=cur_citations,
delta=content_delta,
Expand Down Expand Up @@ -859,7 +918,11 @@ def get_tool_calls_from_response(
**kwargs: Any,
) -> List[ToolSelection]:
"""Predict and call the tool."""
tool_calls = response.message.additional_kwargs.get("tool_calls", [])
tool_calls = [
block
for block in response.message.blocks
if isinstance(block, ToolCallBlock)
]

if len(tool_calls) < 1:
if error_on_no_tool_call:
Expand All @@ -871,24 +934,16 @@ def get_tool_calls_from_response(

tool_selections = []
for tool_call in tool_calls:
if (
"input" not in tool_call
or "id" not in tool_call
or "name" not in tool_call
):
raise ValueError("Invalid tool call.")
if tool_call["type"] != "tool_use":
raise ValueError("Invalid tool type. Unsupported by Anthropic")
argument_dict = (
json.loads(tool_call["input"])
if isinstance(tool_call["input"], str)
else tool_call["input"]
json.loads(tool_call.tool_kwargs)
if isinstance(tool_call.tool_kwargs, str)
else tool_call.tool_kwargs
)

tool_selections.append(
ToolSelection(
tool_id=tool_call["id"],
tool_name=tool_call["name"],
tool_id=tool_call.tool_call_id or "",
tool_name=tool_call.tool_name,
tool_kwargs=argument_dict,
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
CitationBlock,
ThinkingBlock,
ContentBlock,
ToolCallBlock,
)

from anthropic.types import (
Expand All @@ -24,6 +25,7 @@
DocumentBlockParam,
ThinkingBlockParam,
ImageBlockParam,
ToolUseBlockParam,
CacheControlEphemeralParam,
Base64PDFSourceParam,
)
Expand Down Expand Up @@ -269,6 +271,18 @@ def blocks_to_anthropic_blocks(
if global_cache_control:
anthropic_blocks[-1]["cache_control"] = global_cache_control

elif isinstance(block, ToolCallBlock):
anthropic_blocks.append(
ToolUseBlockParam(
id=block.tool_call_id or "",
input=block.tool_kwargs,
name=block.tool_name,
type="tool_use",
)
)
if global_cache_control:
anthropic_blocks[-1]["cache_control"] = global_cache_control

elif isinstance(block, CachePoint):
if len(anthropic_blocks) > 0:
anthropic_blocks[-1]["cache_control"] = CacheControlEphemeralParam(
Expand All @@ -282,6 +296,7 @@ def blocks_to_anthropic_blocks(
else:
raise ValueError(f"Unsupported block type: {type(block)}")

# keep this code for compatibility with older chat histories
tool_calls = kwargs.get("tool_calls", [])
for tool_call in tool_calls:
assert "id" in tool_call
Expand Down Expand Up @@ -359,9 +374,15 @@ def messages_to_anthropic_messages(


def force_single_tool_call(response: ChatResponse) -> None:
tool_calls = response.message.additional_kwargs.get("tool_calls", [])
tool_calls = [
block for block in response.message.blocks if isinstance(block, ToolCallBlock)
]
if len(tool_calls) > 1:
response.message.additional_kwargs["tool_calls"] = [tool_calls[0]]
response.message.blocks = [
block
for block in response.message.blocks
if not isinstance(block, ToolCallBlock)
] + [tool_calls[0]]


# Anthropic models that support prompt caching
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ dev = [

[project]
name = "llama-index-llms-anthropic"
version = "0.9.5"
version = "0.10.0"
description = "llama-index llms anthropic integration"
authors = [{name = "Your Name", email = "you@example.com"}]
requires-python = ">=3.9,<4.0"
readme = "README.md"
license = "MIT"
dependencies = [
"anthropic[bedrock, vertex]>=0.69.0",
"llama-index-core>=0.14.3,<0.15",
"llama-index-core>=0.14.5,<0.15",
]

[tool.codespell]
Expand Down
Loading
Loading