4
4
5
5
import inspect
6
6
from abc import ABC , abstractmethod
7
- from typing import Any , Callable , Dict , Generic , List , TypedDict , TypeVar , Union , cast , final , Optional
7
+ from typing import (
8
+ Any ,
9
+ Callable ,
10
+ Dict ,
11
+ Generic ,
12
+ List ,
13
+ TypedDict ,
14
+ TypeVar ,
15
+ Union ,
16
+ cast ,
17
+ final ,
18
+ Optional ,
19
+ )
8
20
9
21
from azure .ai .evaluation ._legacy ._adapters .utils import async_run_allowing_running_loop
10
22
from typing_extensions import ParamSpec , TypeAlias , get_overloads
11
23
12
- from azure .ai .evaluation ._exceptions import ErrorBlame , ErrorCategory , ErrorTarget , EvaluationException
24
+ from azure .ai .evaluation ._exceptions import (
25
+ ErrorBlame ,
26
+ ErrorCategory ,
27
+ ErrorTarget ,
28
+ EvaluationException ,
29
+ )
13
30
from azure .ai .evaluation ._common .utils import remove_optional_singletons
14
- from azure .ai .evaluation ._constants import _AggregationType , EVALUATION_PASS_FAIL_MAPPING
31
+ from azure .ai .evaluation ._constants import (
32
+ _AggregationType ,
33
+ EVALUATION_PASS_FAIL_MAPPING ,
34
+ )
15
35
from azure .ai .evaluation ._model_configurations import Conversation
16
36
from azure .ai .evaluation ._common ._experimental import experimental
17
37
@@ -101,14 +121,18 @@ def __init__(
101
121
not_singleton_inputs : List [str ] = ["conversation" , "kwargs" ],
102
122
eval_last_turn : bool = False ,
103
123
conversation_aggregation_type : _AggregationType = _AggregationType .MEAN ,
104
- conversation_aggregator_override : Optional [Callable [[List [float ]], float ]] = None ,
124
+ conversation_aggregator_override : Optional [
125
+ Callable [[List [float ]], float ]
126
+ ] = None ,
105
127
_higher_is_better : Optional [bool ] = True ,
106
128
):
107
129
self ._not_singleton_inputs = not_singleton_inputs
108
130
self ._eval_last_turn = eval_last_turn
109
131
self ._singleton_inputs = self ._derive_singleton_inputs ()
110
132
self ._async_evaluator = AsyncEvaluatorBase (self ._real_call )
111
- self ._conversation_aggregation_function = GetAggregator (conversation_aggregation_type )
133
+ self ._conversation_aggregation_function = GetAggregator (
134
+ conversation_aggregation_type
135
+ )
112
136
self ._higher_is_better = _higher_is_better
113
137
self ._threshold = threshold
114
138
if conversation_aggregator_override is not None :
@@ -170,13 +194,18 @@ def _derive_singleton_inputs(self) -> List[str]:
170
194
singletons = []
171
195
for call_signature in call_signatures :
172
196
params = call_signature .parameters
173
- if any (not_singleton_input in params for not_singleton_input in self ._not_singleton_inputs ):
197
+ if any (
198
+ not_singleton_input in params
199
+ for not_singleton_input in self ._not_singleton_inputs
200
+ ):
174
201
continue
175
202
# exclude self since it is not a singleton input
176
203
singletons .extend ([p for p in params if p != "self" ])
177
204
return singletons
178
205
179
- def _derive_conversation_converter (self ) -> Callable [[Dict ], List [DerivedEvalInput ]]:
206
+ def _derive_conversation_converter (
207
+ self ,
208
+ ) -> Callable [[Dict ], List [DerivedEvalInput ]]:
180
209
"""Produce the function that will be used to convert conversations to a list of evaluable inputs.
181
210
This uses the inputs derived from the _derive_singleton_inputs function to determine which
182
211
aspects of a conversation ought to be extracted.
@@ -235,7 +264,9 @@ def converter(conversation: Dict) -> List[DerivedEvalInput]:
235
264
236
265
return converter
237
266
238
- def _derive_multi_modal_conversation_converter (self ) -> Callable [[Dict ], List [Dict [str , Any ]]]:
267
+ def _derive_multi_modal_conversation_converter (
268
+ self ,
269
+ ) -> Callable [[Dict ], List [Dict [str , Any ]]]:
239
270
"""Produce the function that will be used to convert multi-modal conversations to a list of evaluable inputs.
240
271
This uses the inputs derived from the _derive_singleton_inputs function to determine which
241
272
aspects of a conversation ought to be extracted.
@@ -269,12 +300,16 @@ def multi_modal_converter(conversation: Dict) -> List[Dict[str, Any]]:
269
300
if len (user_messages ) != len (assistant_messages ):
270
301
raise EvaluationException (
271
302
message = "Mismatched number of user and assistant messages." ,
272
- internal_message = ("Mismatched number of user and assistant messages." ),
303
+ internal_message = (
304
+ "Mismatched number of user and assistant messages."
305
+ ),
273
306
)
274
307
if len (assistant_messages ) > 1 :
275
308
raise EvaluationException (
276
309
message = "Conversation can have only one assistant message." ,
277
- internal_message = ("Conversation can have only one assistant message." ),
310
+ internal_message = (
311
+ "Conversation can have only one assistant message."
312
+ ),
278
313
)
279
314
eval_conv_inputs = []
280
315
for user_msg , assist_msg in zip (user_messages , assistant_messages ):
@@ -283,12 +318,16 @@ def multi_modal_converter(conversation: Dict) -> List[Dict[str, Any]]:
283
318
conv_messages .append (system_messages [0 ])
284
319
conv_messages .append (user_msg )
285
320
conv_messages .append (assist_msg )
286
- eval_conv_inputs .append ({"conversation" : Conversation (messages = conv_messages )})
321
+ eval_conv_inputs .append (
322
+ {"conversation" : Conversation (messages = conv_messages )}
323
+ )
287
324
return eval_conv_inputs
288
325
289
326
return multi_modal_converter
290
327
291
- def _convert_kwargs_to_eval_input (self , ** kwargs ) -> Union [List [Dict ], List [DerivedEvalInput ], Dict [str , Any ]]:
328
+ def _convert_kwargs_to_eval_input (
329
+ self , ** kwargs
330
+ ) -> Union [List [Dict ], List [DerivedEvalInput ], Dict [str , Any ]]:
292
331
"""Convert an arbitrary input into a list of inputs for evaluators.
293
332
It is assumed that evaluators generally make use of their inputs in one of two ways.
294
333
Either they receive a collection of keyname inputs that are all single values
@@ -353,11 +392,17 @@ def _is_multi_modal_conversation(self, conversation: Dict) -> bool:
353
392
if "content" in message :
354
393
content = message .get ("content" , "" )
355
394
if isinstance (content , list ):
356
- if any (item .get ("type" ) == "image_url" and "url" in item .get ("image_url" , {}) for item in content ):
395
+ if any (
396
+ item .get ("type" ) == "image_url"
397
+ and "url" in item .get ("image_url" , {})
398
+ for item in content
399
+ ):
357
400
return True
358
401
return False
359
402
360
- def _aggregate_results (self , per_turn_results : List [DoEvalResult [T_EvalValue ]]) -> AggregateResult [T_EvalValue ]:
403
+ def _aggregate_results (
404
+ self , per_turn_results : List [DoEvalResult [T_EvalValue ]]
405
+ ) -> AggregateResult [T_EvalValue ]:
361
406
"""Aggregate the evaluation results of each conversation turn into a single result.
362
407
363
408
Exact implementation might need to vary slightly depending on the results produced.
@@ -387,12 +432,16 @@ def _aggregate_results(self, per_turn_results: List[DoEvalResult[T_EvalValue]])
387
432
# Find and average all numeric values
388
433
for metric , values in evaluation_per_turn .items ():
389
434
if all (isinstance (value , (int , float )) for value in values ):
390
- aggregated [metric ] = self ._conversation_aggregation_function (cast (List [Union [int , float ]], values ))
435
+ aggregated [metric ] = self ._conversation_aggregation_function (
436
+ cast (List [Union [int , float ]], values )
437
+ )
391
438
# Slap the per-turn results back in.
392
439
aggregated ["evaluation_per_turn" ] = evaluation_per_turn
393
440
return aggregated
394
441
395
- async def _real_call (self , ** kwargs ) -> Union [DoEvalResult [T_EvalValue ], AggregateResult [T_EvalValue ]]:
442
+ async def _real_call (
443
+ self , ** kwargs
444
+ ) -> Union [DoEvalResult [T_EvalValue ], AggregateResult [T_EvalValue ]]:
396
445
"""The asynchronous call where real end-to-end evaluation logic is performed.
397
446
398
447
:keyword kwargs: The inputs to evaluate.
@@ -445,7 +494,9 @@ def _to_async(self) -> "AsyncEvaluatorBase":
445
494
446
495
@experimental
447
496
@final
448
- def _set_conversation_aggregation_type (self , conversation_aggregation_type : _AggregationType ) -> None :
497
+ def _set_conversation_aggregation_type (
498
+ self , conversation_aggregation_type : _AggregationType
499
+ ) -> None :
449
500
"""Input a conversation aggregation type to re-assign the aggregator function used by this evaluator for
450
501
multi-turn conversations. This aggregator is used to combine numeric outputs from each evaluation of a
451
502
multi-turn conversation into a single top-level result.
@@ -454,11 +505,15 @@ def _set_conversation_aggregation_type(self, conversation_aggregation_type: _Agg
454
505
results of a conversation to produce a single result.
455
506
:type conversation_aggregation_type: ~azure.ai.evaluation._AggregationType
456
507
"""
457
- self ._conversation_aggregation_function = GetAggregator (conversation_aggregation_type )
508
+ self ._conversation_aggregation_function = GetAggregator (
509
+ conversation_aggregation_type
510
+ )
458
511
459
512
@experimental
460
513
@final
461
- def _set_conversation_aggregator (self , aggregator : Callable [[List [float ]], float ]) -> None :
514
+ def _set_conversation_aggregator (
515
+ self , aggregator : Callable [[List [float ]], float ]
516
+ ) -> None :
462
517
"""Set the conversation aggregator function directly. This function will be applied to all numeric outputs
463
518
of an evaluator when it evaluates a conversation with multiple-turns thus ends up with multiple results per
464
519
evaluation that is needs to coalesce into a single result. Use when built-in aggregators do not
@@ -488,7 +543,9 @@ class AsyncEvaluatorBase:
488
543
to ensure that no one ever needs to extend or otherwise modify this class directly.
489
544
"""
490
545
491
- def __init__ (self , real_call ): # DO NOT ADD TYPEHINT PROMPT FLOW WILL SCREAM AT YOU ABOUT META GENERATION
546
+ def __init__ (
547
+ self , real_call
548
+ ): # DO NOT ADD TYPEHINT PROMPT FLOW WILL SCREAM AT YOU ABOUT META GENERATION
492
549
self ._real_call = real_call
493
550
494
551
# Don't look at my shame. Nothing to see here....
0 commit comments