Skip to content

Commit 39cde09

Browse files
committed
Fix broken Gemini tests.
1 parent 8af2f44 commit 39cde09

File tree

2 files changed

+33
-40
lines changed

2 files changed

+33
-40
lines changed

tests/integration/guardrails/test_guardrails_gemini.py

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Test the guardrails from file with the Gemini route."""
22

3+
# pylint: disable=protected-access
4+
35
import os
46
import sys
57
import time
@@ -10,7 +12,6 @@
1012

1113
import pytest
1214
import requests
13-
import tenacity
1415
from google import genai
1516
from utils import add_guardrail_to_dataset, create_dataset, get_gemini_client
1617

@@ -46,17 +47,16 @@ async def test_message_content_guardrail_from_file(
4647
}
4748

4849
if not do_stream:
49-
with pytest.raises(tenacity.RetryError) as exc_info:
50+
with pytest.raises(genai.errors.ClientError) as e:
5051
response = client.models.generate_content(
5152
**request,
5253
)
53-
original_error = exc_info.value.last_attempt.exception()
54-
assert isinstance(original_error, genai.errors.ClientError)
55-
assert original_error.code == 400
56-
assert "[Invariant] The response did not pass the guardrails" in str(
57-
original_error
54+
assert e._excinfo[1].code == 400
55+
assert (
56+
"[Invariant] The response did not pass the guardrails"
57+
in e._excinfo[1].message
5858
)
59-
assert "Dublin detected in the response" in str(original_error)
59+
assert "Dublin detected in the response" in str(e._excinfo[1])
6060

6161
else:
6262
response = client.models.generate_content_stream(**request)
@@ -151,17 +151,16 @@ def get_capital(country_name: str) -> str:
151151
}
152152

153153
if not do_stream:
154-
with pytest.raises(tenacity.RetryError) as exc_info:
154+
with pytest.raises(genai.errors.ClientError) as e:
155155
client.models.generate_content(
156156
**request,
157157
)
158-
original_error = exc_info.value.last_attempt.exception()
159-
assert isinstance(original_error, genai.errors.ClientError)
160-
assert original_error.code == 400
161-
assert "[Invariant] The response did not pass the guardrails" in str(
162-
original_error
158+
assert e._excinfo[1].code == 400
159+
assert (
160+
"[Invariant] The response did not pass the guardrails"
161+
in e._excinfo[1].message
163162
)
164-
assert "get_capital is called with Germany as argument" in str(original_error)
163+
assert "get_capital is called with Germany as argument" in str(e._excinfo[1])
165164

166165
else:
167166
response = client.models.generate_content_stream(
@@ -250,17 +249,15 @@ async def test_input_from_guardrail_from_file(
250249
}
251250

252251
if not do_stream:
253-
with pytest.raises(tenacity.RetryError) as exc_info:
252+
with pytest.raises(genai.errors.ClientError) as e:
254253
client.models.generate_content(**request)
255-
256-
original_error = exc_info.value.last_attempt.exception()
257-
assert isinstance(original_error, genai.errors.ClientError)
258-
assert original_error.code == 400
259-
assert "[Invariant] The request did not pass the guardrails" in str(
260-
original_error
254+
assert e._excinfo[1].code == 400
255+
assert (
256+
"[Invariant] The request did not pass the guardrails"
257+
in e._excinfo[1].message
261258
)
262259
assert "Users must not mention the magic phrase 'Fight Club'" in str(
263-
original_error
260+
e._excinfo[1]
264261
)
265262

266263
else:
@@ -379,18 +376,16 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s
379376
},
380377
}
381378
if not do_stream:
382-
with pytest.raises(tenacity.RetryError) as exc_info:
379+
with pytest.raises(genai.errors.ClientError) as e:
383380
client.models.generate_content(**shrek_request)
384-
385-
original_error = exc_info.value.last_attempt.exception()
386-
assert isinstance(original_error, genai.errors.ClientError)
387-
assert original_error.code == 400
388-
assert "[Invariant] The response did not pass the guardrails" in str(
389-
original_error
381+
assert e._excinfo[1].code == 400
382+
assert (
383+
"[Invariant] The response did not pass the guardrails"
384+
in e._excinfo[1].message
390385
)
391386
# Only the block guardrail should be triggered here
392-
assert "ogre detected in response" in str(original_error)
393-
assert "Fiona detected in response" not in str(original_error)
387+
assert "ogre detected in response" in str(e._excinfo[1])
388+
assert "Fiona detected in response" not in str(e._excinfo[1])
394389
else:
395390
response = client.models.generate_content_stream(**shrek_request)
396391

@@ -502,15 +497,14 @@ async def test_preguardrailing_with_guardrails_from_explorer(
502497
],
503498
)
504499
else:
505-
with pytest.raises(tenacity.RetryError) as exc_info:
500+
with pytest.raises(genai.errors.ClientError) as e:
506501
chat_response = client.models.generate_content(**request)
507-
original_error = exc_info.value.last_attempt.exception()
508-
assert isinstance(original_error, genai.errors.ClientError)
509-
assert original_error.code == 400
510-
assert "[Invariant] The request did not pass the guardrails" in str(
511-
original_error
502+
assert e._excinfo[1].code == 400
503+
assert (
504+
"[Invariant] The request did not pass the guardrails"
505+
in e._excinfo[1].message
512506
)
513-
assert "pun detected in user message" in str(original_error)
507+
assert "pun detected in user message" in str(e._excinfo[1])
514508
else:
515509
if do_stream:
516510
response = client.models.generate_content_stream(**request)

tests/integration/requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,4 @@ pytest
99
pytest-asyncio
1010
pytest-timeout
1111
tavily-python
12-
tenacity
1312
uv

0 commit comments

Comments
 (0)