Skip to content

Commit c8c5b03

Browse files
authored
Amazon bedrock guardrails (#17281)
1 parent d1fc12e commit c8c5b03

File tree

8 files changed

+152
-10
lines changed

8 files changed

+152
-10
lines changed

llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,21 @@ class BedrockConverse(FunctionCallingLLM):
114114
default=60.0,
115115
description="The timeout for the Bedrock API request in seconds. It will be used for both connect and read timeouts.",
116116
)
117+
guardrail_identifier: Optional[str] = (
118+
Field(
119+
description="The unique identifier of the guardrail that you want to use. If you don’t provide a value, no guardrail is applied to the invocation."
120+
),
121+
)
122+
guardrail_version: Optional[str] = (
123+
Field(
124+
description="The version number for the guardrail. The value can also be DRAFT"
125+
),
126+
)
127+
trace: Optional[str] = (
128+
Field(
129+
description="Specifies whether to enable or disable the Bedrock trace. If enabled, you can see the full Bedrock trace."
130+
),
131+
)
117132
additional_kwargs: Dict[str, Any] = Field(
118133
default_factory=dict,
119134
description="Additional kwargs for the bedrock invokeModel request.",
@@ -145,6 +160,9 @@ def __init__(
145160
completion_to_prompt: Optional[Callable[[str], str]] = None,
146161
pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
147162
output_parser: Optional[BaseOutputParser] = None,
163+
guardrail_identifier: Optional[str] = None,
164+
guardrail_version: Optional[str] = None,
165+
trace: Optional[str] = None,
148166
) -> None:
149167
additional_kwargs = additional_kwargs or {}
150168
callback_manager = callback_manager or CallbackManager([])
@@ -178,6 +196,9 @@ def __init__(
178196
region_name=region_name,
179197
botocore_session=botocore_session,
180198
botocore_config=botocore_config,
199+
guardrail_identifier=guardrail_identifier,
200+
guardrail_version=guardrail_version,
201+
trace=trace,
181202
)
182203

183204
self._config = None
@@ -292,6 +313,9 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
292313
system_prompt=self.system_prompt,
293314
max_retries=self.max_retries,
294315
stream=False,
316+
guardrail_identifier=self.guardrail_identifier,
317+
guardrail_version=self.guardrail_version,
318+
trace=self.trace,
295319
**all_kwargs,
296320
)
297321

@@ -336,6 +360,9 @@ def stream_chat(
336360
system_prompt=self.system_prompt,
337361
max_retries=self.max_retries,
338362
stream=True,
363+
guardrail_identifier=self.guardrail_identifier,
364+
guardrail_version=self.guardrail_version,
365+
trace=self.trace,
339366
**all_kwargs,
340367
)
341368

@@ -416,6 +443,9 @@ async def achat(
416443
system_prompt=self.system_prompt,
417444
max_retries=self.max_retries,
418445
stream=False,
446+
guardrail_identifier=self.guardrail_identifier,
447+
guardrail_version=self.guardrail_version,
448+
trace=self.trace,
419449
**all_kwargs,
420450
)
421451

@@ -461,6 +491,9 @@ async def astream_chat(
461491
system_prompt=self.system_prompt,
462492
max_retries=self.max_retries,
463493
stream=True,
494+
guardrail_identifier=self.guardrail_identifier,
495+
guardrail_version=self.guardrail_version,
496+
trace=self.trace,
464497
**all_kwargs,
465498
)
466499

llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/utils.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,9 @@ def converse_with_retry(
307307
max_tokens: int = 1000,
308308
temperature: float = 0.1,
309309
stream: bool = False,
310+
guardrail_identifier: Optional[str] = None,
311+
guardrail_version: Optional[str] = None,
312+
trace: Optional[str] = None,
310313
**kwargs: Any,
311314
) -> Any:
312315
"""Use tenacity to retry the completion call."""
@@ -323,8 +326,24 @@ def converse_with_retry(
323326
converse_kwargs["system"] = [{"text": system_prompt}]
324327
if tool_config := kwargs.get("tools"):
325328
converse_kwargs["toolConfig"] = tool_config
329+
if guardrail_identifier and guardrail_version:
330+
converse_kwargs["guardrailConfig"] = {}
331+
converse_kwargs["guardrailConfig"]["guardrailIdentifier"] = guardrail_identifier
332+
converse_kwargs["guardrailConfig"]["guardrailVersion"] = guardrail_version
333+
if trace:
334+
converse_kwargs["guardrailConfig"]["trace"] = trace
326335
converse_kwargs = join_two_dicts(
327-
converse_kwargs, {k: v for k, v in kwargs.items() if k != "tools"}
336+
converse_kwargs,
337+
{
338+
k: v
339+
for k, v in kwargs.items()
340+
if (
341+
k != "tools"
342+
or k != "guardrail_identifier"
343+
or k != "guardrail_version"
344+
or k != "trace"
345+
)
346+
},
328347
)
329348

330349
@retry_decorator
@@ -346,6 +365,9 @@ async def converse_with_retry_async(
346365
max_tokens: int = 1000,
347366
temperature: float = 0.1,
348367
stream: bool = False,
368+
guardrail_identifier: Optional[str] = None,
369+
guardrail_version: Optional[str] = None,
370+
trace: Optional[str] = None,
349371
**kwargs: Any,
350372
) -> Any:
351373
"""Use tenacity to retry the completion call."""
@@ -362,8 +384,24 @@ async def converse_with_retry_async(
362384
converse_kwargs["system"] = [{"text": system_prompt}]
363385
if tool_config := kwargs.get("tools"):
364386
converse_kwargs["toolConfig"] = tool_config
387+
if guardrail_identifier and guardrail_version:
388+
converse_kwargs["guardrailConfig"] = {}
389+
converse_kwargs["guardrailConfig"]["guardrailIdentifier"] = guardrail_identifier
390+
converse_kwargs["guardrailConfig"]["guardrailVersion"] = guardrail_version
391+
if trace:
392+
converse_kwargs["guardrailConfig"]["trace"] = trace
365393
converse_kwargs = join_two_dicts(
366-
converse_kwargs, {k: v for k, v in kwargs.items() if k != "tools"}
394+
converse_kwargs,
395+
{
396+
k: v
397+
for k, v in kwargs.items()
398+
if (
399+
k != "tools"
400+
or k != "guardrail_identifier"
401+
or k != "guardrail_version"
402+
or k != "trace"
403+
)
404+
},
367405
)
368406

369407
## NOTE: Returning the generator directly from converse_stream doesn't work

llama-index-integrations/llms/llama-index-llms-bedrock-converse/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
2727
license = "MIT"
2828
name = "llama-index-llms-bedrock-converse"
2929
readme = "README.md"
30-
version = "0.4.1"
30+
version = "0.4.2"
3131

3232
[tool.poetry.dependencies]
3333
python = ">=3.9,<4.0"

llama-index-integrations/llms/llama-index-llms-bedrock-converse/tests/test_llms_bedrock_converse.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
EXP_MAX_TOKENS = 100
1515
EXP_TEMPERATURE = 0.7
1616
EXP_MODEL = "anthropic.claude-v2"
17+
EXP_GUARDRAIL_ID = "IDENTIFIER"
18+
EXP_GUARDRAIL_VERSION = "DRAFT"
19+
EXP_GUARDRAIL_TRACE = "ENABLED"
1720

1821
# Reused chat message and prompt
1922
messages = [ChatMessage(role=MessageRole.USER, content="Test")]
@@ -88,6 +91,9 @@ def bedrock_converse(mock_boto3_session, mock_aioboto3_session):
8891
model=EXP_MODEL,
8992
max_tokens=EXP_MAX_TOKENS,
9093
temperature=EXP_TEMPERATURE,
94+
guardrail_identifier=EXP_GUARDRAIL_ID,
95+
guardrail_version=EXP_GUARDRAIL_VERSION,
96+
trace=EXP_GUARDRAIL_TRACE,
9197
callback_manager=CallbackManager(),
9298
)
9399

llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/base.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,21 @@ class Bedrock(LLM):
9494
default=60.0,
9595
description="The timeout for the Bedrock API request in seconds. It will be used for both connect and read timeouts.",
9696
)
97+
guardrail_identifier: Optional[str] = (
98+
Field(
99+
description="The unique identifier of the guardrail that you want to use. If you don’t provide a value, no guardrail is applied to the invocation."
100+
),
101+
)
102+
guardrail_version: Optional[str] = (
103+
Field(
104+
description="The version number for the guardrail. The value can also be DRAFT"
105+
),
106+
)
107+
trace: Optional[str] = (
108+
Field(
109+
description="Specifies whether to enable or disable the Bedrock trace. If enabled, you can see the full Bedrock trace."
110+
),
111+
)
97112
additional_kwargs: Dict[str, Any] = Field(
98113
default_factory=dict,
99114
description="Additional kwargs for the bedrock invokeModel request.",
@@ -125,6 +140,9 @@ def __init__(
125140
completion_to_prompt: Optional[Callable[[str], str]] = None,
126141
pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
127142
output_parser: Optional[BaseOutputParser] = None,
143+
guardrail_identifier: Optional[str] = None,
144+
guardrail_version: Optional[str] = None,
145+
trace: Optional[str] = None,
128146
**kwargs: Any,
129147
) -> None:
130148
if context_size is None and model not in BEDROCK_FOUNDATION_LLMS:
@@ -187,6 +205,9 @@ def __init__(
187205
completion_to_prompt=completion_to_prompt,
188206
pydantic_program_mode=pydantic_program_mode,
189207
output_parser=output_parser,
208+
guardrail_identifier=guardrail_identifier,
209+
guardrail_version=guardrail_version,
210+
trace=trace,
190211
)
191212
self._provider = get_provider(model)
192213
self.messages_to_prompt = (
@@ -257,6 +278,9 @@ def complete(
257278
model=self.model,
258279
request_body=request_body_str,
259280
max_retries=self.max_retries,
281+
guardrail_identifier=self.guardrail_identifier,
282+
guardrail_version=self.guardrail_version,
283+
trace=self.trace,
260284
**all_kwargs,
261285
)
262286
response_body = response["body"].read()
@@ -287,6 +311,9 @@ def stream_complete(
287311
request_body=request_body_str,
288312
max_retries=self.max_retries,
289313
stream=True,
314+
guardrail_identifier=self.guardrail_identifier,
315+
guardrail_version=self.guardrail_version,
316+
trace=self.trace,
290317
**all_kwargs,
291318
)
292319
response_body = response["body"]

llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/utils.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,9 @@ def completion_with_retry(
299299
request_body: str,
300300
max_retries: int,
301301
stream: bool = False,
302+
guardrail_identifier: Optional[str] = None,
303+
guardrail_version: Optional[str] = None,
304+
trace: Optional[str] = None,
302305
**kwargs: Any,
303306
) -> Any:
304307
"""Use tenacity to retry the completion call."""
@@ -307,9 +310,29 @@ def completion_with_retry(
307310
@retry_decorator
308311
def _completion_with_retry(**kwargs: Any) -> Any:
309312
if stream:
310-
return client.invoke_model_with_response_stream(
311-
modelId=model, body=request_body
312-
)
313-
return client.invoke_model(modelId=model, body=request_body)
313+
if guardrail_identifier is None or guardrail_version is None:
314+
return client.invoke_model_with_response_stream(
315+
modelId=model,
316+
body=request_body,
317+
)
318+
else:
319+
return client.invoke_model_with_response_stream(
320+
modelId=model,
321+
body=request_body,
322+
guardrailIdentifier=guardrail_identifier,
323+
guardrailVersion=guardrail_version,
324+
trace=trace,
325+
)
326+
else:
327+
if guardrail_identifier is None or guardrail_version is None:
328+
return client.invoke_model(modelId=model, body=request_body)
329+
else:
330+
return client.invoke_model(
331+
modelId=model,
332+
body=request_body,
333+
guardrailIdentifier=guardrail_identifier,
334+
guardrailVersion=guardrail_version,
335+
trace=trace,
336+
)
314337

315338
return _completion_with_retry(**kwargs)

llama-index-integrations/llms/llama-index-llms-bedrock/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
2727
license = "MIT"
2828
name = "llama-index-llms-bedrock"
2929
readme = "README.md"
30-
version = "0.3.2"
30+
version = "0.3.3"
3131

3232
[tool.poetry.dependencies]
3333
python = ">=3.9,<4.0"

llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,9 @@ def test_model_basic(
147147
profile_name=None,
148148
region_name="us-east-1",
149149
aws_access_key_id="test",
150+
guardrail_identifier="test",
151+
guardrail_version="test",
152+
trace="ENABLED",
150153
)
151154

152155
bedrock_stubber = Stubber(llm._client)
@@ -155,13 +158,25 @@ def test_model_basic(
155158
bedrock_stubber.add_response(
156159
"invoke_model",
157160
get_invoke_model_response(response_body),
158-
{"body": complete_request, "modelId": model},
161+
{
162+
"body": complete_request,
163+
"modelId": model,
164+
"guardrailIdentifier": "test",
165+
"guardrailVersion": "test",
166+
"trace": "ENABLED",
167+
},
159168
)
160169
# response for llm.chat()
161170
bedrock_stubber.add_response(
162171
"invoke_model",
163172
get_invoke_model_response(response_body),
164-
{"body": chat_request, "modelId": model},
173+
{
174+
"body": chat_request,
175+
"modelId": model,
176+
"guardrailIdentifier": "test",
177+
"guardrailVersion": "test",
178+
"trace": "ENABLED",
179+
},
165180
)
166181

167182
bedrock_stubber.activate()

0 commit comments

Comments
 (0)