Skip to content

Commit df33199

Browse files
committed
Use tenacity.RetryError instead of genai.errors.ClientError for gemini guardrailing errors.
1 parent 70091b7 commit df33199

File tree

2 files changed

+39
-25
lines changed

2 files changed

+39
-25
lines changed

tests/integration/guardrails/test_guardrails_gemini.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,17 @@
22

33
import os
44
import sys
5-
import uuid
65
import time
6+
import uuid
77

88
# Add integration folder (parent) to sys.path
99
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
1010

11-
from utils import get_gemini_client, create_dataset, add_guardrail_to_dataset
12-
1311
import pytest
1412
import requests
13+
import tenacity
1514
from google import genai
15+
from utils import add_guardrail_to_dataset, create_dataset, get_gemini_client
1616

1717
# Pytest plugins
1818
pytest_plugins = ("pytest_asyncio",)
@@ -46,14 +46,17 @@ async def test_message_content_guardrail_from_file(
4646
}
4747

4848
if not do_stream:
49-
with pytest.raises(genai.errors.ClientError) as exc_info:
49+
with pytest.raises(tenacity.RetryError) as exc_info:
5050
response = client.models.generate_content(
5151
**request,
5252
)
53-
assert "[Invariant] The response did not pass the guardrails" in str(
54-
exc_info
55-
)
56-
assert "Dublin detected in the response" in str(exc_info)
53+
original_error = exc_info.value.last_attempt.exception()
54+
assert isinstance(original_error, genai.errors.ClientError)
55+
assert original_error.code == 400
56+
assert "[Invariant] The response did not pass the guardrails" in str(
57+
original_error
58+
)
59+
assert "Dublin detected in the response" in str(original_error)
5760

5861
else:
5962
response = client.models.generate_content_stream(**request)
@@ -148,16 +151,17 @@ def get_capital(country_name: str) -> str:
148151
}
149152

150153
if not do_stream:
151-
with pytest.raises(genai.errors.ClientError) as exc_info:
154+
with pytest.raises(tenacity.RetryError) as exc_info:
152155
client.models.generate_content(
153156
**request,
154157
)
155-
156-
assert exc_info.value.status_code == 400
157-
assert "[Invariant] The response did not pass the guardrails" in str(
158-
exc_info
159-
)
160-
assert "get_capital is called with Germany as argument" in str(exc_info)
158+
original_error = exc_info.value.last_attempt.exception()
159+
assert isinstance(original_error, genai.errors.ClientError)
160+
assert original_error.code == 400
161+
assert "[Invariant] The response did not pass the guardrails" in str(
162+
original_error
163+
)
164+
assert "get_capital is called with Germany as argument" in str(original_error)
161165

162166
else:
163167
response = client.models.generate_content_stream(
@@ -246,14 +250,17 @@ async def test_input_from_guardrail_from_file(
246250
}
247251

248252
if not do_stream:
249-
with pytest.raises(genai.errors.ClientError) as exc_info:
253+
with pytest.raises(tenacity.RetryError) as exc_info:
250254
client.models.generate_content(**request)
251255

256+
original_error = exc_info.value.last_attempt.exception()
257+
assert isinstance(original_error, genai.errors.ClientError)
258+
assert original_error.code == 400
252259
assert "[Invariant] The request did not pass the guardrails" in str(
253-
exc_info.value
260+
original_error
254261
)
255262
assert "Users must not mention the magic phrase 'Fight Club'" in str(
256-
exc_info.value
263+
original_error
257264
)
258265

259266
else:
@@ -372,15 +379,18 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s
372379
},
373380
}
374381
if not do_stream:
375-
with pytest.raises(genai.errors.ClientError) as exc_info:
382+
with pytest.raises(tenacity.RetryError) as exc_info:
376383
client.models.generate_content(**shrek_request)
377384

385+
original_error = exc_info.value.last_attempt.exception()
386+
assert isinstance(original_error, genai.errors.ClientError)
387+
assert original_error.code == 400
378388
assert "[Invariant] The response did not pass the guardrails" in str(
379-
exc_info.value
389+
original_error
380390
)
381391
# Only the block guardrail should be triggered here
382-
assert "ogre detected in response" in str(exc_info.value)
383-
assert "Fiona detected in response" not in str(exc_info.value)
392+
assert "ogre detected in response" in str(original_error)
393+
assert "Fiona detected in response" not in str(original_error)
384394
else:
385395
response = client.models.generate_content_stream(**shrek_request)
386396

@@ -492,12 +502,15 @@ async def test_preguardrailing_with_guardrails_from_explorer(
492502
],
493503
)
494504
else:
495-
with pytest.raises(genai.errors.ClientError) as exc_info:
505+
with pytest.raises(tenacity.RetryError) as exc_info:
496506
chat_response = client.models.generate_content(**request)
507+
original_error = exc_info.value.last_attempt.exception()
508+
assert isinstance(original_error, genai.errors.ClientError)
509+
assert original_error.code == 400
497510
assert "[Invariant] The request did not pass the guardrails" in str(
498-
exc_info.value
511+
original_error
499512
)
500-
assert "pun detected in user message" in str(exc_info.value)
513+
assert "pun detected in user message" in str(original_error)
501514
else:
502515
if do_stream:
503516
response = client.models.generate_content_stream(**request)

tests/integration/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ pytest
99
pytest-asyncio
1010
pytest-timeout
1111
tavily-python
12+
tenacity
1213
uv

0 commit comments

Comments
 (0)