@@ -52,6 +52,9 @@ class GuardrailSequence(RunnableSequence):
52
52
This option can also be turned on if the environment variable LOG_ADS_GUARDRAIL_INFO is set to "1".
53
53
"""
54
54
55
+ max_retry : int = 1
56
+ """Maximum number of retry for running the Guardrail sequence again if the output is blocked by a guardrail."""
57
+
55
58
@property
56
59
def steps (self ) -> List [Runnable [Any , Any ]]:
57
60
"""Steps in the sequence."""
@@ -69,9 +72,11 @@ def type() -> str:
69
72
return "ads_guardrail_sequence"
70
73
71
74
@classmethod
72
- def from_sequence (cls , sequence : RunnableSequence ):
75
+ def from_sequence (cls , sequence : RunnableSequence , ** kwargs ):
73
76
"""Creates a GuardrailSequence from a LangChain runnable sequence."""
74
- return cls (first = sequence .first , middle = sequence .middle , last = sequence .last )
77
+ return cls (
78
+ first = sequence .first , middle = sequence .middle , last = sequence .last , ** kwargs
79
+ )
75
80
76
81
def __or__ (self , other ) -> "GuardrailSequence" :
77
82
"""Adds another component to the end of this sequence.
@@ -154,16 +159,22 @@ def run(self, input: Any, num_generations: int = 1, **kwargs) -> GuardrailIO:
154
159
Contains the outputs and metrics from each step.
155
160
The final output is stored in GuardrailIO.data property.
156
161
"""
157
- obj = GuardrailIO (data = [input ])
158
-
159
- try :
160
- for i , step in enumerate (self .steps ):
161
- obj = self ._run_step (step , obj , num_generations , ** kwargs )
162
- except BlockedByGuardrail as ex :
163
- if self .raise_exception :
164
- raise ex
165
- obj .data = [ex .message ]
166
- obj .info .append (ex .info )
162
+ retry_count = 0
163
+ while True :
164
+ retry_count += 1
165
+ obj = GuardrailIO (data = [input ])
166
+ try :
167
+ for i , step in enumerate (self .steps ):
168
+ obj = self ._run_step (step , obj , num_generations , ** kwargs )
169
+ break
170
+ except BlockedByGuardrail as ex :
171
+ if retry_count < self .max_retry :
172
+ continue
173
+ if self .raise_exception :
174
+ raise ex
175
+ obj .data = [ex .message ]
176
+ obj .info .append (ex .info )
177
+ break
167
178
if self .log_info or os .environ .get (LOG_ADS_GUARDRAIL_INFO ) == "1" :
168
179
# LOG_ADS_GUARDRAIL_INFO is set to "1" in score.py by default.
169
180
print (obj .dict ())
0 commit comments