Skip to content

Commit 4464a77

Browse files
author
Salma Elshafey
committed
Reformat using Black
1 parent fef02f5 commit 4464a77

File tree

3 files changed

+304
-143
lines changed

3 files changed

+304
-143
lines changed

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_tool_call_accuracy/_tool_call_accuracy.py

Lines changed: 66 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
from typing import Dict, List, Union, TypeVar, cast
99
from typing_extensions import overload, override
1010
from azure.ai.evaluation._evaluators._common import PromptyEvaluatorBase
11-
from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
11+
from azure.ai.evaluation._exceptions import (
12+
ErrorBlame,
13+
ErrorCategory,
14+
ErrorTarget,
15+
EvaluationException,
16+
)
1217
from ..._common.utils import check_score_is_valid
1318
from azure.ai.evaluation._common._experimental import experimental
1419

@@ -74,7 +79,9 @@ class ToolCallAccuracyEvaluator(PromptyEvaluatorBase[Union[str, float]]):
7479

7580
_NO_TOOL_CALLS_MESSAGE = "No tool calls found in response or provided tool_calls."
7681
_NO_TOOL_DEFINITIONS_MESSAGE = "Tool definitions must be provided."
77-
_TOOL_DEFINITIONS_MISSING_MESSAGE = "Tool definitions for all tool calls must be provided."
82+
_TOOL_DEFINITIONS_MISSING_MESSAGE = (
83+
"Tool definitions for all tool calls must be provided."
84+
)
7885
_INVALID_SCORE_MESSAGE = "Tool call accuracy score must be between 1 and 5."
7986

8087
_LLM_SCORE_KEY = "tool_calls_success_level"
@@ -83,11 +90,18 @@ class ToolCallAccuracyEvaluator(PromptyEvaluatorBase[Union[str, float]]):
8390
"""Evaluator identifier, experimental and to be used only with evaluation in cloud."""
8491

8592
@override
86-
def __init__(self, model_config, *, threshold=_DEFAULT_TOOL_CALL_ACCURACY_SCORE, **kwargs):
93+
def __init__(
94+
self, model_config, *, threshold=_DEFAULT_TOOL_CALL_ACCURACY_SCORE, **kwargs
95+
):
8796
current_dir = os.path.dirname(__file__)
8897
prompty_path = os.path.join(current_dir, self._PROMPTY_FILE)
8998
self.threshold = threshold
90-
super().__init__(model_config=model_config, prompty_file=prompty_path, result_key=self._RESULT_KEY, **kwargs)
99+
super().__init__(
100+
model_config=model_config,
101+
prompty_file=prompty_path,
102+
result_key=self._RESULT_KEY,
103+
**kwargs,
104+
)
91105

