Skip to content

Commit 5127d30

Browse files
committed
Update base.py to fix bug on selecting outputs based on guardrail metrics.
1 parent 9e82210 commit 5127d30

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

ads/llm/guardrails/base.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,16 @@ def check(func):
104104
def wrapper(self: "Guardrail", metrics: dict, data: list, *args, **kwargs):
105105
if self.metric_key not in metrics:
106106
raise KeyError(
107-
f"This method requires the metrics contains {self.metric_key}."
107+
f"Method requires the metrics contains {self.metric_key}."
108108
)
109109
if not isinstance(metrics[self.metric_key], list):
110110
raise ValueError(
111-
f"This method requires the value of {self.metric_key} in metrics."
111+
f"Method requires the value of {self.metric_key} in metrics."
112112
)
113113
if len(metrics[self.metric_key]) != len(data):
114114
raise ValueError(
115-
f"This method requires the value of {self.metric_key} in metrics to have the same size as data."
115+
f"Method requires the value of {self.metric_key} in metrics "
116+
"to have the same size as data."
116117
)
117118
return func(self, metrics, data, *args, **kwargs)
118119

@@ -327,6 +328,8 @@ def apply_select(self, metrics: dict, data: list):
327328
"""
328329
if self.select not in self._SELECT_OPERATOR:
329330
raise ValueError(f"select='{self.select}' is not supported.")
331+
if not data:
332+
return data
330333
func = self._SELECT_OPERATOR[self.select]
331334
values = metrics[self.metric_key]
332335
idx = values.index(func(values))
@@ -379,11 +382,11 @@ def filter_and_select(self, metrics: dict, data: list):
379382
The selected candidate in a list.
380383
"""
381384
filtered_data = self.apply_filter(metrics, data)
382-
passed_idx = [i for i in range(len(metrics["passed"])) if metrics["passed"]]
385+
passed_idx = [i for i in range(len(metrics["passed"])) if metrics["passed"][i]]
383386
filtered_metrics = {
384387
self.metric_key: [metrics[self.metric_key][i] for i in passed_idx]
385388
}
386-
return self.apply_select(filtered_data, filtered_metrics)
389+
return self.apply_select(filtered_metrics, filtered_data)
387390

388391
def single_metric_moderate(self, metrics: dict, data=None, **kwargs) -> List[str]:
389392
"""Applies moderation (filter and/or select) using the metrics."""

0 commit comments

Comments
 (0)