25
25
RunnableConfig ,
26
26
RunnableSequence ,
27
27
)
28
- from . import guardrails
29
- from .guardrails .base import GuardrailIO , Guardrail , RunInfo
28
+ from ads .llm import guardrails
29
+ from ads .llm .serialize import load , dump
30
+ from ads .llm .guardrails .base import GuardrailIO , Guardrail , RunInfo , BlockedByGuardrail
30
31
31
32
32
33
logger = logging .getLogger (__name__ )
33
- BLOCKED_MESSAGE = "custom_msg"
34
34
SPEC_CLASS = "class"
35
35
SPEC_PATH = "path"
36
36
SPEC_SPEC = "spec"
@@ -47,6 +47,20 @@ class GuardrailSequence(RunnableSequence):
47
47
first : Optional [Runnable ] = None
48
48
last : Optional [Runnable ] = None
49
49
50
+ raise_exception : bool = False
51
+ """The ``raise_exception`` property indicate whether an exception should be raised
52
+ if the content is blocked by one of the guardrails.
53
+ This property is set to ``False`` by default.
54
+ Note that each guardrail also has its own ``raise_exception`` property.
55
+ This property on GuardrailSequence has no effect
56
+ when the ``raise_exception`` is set to False on the individual guardrail.
57
+
58
+ When this is ``False``, instead of raising an exception,
59
+ the custom message from the guardrail will be returned as the output.
60
+
61
+ When this is ``True``, the ``BlockedByGuardrail`` exception from the guardrail will be raised.
62
+ """
63
+
50
64
@property
51
65
def steps (self ) -> List [Runnable [Any , Any ]]:
52
66
"""Steps in the sequence."""
@@ -99,6 +113,31 @@ def _invoke_llm(self, llm, texts, num_generations, **kwargs):
99
113
output = llm .batch (texts , ** kwargs )
100
114
return output
101
115
116
+ def _run_step (
117
+ self , step : Runnable , obj : GuardrailIO , num_generations : int , ** kwargs
118
+ ):
119
+ if not isinstance (step , Guardrail ):
120
+ # Invoke the step as a LangChain component
121
+ spec = {}
122
+ with RunInfo (name = step .__class__ .__name__ , input = obj .data ) as info :
123
+ if isinstance (step , LLM ):
124
+ output = self ._invoke_llm (step , obj .data , num_generations , ** kwargs )
125
+ spec .update (kwargs )
126
+ spec ["num_generations" ] = num_generations
127
+ else :
128
+ output = step .batch (obj .data )
129
+ info .output = output
130
+ info .parameters = {
131
+ "class" : step .__class__ .__name__ ,
132
+ "path" : step .__module__ ,
133
+ "spec" : spec ,
134
+ }
135
+ obj .info .append (info )
136
+ obj .data = output
137
+ else :
138
+ obj = step .invoke (obj )
139
+ return obj
140
+
102
141
def run (self , input : Any , num_generations : int = 1 , ** kwargs ) -> GuardrailIO :
103
142
"""Runs the guardrail sequence.
104
143
@@ -120,36 +159,13 @@ def run(self, input: Any, num_generations: int = 1, **kwargs) -> GuardrailIO:
120
159
"""
121
160
obj = GuardrailIO (data = [input ])
122
161
123
- for i , step in enumerate (self .steps ):
124
- if not isinstance (step , Guardrail ):
125
- # Invoke the step as a LangChain component
126
- spec = {}
127
- with RunInfo (name = step .__class__ .__name__ , input = obj .data ) as info :
128
- if isinstance (step , LLM ):
129
- output = self ._invoke_llm (
130
- step , obj .data , num_generations , ** kwargs
131
- )
132
- spec .update (kwargs )
133
- spec ["num_generations" ] = num_generations
134
- else :
135
- output = step .batch (obj .data )
136
- info .output = output
137
- info .parameters = {
138
- "class" : step .__class__ .__name__ ,
139
- "path" : step .__module__ ,
140
- "spec" : spec ,
141
- }
142
- obj .info .append (info )
143
- obj .data = output
144
- else :
145
- obj = step .invoke (obj )
146
- if not obj .data :
147
- default_msg = f"Blocked by { step .__class__ .__name__ } "
148
- msg = getattr (step , BLOCKED_MESSAGE , default_msg )
149
- if msg is None :
150
- msg = default_msg
151
- obj .data = [msg ]
152
- return obj
162
+ try :
163
+ for i , step in enumerate (self .steps ):
164
+ obj = self ._run_step (step , obj , num_generations , ** kwargs )
165
+ except BlockedByGuardrail as ex :
166
+ if self .raise_exception :
167
+ raise ex
168
+ obj .data = [ex .message ]
153
169
return obj
154
170
155
171
def _save_to_file (self , chain_dict , filename , overwrite = False ):
0 commit comments