@@ -104,15 +104,16 @@ def check(func):
104
104
def wrapper (self : "Guardrail" , metrics : dict , data : list , * args , ** kwargs ):
105
105
if self .metric_key not in metrics :
106
106
raise KeyError (
107
- f"This method requires the metrics contains { self .metric_key } ."
107
+ f"Method requires the metrics contains { self .metric_key } ."
108
108
)
109
109
if not isinstance (metrics [self .metric_key ], list ):
110
110
raise ValueError (
111
- f"This method requires the value of { self .metric_key } in metrics."
111
+ f"Method requires the value of { self .metric_key } in metrics."
112
112
)
113
113
if len (metrics [self .metric_key ]) != len (data ):
114
114
raise ValueError (
115
- f"This method requires the value of { self .metric_key } in metrics to have the same size as data."
115
+ f"Method requires the value of { self .metric_key } in metrics "
116
+ "to have the same size as data."
116
117
)
117
118
return func (self , metrics , data , * args , ** kwargs )
118
119
@@ -327,6 +328,8 @@ def apply_select(self, metrics: dict, data: list):
327
328
"""
328
329
if self .select not in self ._SELECT_OPERATOR :
329
330
raise ValueError (f"select='{ self .select } ' is not supported." )
331
+ if not data :
332
+ return data
330
333
func = self ._SELECT_OPERATOR [self .select ]
331
334
values = metrics [self .metric_key ]
332
335
idx = values .index (func (values ))
@@ -379,11 +382,11 @@ def filter_and_select(self, metrics: dict, data: list):
379
382
The selected candidate in a list.
380
383
"""
381
384
filtered_data = self .apply_filter (metrics , data )
382
- passed_idx = [i for i in range (len (metrics ["passed" ])) if metrics ["passed" ]]
385
+ passed_idx = [i for i in range (len (metrics ["passed" ])) if metrics ["passed" ][ i ] ]
383
386
filtered_metrics = {
384
387
self .metric_key : [metrics [self .metric_key ][i ] for i in passed_idx ]
385
388
}
386
- return self .apply_select (filtered_data , filtered_metrics )
389
+ return self .apply_select (filtered_metrics , filtered_data )
387
390
388
391
def single_metric_moderate (self , metrics : dict , data = None , ** kwargs ) -> List [str ]:
389
392
"""Applies moderation (filter and/or select) using the metrics."""
0 commit comments