92106
@overload
93107
def __call__(
@@ -164,7 +178,9 @@ def _convert_kwargs_to_eval_input(self, **kwargs):
164178
tool_definitions = [tool_definitions]
165179

166180
try:
167-
needed_tool_definitions = self._extract_needed_tool_definitions(tool_calls, tool_definitions)
181+
needed_tool_definitions = self._extract_needed_tool_definitions(
182+
tool_calls, tool_definitions
183+
)
168184
except EvaluationException as e:
169185
return {"error_message": self._TOOL_DEFINITIONS_MISSING_MESSAGE}
170186
if len(needed_tool_definitions) == 0:
@@ -173,9 +189,8 @@ def _convert_kwargs_to_eval_input(self, **kwargs):
173189
return {
174190
"query": query,
175191
"tool_calls": tool_calls,
176-
"tool_definitions": needed_tool_definitions
192+
"tool_definitions": needed_tool_definitions,
177193
}
178-
179194

180195
@override
181196
async def _do_eval(self, eval_input: Dict) -> Dict[str, Union[float, str]]: # type: ignore[override]
@@ -192,35 +207,39 @@ async def _do_eval(self, eval_input: Dict) -> Dict[str, Union[float, str]]: # t
192207

193208
if isinstance(llm_output, dict):
194209
score = llm_output.get(self._LLM_SCORE_KEY, None)
195-
if not score or not check_score_is_valid(score, ToolCallAccuracyEvaluator._MIN_TOOL_CALL_ACCURACY_SCORE, ToolCallAccuracyEvaluator._MAX_TOOL_CALL_ACCURACY_SCORE):
210+
if not score or not check_score_is_valid(
211+
score,
212+
ToolCallAccuracyEvaluator._MIN_TOOL_CALL_ACCURACY_SCORE,
213+
ToolCallAccuracyEvaluator._MAX_TOOL_CALL_ACCURACY_SCORE,
214+
):
196215
raise EvaluationException(
197216
message=f"Invalid score value: {score}. Expected a number in range [{ToolCallAccuracyEvaluator._MIN_TOOL_CALL_ACCURACY_SCORE}, {ToolCallAccuracyEvaluator._MAX_TOOL_CALL_ACCURACY_SCORE}].",
198217
internal_message="Invalid score value.",
199218
category=ErrorCategory.FAILED_EXECUTION,
200219
blame=ErrorBlame.SYSTEM_ERROR,
201220
)
202-
221+
203222
# Format the output
204223
reason = llm_output.get("chain_of_thought", "")
205224
score = float(score)
206-
score_result = 'pass' if score >= self.threshold else 'fail'
225+
score_result = "pass" if score >= self.threshold else "fail"
207226
response_dict = {
208227
self._result_key: score,
209228
f"{self._result_key}_result": score_result,
210229
f"{self._result_key}_threshold": self.threshold,
211230
f"{self._result_key}_reason": reason,
212-
'details': llm_output.get('details', {}),
231+
"details": llm_output.get("details", {}),
213232
}
214233
return response_dict
215-
234+
216235
else:
217236
raise EvaluationException(
218-
message="Tool call accuracy evaluator returned invalid output.",
219-
blame=ErrorBlame.SYSTEM_ERROR,
220-
category=ErrorCategory.FAILED_EXECUTION,
221-
target=ErrorTarget.TOOL_CALL_ACCURACY_EVALUATOR,
222-
)
223-
237+
message="Tool call accuracy evaluator returned invalid output.",
238+
blame=ErrorBlame.SYSTEM_ERROR,
239+
category=ErrorCategory.FAILED_EXECUTION,
240+
target=ErrorTarget.TOOL_CALL_ACCURACY_EVALUATOR,
241+
)
242+
224243
async def _real_call(self, **kwargs):
225244
"""The asynchronous call where real end-to-end evaluation logic is performed.
226245
@@ -231,14 +250,14 @@ async def _real_call(self, **kwargs):
231250
"""
232251
# Convert inputs into list of evaluable inputs.
233252
eval_input = self._convert_kwargs_to_eval_input(**kwargs)
234-
if isinstance(eval_input, dict) and eval_input.get('error_message'):
253+
if isinstance(eval_input, dict) and eval_input.get("error_message"):
235254
# If there is an error message, return not applicable result
236-
return self._not_applicable_result(eval_input.get('error_message'))
255+
return self._not_applicable_result(eval_input.get("error_message"))
237256
# Do the evaluation
238257
result = await self._do_eval(eval_input)
239258
# Return the result
240259
return result
241-
260+
242261
def _not_applicable_result(self, error_message):
243262
"""Return a result indicating that the tool call is not applicable for evaluation.
244263
:param eval_input: The input to the evaluator.
@@ -249,13 +268,12 @@ def _not_applicable_result(self, error_message):
249268
# If no tool calls were made or tool call type is not supported, return not applicable result
250269
return {
251270
self._result_key: self._NOT_APPLICABLE_RESULT,
252-
f"{self._result_key}_result": 'pass',
271+
f"{self._result_key}_result": "pass",
253272
f"{self._result_key}_threshold": self.threshold,
254273
f"{self._result_key}_reason": error_message,
255274
"details": {},
256-
257275
}
258-
276+
259277
def _parse_tools_from_response(self, response):
260278
"""Parse the response to extract tool calls and results.
261279
:param response: The response to parse.
@@ -266,29 +284,40 @@ def _parse_tools_from_response(self, response):
266284
tool_calls = []
267285
tool_results_map = {}
268286
if isinstance(response, list):
269-
for message in response:
287+
for message in response:
270288
# Extract tool calls from assistant messages
271-
if message.get("role") == "assistant" and isinstance(message.get("content"), list):
289+
if message.get("role") == "assistant" and isinstance(
290+
message.get("content"), list
291+
):
272292
for content_item in message.get("content"):
273-
if isinstance(content_item, dict) and content_item.get("type") == "tool_call":
293+
if (
294+
isinstance(content_item, dict)
295+
and content_item.get("type") == "tool_call"
296+
):
274297
tool_calls.append(content_item)
275298

