8
8
from typing import Dict , List , Union , TypeVar , cast
9
9
from typing_extensions import overload , override
10
10
from azure .ai .evaluation ._evaluators ._common import PromptyEvaluatorBase
11
- from azure .ai .evaluation ._exceptions import ErrorBlame , ErrorCategory , ErrorTarget , EvaluationException
11
+ from azure .ai .evaluation ._exceptions import (
12
+ ErrorBlame ,
13
+ ErrorCategory ,
14
+ ErrorTarget ,
15
+ EvaluationException ,
16
+ )
12
17
from ..._common .utils import check_score_is_valid
13
18
from azure .ai .evaluation ._common ._experimental import experimental
14
19
@@ -74,7 +79,9 @@ class ToolCallAccuracyEvaluator(PromptyEvaluatorBase[Union[str, float]]):
74
79
75
80
_NO_TOOL_CALLS_MESSAGE = "No tool calls found in response or provided tool_calls."
76
81
_NO_TOOL_DEFINITIONS_MESSAGE = "Tool definitions must be provided."
77
- _TOOL_DEFINITIONS_MISSING_MESSAGE = "Tool definitions for all tool calls must be provided."
82
+ _TOOL_DEFINITIONS_MISSING_MESSAGE = (
83
+ "Tool definitions for all tool calls must be provided."
84
+ )
78
85
_INVALID_SCORE_MESSAGE = "Tool call accuracy score must be between 1 and 5."
79
86
80
87
_LLM_SCORE_KEY = "tool_calls_success_level"
@@ -83,11 +90,18 @@ class ToolCallAccuracyEvaluator(PromptyEvaluatorBase[Union[str, float]]):
83
90
"""Evaluator identifier, experimental and to be used only with evaluation in cloud."""
84
91
85
92
@override
86
- def __init__ (self , model_config , * , threshold = _DEFAULT_TOOL_CALL_ACCURACY_SCORE , ** kwargs ):
93
+ def __init__ (
94
+ self , model_config , * , threshold = _DEFAULT_TOOL_CALL_ACCURACY_SCORE , ** kwargs
95
+ ):
87
96
current_dir = os .path .dirname (__file__ )
88
97
prompty_path = os .path .join (current_dir , self ._PROMPTY_FILE )
89
98
self .threshold = threshold
90
- super ().__init__ (model_config = model_config , prompty_file = prompty_path , result_key = self ._RESULT_KEY , ** kwargs )
99
+ super ().__init__ (
100
+ model_config = model_config ,
101
+ prompty_file = prompty_path ,
102
+ result_key = self ._RESULT_KEY ,
103
+ ** kwargs ,
104
+ )
91
105
92
106
@overload
93
107
def __call__ (
@@ -164,7 +178,9 @@ def _convert_kwargs_to_eval_input(self, **kwargs):
164
178
tool_definitions = [tool_definitions ]
165
179
166
180
try :
167
- needed_tool_definitions = self ._extract_needed_tool_definitions (tool_calls , tool_definitions )
181
+ needed_tool_definitions = self ._extract_needed_tool_definitions (
182
+ tool_calls , tool_definitions
183
+ )
168
184
except EvaluationException as e :
169
185
return {"error_message" : self ._TOOL_DEFINITIONS_MISSING_MESSAGE }
170
186
if len (needed_tool_definitions ) == 0 :
@@ -173,9 +189,8 @@ def _convert_kwargs_to_eval_input(self, **kwargs):
173
189
return {
174
190
"query" : query ,
175
191
"tool_calls" : tool_calls ,
176
- "tool_definitions" : needed_tool_definitions
192
+ "tool_definitions" : needed_tool_definitions ,
177
193
}
178
-
179
194
180
195
@override
181
196
async def _do_eval (self , eval_input : Dict ) -> Dict [str , Union [float , str ]]: # type: ignore[override]
@@ -192,35 +207,39 @@ async def _do_eval(self, eval_input: Dict) -> Dict[str, Union[float, str]]: # t
192
207
193
208
if isinstance (llm_output , dict ):
194
209
score = llm_output .get (self ._LLM_SCORE_KEY , None )
195
- if not score or not check_score_is_valid (score , ToolCallAccuracyEvaluator ._MIN_TOOL_CALL_ACCURACY_SCORE , ToolCallAccuracyEvaluator ._MAX_TOOL_CALL_ACCURACY_SCORE ):
210
+ if not score or not check_score_is_valid (
211
+ score ,
212
+ ToolCallAccuracyEvaluator ._MIN_TOOL_CALL_ACCURACY_SCORE ,
213
+ ToolCallAccuracyEvaluator ._MAX_TOOL_CALL_ACCURACY_SCORE ,
214
+ ):
196
215
raise EvaluationException (
197
216
message = f"Invalid score value: { score } . Expected a number in range [{ ToolCallAccuracyEvaluator ._MIN_TOOL_CALL_ACCURACY_SCORE } , { ToolCallAccuracyEvaluator ._MAX_TOOL_CALL_ACCURACY_SCORE } ]." ,
198
217
internal_message = "Invalid score value." ,
199
218
category = ErrorCategory .FAILED_EXECUTION ,
200
219
blame = ErrorBlame .SYSTEM_ERROR ,
201
220
)
202
-
221
+
203
222
# Format the output
204
223
reason = llm_output .get ("chain_of_thought" , "" )
205
224
score = float (score )
206
- score_result = ' pass' if score >= self .threshold else ' fail'
225
+ score_result = " pass" if score >= self .threshold else " fail"
207
226
response_dict = {
208
227
self ._result_key : score ,
209
228
f"{ self ._result_key } _result" : score_result ,
210
229
f"{ self ._result_key } _threshold" : self .threshold ,
211
230
f"{ self ._result_key } _reason" : reason ,
212
- ' details' : llm_output .get (' details' , {}),
231
+ " details" : llm_output .get (" details" , {}),
213
232
}
214
233
return response_dict
215
-
234
+
216
235
else :
217
236
raise EvaluationException (
218
- message = "Tool call accuracy evaluator returned invalid output." ,
219
- blame = ErrorBlame .SYSTEM_ERROR ,
220
- category = ErrorCategory .FAILED_EXECUTION ,
221
- target = ErrorTarget .TOOL_CALL_ACCURACY_EVALUATOR ,
222
- )
223
-
237
+ message = "Tool call accuracy evaluator returned invalid output." ,
238
+ blame = ErrorBlame .SYSTEM_ERROR ,
239
+ category = ErrorCategory .FAILED_EXECUTION ,
240
+ target = ErrorTarget .TOOL_CALL_ACCURACY_EVALUATOR ,
241
+ )
242
+
224
243
async def _real_call (self , ** kwargs ):
225
244
"""The asynchronous call where real end-to-end evaluation logic is performed.
226
245
@@ -231,14 +250,14 @@ async def _real_call(self, **kwargs):
231
250
"""
232
251
# Convert inputs into list of evaluable inputs.
233
252
eval_input = self ._convert_kwargs_to_eval_input (** kwargs )
234
- if isinstance (eval_input , dict ) and eval_input .get (' error_message' ):
253
+ if isinstance (eval_input , dict ) and eval_input .get (" error_message" ):
235
254
# If there is an error message, return not applicable result
236
- return self ._not_applicable_result (eval_input .get (' error_message' ))
255
+ return self ._not_applicable_result (eval_input .get (" error_message" ))
237
256
# Do the evaluation
238
257
result = await self ._do_eval (eval_input )
239
258
# Return the result
240
259
return result
241
-
260
+
242
261
def _not_applicable_result (self , error_message ):
243
262
"""Return a result indicating that the tool call is not applicable for evaluation.
244
263
:param eval_input: The input to the evaluator.
@@ -249,13 +268,12 @@ def _not_applicable_result(self, error_message):
249
268
# If no tool calls were made or tool call type is not supported, return not applicable result
250
269
return {
251
270
self ._result_key : self ._NOT_APPLICABLE_RESULT ,
252
- f"{ self ._result_key } _result" : ' pass' ,
271
+ f"{ self ._result_key } _result" : " pass" ,
253
272
f"{ self ._result_key } _threshold" : self .threshold ,
254
273
f"{ self ._result_key } _reason" : error_message ,
255
274
"details" : {},
256
-
257
275
}
258
-
276
+
259
277
def _parse_tools_from_response (self , response ):
260
278
"""Parse the response to extract tool calls and results.
261
279
:param response: The response to parse.
@@ -266,29 +284,40 @@ def _parse_tools_from_response(self, response):
266
284
tool_calls = []
267
285
tool_results_map = {}
268
286
if isinstance (response , list ):
269
- for message in response :
287
+ for message in response :
270
288
# Extract tool calls from assistant messages
271
- if message .get ("role" ) == "assistant" and isinstance (message .get ("content" ), list ):
289
+ if message .get ("role" ) == "assistant" and isinstance (
290
+ message .get ("content" ), list
291
+ ):
272
292
for content_item in message .get ("content" ):
273
- if isinstance (content_item , dict ) and content_item .get ("type" ) == "tool_call" :
293
+ if (
294
+ isinstance (content_item , dict )
295
+ and content_item .get ("type" ) == "tool_call"
296
+ ):
274
297
tool_calls .append (content_item )
275
298
276
299
# Extract tool results from tool messages
277
300
elif message .get ("role" ) == "tool" and message .get ("tool_call_id" ):
278
301
tool_call_id = message .get ("tool_call_id" )
279
- if isinstance (message .get ("content" ), list ) and len (message .get ("content" )) > 0 :
302
+ if (
303
+ isinstance (message .get ("content" ), list )
304
+ and len (message .get ("content" )) > 0
305
+ ):
280
306
result_content = message .get ("content" )[0 ]
281
- if isinstance (result_content , dict ) and result_content .get ("type" ) == "tool_result" :
307
+ if (
308
+ isinstance (result_content , dict )
309
+ and result_content .get ("type" ) == "tool_result"
310
+ ):
282
311
tool_results_map [tool_call_id ] = result_content
283
312
284
313
# Attach results to their corresponding calls
285
314
for tool_call in tool_calls :
286
315
tool_call_id = tool_call .get ("tool_call_id" )
287
316
if tool_call_id in tool_results_map :
288
- tool_call ["tool_result" ] = tool_results_map [tool_call_id ][' tool_result' ]
317
+ tool_call ["tool_result" ] = tool_results_map [tool_call_id ][" tool_result" ]
289
318
290
319
return tool_calls
291
-
320
+
292
321
def _extract_needed_tool_definitions (self , tool_calls , tool_definitions ):
293
322
"""Extract the tool definitions that are needed for the provided tool calls.
294
323
:param tool_calls: List of tool calls to evaluate.
@@ -302,9 +331,12 @@ def _extract_needed_tool_definitions(self, tool_calls, tool_definitions):
302
331
for tool_call in tool_calls :
303
332
if isinstance (tool_call , dict ) and tool_call .get ("type" ) == "tool_call" :
304
333
tool_name = tool_call .get ("name" )
305
- tool_definition = [tool for tool in tool_definitions
306
- if tool .get ("name" ) == tool_name and
307
- tool .get ("type" , "function" ) == "function" ]
334
+ tool_definition = [
335
+ tool
336
+ for tool in tool_definitions
337
+ if tool .get ("name" ) == tool_name
338
+ and tool .get ("type" , "function" ) == "function"
339
+ ]
308
340
if len (tool_definition ) > 0 :
309
341
needed_tool_definitions .extend (tool_definition )
310
342
else :
0 commit comments