Skip to content

Commit 97d105a

Browse files
committed
Update chain.py to handle BloackedByGuardrail exception.
1 parent 70f6555 commit 97d105a

File tree

1 file changed

+49
-33
lines changed

1 file changed

+49
-33
lines changed

ads/llm/chain.py

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@
2525
RunnableConfig,
2626
RunnableSequence,
2727
)
28-
from . import guardrails
29-
from .guardrails.base import GuardrailIO, Guardrail, RunInfo
28+
from ads.llm import guardrails
29+
from ads.llm.serialize import load, dump
30+
from ads.llm.guardrails.base import GuardrailIO, Guardrail, RunInfo, BlockedByGuardrail
3031

3132

3233
logger = logging.getLogger(__name__)
33-
BLOCKED_MESSAGE = "custom_msg"
3434
SPEC_CLASS = "class"
3535
SPEC_PATH = "path"
3636
SPEC_SPEC = "spec"
@@ -47,6 +47,20 @@ class GuardrailSequence(RunnableSequence):
4747
first: Optional[Runnable] = None
4848
last: Optional[Runnable] = None
4949

50+
raise_exception: bool = False
51+
"""The ``raise_exception`` property indicate whether an exception should be raised
52+
if the content is blocked by one of the guardrails.
53+
This property is set to ``False`` by default.
54+
Note that each guardrail also has its own ``raise_exception`` property.
55+
This property on GuardrailSequence has no effect
56+
when the ``raise_exception`` is set to False on the individual guardrail.
57+
58+
When this is ``False``, instead of raising an exception,
59+
the custom message from the guardrail will be returned as the output.
60+
61+
When this is ``True``, the ``BlockedByGuardrail`` exception from the guardrail will be raised.
62+
"""
63+
5064
@property
5165
def steps(self) -> List[Runnable[Any, Any]]:
5266
"""Steps in the sequence."""
@@ -99,6 +113,31 @@ def _invoke_llm(self, llm, texts, num_generations, **kwargs):
99113
output = llm.batch(texts, **kwargs)
100114
return output
101115

116+
def _run_step(
117+
self, step: Runnable, obj: GuardrailIO, num_generations: int, **kwargs
118+
):
119+
if not isinstance(step, Guardrail):
120+
# Invoke the step as a LangChain component
121+
spec = {}
122+
with RunInfo(name=step.__class__.__name__, input=obj.data) as info:
123+
if isinstance(step, LLM):
124+
output = self._invoke_llm(step, obj.data, num_generations, **kwargs)
125+
spec.update(kwargs)
126+
spec["num_generations"] = num_generations
127+
else:
128+
output = step.batch(obj.data)
129+
info.output = output
130+
info.parameters = {
131+
"class": step.__class__.__name__,
132+
"path": step.__module__,
133+
"spec": spec,
134+
}
135+
obj.info.append(info)
136+
obj.data = output
137+
else:
138+
obj = step.invoke(obj)
139+
return obj
140+
102141
def run(self, input: Any, num_generations: int = 1, **kwargs) -> GuardrailIO:
103142
"""Runs the guardrail sequence.
104143
@@ -120,36 +159,13 @@ def run(self, input: Any, num_generations: int = 1, **kwargs) -> GuardrailIO:
120159
"""
121160
obj = GuardrailIO(data=[input])
122161

123-
for i, step in enumerate(self.steps):
124-
if not isinstance(step, Guardrail):
125-
# Invoke the step as a LangChain component
126-
spec = {}
127-
with RunInfo(name=step.__class__.__name__, input=obj.data) as info:
128-
if isinstance(step, LLM):
129-
output = self._invoke_llm(
130-
step, obj.data, num_generations, **kwargs
131-
)
132-
spec.update(kwargs)
133-
spec["num_generations"] = num_generations
134-
else:
135-
output = step.batch(obj.data)
136-
info.output = output
137-
info.parameters = {
138-
"class": step.__class__.__name__,
139-
"path": step.__module__,
140-
"spec": spec,
141-
}
142-
obj.info.append(info)
143-
obj.data = output
144-
else:
145-
obj = step.invoke(obj)
146-
if not obj.data:
147-
default_msg = f"Blocked by {step.__class__.__name__}"
148-
msg = getattr(step, BLOCKED_MESSAGE, default_msg)
149-
if msg is None:
150-
msg = default_msg
151-
obj.data = [msg]
152-
return obj
162+
try:
163+
for i, step in enumerate(self.steps):
164+
obj = self._run_step(step, obj, num_generations, **kwargs)
165+
except BlockedByGuardrail as ex:
166+
if self.raise_exception:
167+
raise ex
168+
obj.data = [ex.message]
153169
return obj
154170

155171
def _save_to_file(self, chain_dict, filename, overwrite=False):

0 commit comments

Comments
 (0)