276299
# Extract tool results from tool messages
277300
elif message.get("role") == "tool" and message.get("tool_call_id"):
278301
tool_call_id = message.get("tool_call_id")
279-
if isinstance(message.get("content"), list) and len(message.get("content")) > 0:
302+
if (
303+
isinstance(message.get("content"), list)
304+
and len(message.get("content")) > 0
305+
):
280306
result_content = message.get("content")[0]
281-
if isinstance(result_content, dict) and result_content.get("type") == "tool_result":
307+
if (
308+
isinstance(result_content, dict)
309+
and result_content.get("type") == "tool_result"
310+
):
282311
tool_results_map[tool_call_id] = result_content
283312

284313
# Attach results to their corresponding calls
285314
for tool_call in tool_calls:
286315
tool_call_id = tool_call.get("tool_call_id")
287316
if tool_call_id in tool_results_map:
288-
tool_call["tool_result"] = tool_results_map[tool_call_id]['tool_result']
317+
tool_call["tool_result"] = tool_results_map[tool_call_id]["tool_result"]
289318

290319
return tool_calls
291-
320+
292321
def _extract_needed_tool_definitions(self, tool_calls, tool_definitions):
293322
"""Extract the tool definitions that are needed for the provided tool calls.
294323
:param tool_calls: List of tool calls to evaluate.
@@ -302,9 +331,12 @@ def _extract_needed_tool_definitions(self, tool_calls, tool_definitions):
302331
for tool_call in tool_calls:
303332
if isinstance(tool_call, dict) and tool_call.get("type") == "tool_call":
304333
tool_name = tool_call.get("name")
305-
tool_definition = [tool for tool in tool_definitions
306-
if tool.get("name") == tool_name and
307-
tool.get("type", "function") == "function"]
334+
tool_definition = [
335+
tool
336+
for tool in tool_definitions
337+
if tool.get("name") == tool_name
338+
and tool.get("type", "function") == "function"
339+
]
308340
if len(tool_definition) > 0:
309341
needed_tool_definitions.extend(tool_definition)
310342
else:

sdk/evaluation/azure-ai-evaluation/tests/unittests/test_agent_evaluators.py

Lines changed: 83 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -12,80 +12,108 @@ def test_tool_call_accuracy_evaluator_missing_inputs(self, mock_model_config):
1212
# Test with missing tool_calls and response
1313
result = tool_call_accuracy(
1414
query="Where is the Eiffel Tower?",
15-
tool_definitions=[{
16-
"name": "fetch_weather",
17-
"description": "Fetches the weather information for the specified location.",
18-
"parameters": {
19-
"type": "object",
20-
"properties": {
21-
"location": {
22-
"type": "string",
23-
"description": "The location to fetch weather for."
24-
}
25-
}
15+
tool_definitions=[
16+
{
17+
"name": "fetch_weather",
18+
"description": "Fetches the weather information for the specified location.",
19+
"parameters": {
20+
"type": "object",
21+
"properties": {
22+
"location": {
23+
"type": "string",
24+
"description": "The location to fetch weather for.",
25+
}
26+
},
27+
},
2628
}
27-
}]
29+
],
30+
)
31+
assert (
32+
result[ToolCallAccuracyEvaluator._RESULT_KEY]
33+
== ToolCallAccuracyEvaluator._NOT_APPLICABLE_RESULT
34+
)
35+
assert (
36+
ToolCallAccuracyEvaluator._NO_TOOL_CALLS_MESSAGE
37+
in result[f"{ToolCallAccuracyEvaluator._RESULT_KEY}_reason"]
2838
)
29-
assert result[ToolCallAccuracyEvaluator._RESULT_KEY] == ToolCallAccuracyEvaluator._NOT_APPLICABLE_RESULT
30-
assert ToolCallAccuracyEvaluator._NO_TOOL_CALLS_MESSAGE in result[f"{ToolCallAccuracyEvaluator._RESULT_KEY}_reason"]
3139

3240
# Test with missing tool_definitions
3341
result = tool_call_accuracy(
3442
query="Where is the Eiffel Tower?",
3543
tool_definitions=[],
36-
tool_calls=[{
37-
"type": "tool_call",
38-
"name": "fetch_weather",
39-
"arguments": {
40-
"location": "Tokyo"
44+
tool_calls=[
45+
{
46+
"type": "tool_call",
47+
"name": "fetch_weather",
48+
"arguments": {"location": "Tokyo"},
4149
}
42-
}]
50+
],
51+
)
52+
assert (
53+
result[ToolCallAccuracyEvaluator._RESULT_KEY]
54+
== ToolCallAccuracyEvaluator._NOT_APPLICABLE_RESULT
55+
)
56+
assert (
57+
ToolCallAccuracyEvaluator._NO_TOOL_DEFINITIONS_MESSAGE
58+
in result[f"{ToolCallAccuracyEvaluator._RESULT_KEY}_reason"]
4359
)
44-
assert result[ToolCallAccuracyEvaluator._RESULT_KEY] == ToolCallAccuracyEvaluator._NOT_APPLICABLE_RESULT
45-
assert ToolCallAccuracyEvaluator._NO_TOOL_DEFINITIONS_MESSAGE in result[f"{ToolCallAccuracyEvaluator._RESULT_KEY}_reason"]
4660

4761
# Test with response that has no tool calls
4862
result = tool_call_accuracy(
4963
query="Where is the Eiffel Tower?",
5064
response="The Eiffel Tower is in Paris.",
51-
tool_definitions=[{
52-
"name": "fetch_weather",
53-
"description": "Fetches the weather information for the specified location.",
54-
"parameters": {
55-
"type": "object",
56-
"properties": {
57-
"location": {
58-
"type": "string",
59-
"description": "The location to fetch weather for."
60-
}
61-
}
65+
tool_definitions=[
66+
{
67+
"name": "fetch_weather",
68+
"description": "Fetches the weather information for the specified location.",
69+
"parameters": {
70+
"type": "object",
71+
"properties": {
72+
"location": {
73+
"type": "string",
74+
"description": "The location to fetch weather for.",
75+
}
76+
},
77+
},
6278
}
63-
}]
79+
],
80+
)
81+
assert (
82+
result[ToolCallAccuracyEvaluator._RESULT_KEY]
83+
== ToolCallAccuracyEvaluator._NOT_APPLICABLE_RESULT
84+
)
85+
assert (
86+
ToolCallAccuracyEvaluator._NO_TOOL_CALLS_MESSAGE
87+
in result[f"{ToolCallAccuracyEvaluator._RESULT_KEY}_reason"]
6488
)
65-
assert result[ToolCallAccuracyEvaluator._RESULT_KEY] == ToolCallAccuracyEvaluator._NOT_APPLICABLE_RESULT
66-
assert ToolCallAccuracyEvaluator._NO_TOOL_CALLS_MESSAGE in result[f"{ToolCallAccuracyEvaluator._RESULT_KEY}_reason"]
6789

6890
# Test with tool call for which definition is not provided
6991
result = tool_call_accuracy(
7092
query="Where is the Eiffel Tower?",
71-
tool_calls=[{
72-
"type": "tool_call",
73-
"name": "some_other_tool",
74-
"arguments": {}
75-
}],
76-
tool_definitions=[{
77-
"name": "fetch_weather",
78-
"description": "Fetches the weather information for the specified location.",
79-
"parameters": {
80-
"type": "object",
81-
"properties": {
82-
"location": {
83-
"type": "string",
84-
"description": "The location to fetch weather for."
85-
}
86-
}
93+
tool_calls=[
94+
{"type": "tool_call", "name": "some_other_tool", "arguments": {}}
95+
],
96+
tool_definitions=[
97+
{
98+
"name": "fetch_weather",
99+
"description": "Fetches the weather information for the specified location.",
100+
"parameters": {
101+
"type": "object",
102+
"properties": {
103+
"location": {
104+
"type": "string",
105+
"description": "The location to fetch weather for.",
106+
}
107+
},
108+
},
87109
}
88-
}]
110+
],
111+
)
112+
assert (
113+
result[ToolCallAccuracyEvaluator._RESULT_KEY]
114+
== ToolCallAccuracyEvaluator._NOT_APPLICABLE_RESULT
115+
)
116+
assert (
117+
ToolCallAccuracyEvaluator._TOOL_DEFINITIONS_MISSING_MESSAGE
118+
in result[f"{ToolCallAccuracyEvaluator._RESULT_KEY}_reason"]
89119
)
90-
assert result[ToolCallAccuracyEvaluator._RESULT_KEY] == ToolCallAccuracyEvaluator._NOT_APPLICABLE_RESULT
91-
assert ToolCallAccuracyEvaluator._TOOL_DEFINITIONS_MISSING_MESSAGE in result[f"{ToolCallAccuracyEvaluator._RESULT_KEY}_reason"]

0 commit comments

Comments
 (0)