Skip to content

Commit e4acf75

Browse files
dirkbrndysolanky
andauthored
Fix tool calling on Llama API (#3064)
## Summary The ReasoningTools require optional params to be handled correctly. Fixes #3063 ## Type of change - [ ] Bug fix - [ ] New feature - [ ] Breaking change - [ ] Improvement - [ ] Model update - [ ] Other: --- ## Checklist - [ ] Code complies with style guidelines - [ ] Ran format/validation scripts (`./scripts/format.sh` and `./scripts/validate.sh`) - [ ] Self-review completed - [ ] Documentation updated (comments, docstrings) - [ ] Examples and guides: Relevant cookbook examples have been included or updated (if applicable) - [ ] Tested in clean environment - [ ] Tests added/updated (if applicable) --- ## Additional Notes Add any important context (deployment instructions, screenshots, security considerations, etc.) --------- Co-authored-by: ysolanky <yash@phidata.com>
1 parent 703151f commit e4acf75

File tree

4 files changed

+63
-3
lines changed

4 files changed

+63
-3
lines changed

libs/agno/agno/models/meta/llama.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,13 @@ def request_kwargs(self) -> Dict[str, Any]:
152152
# Add tools
153153
if self._tools is not None and len(self._tools) > 0:
154154
request_params["tools"] = self._tools
155+
156+
# Fix optional parameters where the "type" is [<type>, null]
157+
for tool in request_params["tools"]: # type: ignore
158+
if "parameters" in tool["function"] and "properties" in tool["function"]["parameters"]: # type: ignore
159+
for _, obj in tool["function"]["parameters"].get("properties", {}).items(): # type: ignore
160+
if isinstance(obj["type"], list):
161+
obj["type"] = obj["type"][0]
155162

156163
if self.response_format is not None:
157164
request_params["response_format"] = self.response_format
@@ -267,7 +274,6 @@ async def ainvoke_stream(self, messages: List[Message]) -> AsyncIterator[CreateC
267274
log_error(f"Error from Llama API: {e}")
268275
raise ModelProviderError(message=str(e), model_name=self.name, model_id=self.id) from e
269276

270-
# Override base method
271277
@staticmethod
272278
def parse_tool_calls(tool_calls_data: List[EventDeltaToolCallDeltaFunction]) -> List[Dict[str, Any]]:
273279
"""

libs/agno/agno/models/meta/llama_openai.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,65 @@ class LlamaOpenAI(OpenAILike):
3333

3434
api_key: Optional[str] = getenv("LLAMA_API_KEY")
3535
base_url: Optional[str] = "https://api.llama.com/compat/v1/"
36+
37+
# Request parameters
38+
max_completion_tokens: Optional[int] = None
39+
repetition_penalty: Optional[float] = None
40+
temperature: Optional[float] = None
41+
top_p: Optional[float] = None
42+
top_k: Optional[int] = None
43+
extra_headers: Optional[Any] = None
44+
extra_query: Optional[Any] = None
45+
extra_body: Optional[Any] = None
46+
request_params: Optional[Dict[str, Any]] = None
3647

3748
supports_native_structured_outputs: bool = False
3849
supports_json_schema_outputs: bool = True
50+
51+
52+
@property
53+
def request_kwargs(self) -> Dict[str, Any]:
54+
"""
55+
Returns keyword arguments for API requests.
56+
57+
Returns:
58+
Dict[str, Any]: A dictionary of keyword arguments for API requests.
59+
"""
60+
# Define base request parameters
61+
base_params = {
62+
"max_completion_tokens": self.max_completion_tokens,
63+
"repetition_penalty": self.repetition_penalty,
64+
"temperature": self.temperature,
65+
"top_p": self.top_p,
66+
"top_k": self.top_k,
67+
"extra_headers": self.extra_headers,
68+
"extra_query": self.extra_query,
69+
"extra_body": self.extra_body,
70+
"request_params": self.request_params,
71+
}
72+
73+
# Filter out None values
74+
request_params = {k: v for k, v in base_params.items() if v is not None}
75+
76+
# Add tools
77+
if self._tools is not None and len(self._tools) > 0:
78+
request_params["tools"] = self._tools
79+
80+
# Fix optional parameters where the "type" is [<type>, null]
81+
for tool in request_params["tools"]: # type: ignore
82+
if "parameters" in tool["function"] and "properties" in tool["function"]["parameters"]: # type: ignore
83+
for _, obj in tool["function"]["parameters"].get("properties", {}).items(): # type: ignore
84+
if isinstance(obj["type"], list):
85+
obj["type"] = obj["type"][0]
86+
87+
if self.response_format is not None:
88+
request_params["response_format"] = self.response_format
89+
90+
# Add additional request params if provided
91+
if self.request_params:
92+
request_params.update(self.request_params)
93+
94+
return request_params
3995

4096
def _format_message(self, message: Message) -> Dict[str, Any]:
4197
"""

libs/agno/tests/integration/models/meta/llama/test_tool_use.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,6 @@ def get_the_weather_in_tokyo():
180180
assert "Tokyo" in response.content
181181

182182

183-
@pytest.mark.skip("Llama models do not not accept optional parameters")
184183
def test_tool_call_custom_tool_optional_parameters():
185184
def get_the_weather(city: Optional[str] = None):
186185
"""

libs/agno/tests/integration/models/meta/llama_openai/test_tool_use.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,6 @@ def get_the_weather_in_tokyo():
180180
assert "Tokyo" in response.content
181181

182182

183-
@pytest.mark.skip("Llama models do not not accept optional parameters")
184183
def test_tool_call_custom_tool_optional_parameters():
185184
def get_the_weather(city: Optional[str] = None):
186185
"""

0 commit comments

Comments
 (0)