Skip to content

Commit 3b04340

Browse files
committed
Update guardrail sequence to take max_retry.
1 parent 5127d30 commit 3b04340

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

ads/llm/chain.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ class GuardrailSequence(RunnableSequence):
5252
This option can also be turned on if the environment variable LOG_ADS_GUARDRAIL_INFO is set to "1".
5353
"""
5454

55+
max_retry: int = 1
56+
"""Maximum number of retry for running the Guardrail sequence again if the output is blocked by a guardrail."""
57+
5558
@property
5659
def steps(self) -> List[Runnable[Any, Any]]:
5760
"""Steps in the sequence."""
@@ -69,9 +72,11 @@ def type() -> str:
6972
return "ads_guardrail_sequence"
7073

7174
@classmethod
72-
def from_sequence(cls, sequence: RunnableSequence):
75+
def from_sequence(cls, sequence: RunnableSequence, **kwargs):
7376
"""Creates a GuardrailSequence from a LangChain runnable sequence."""
74-
return cls(first=sequence.first, middle=sequence.middle, last=sequence.last)
77+
return cls(
78+
first=sequence.first, middle=sequence.middle, last=sequence.last, **kwargs
79+
)
7580

7681
def __or__(self, other) -> "GuardrailSequence":
7782
"""Adds another component to the end of this sequence.
@@ -154,16 +159,22 @@ def run(self, input: Any, num_generations: int = 1, **kwargs) -> GuardrailIO:
154159
Contains the outputs and metrics from each step.
155160
The final output is stored in GuardrailIO.data property.
156161
"""
157-
obj = GuardrailIO(data=[input])
158-
159-
try:
160-
for i, step in enumerate(self.steps):
161-
obj = self._run_step(step, obj, num_generations, **kwargs)
162-
except BlockedByGuardrail as ex:
163-
if self.raise_exception:
164-
raise ex
165-
obj.data = [ex.message]
166-
obj.info.append(ex.info)
162+
retry_count = 0
163+
while True:
164+
retry_count += 1
165+
obj = GuardrailIO(data=[input])
166+
try:
167+
for i, step in enumerate(self.steps):
168+
obj = self._run_step(step, obj, num_generations, **kwargs)
169+
break
170+
except BlockedByGuardrail as ex:
171+
if retry_count < self.max_retry:
172+
continue
173+
if self.raise_exception:
174+
raise ex
175+
obj.data = [ex.message]
176+
obj.info.append(ex.info)
177+
break
167178
if self.log_info or os.environ.get(LOG_ADS_GUARDRAIL_INFO) == "1":
168179
# LOG_ADS_GUARDRAIL_INFO is set to "1" in score.py by default.
169180
print(obj.dict())

0 commit comments

Comments
 (0)