|
1 | 1 | """Test the guardrails from file with the Gemini route."""
|
2 | 2 |
|
| 3 | +# pylint: disable=protected-access |
| 4 | + |
3 | 5 | import os
|
4 | 6 | import sys
|
5 | 7 | import time
|
|
10 | 12 |
|
11 | 13 | import pytest
|
12 | 14 | import requests
|
13 |
| -import tenacity |
14 | 15 | from google import genai
|
15 | 16 | from utils import add_guardrail_to_dataset, create_dataset, get_gemini_client
|
16 | 17 |
|
@@ -46,17 +47,16 @@ async def test_message_content_guardrail_from_file(
|
46 | 47 | }
|
47 | 48 |
|
48 | 49 | if not do_stream:
|
49 |
| - with pytest.raises(tenacity.RetryError) as exc_info: |
| 50 | + with pytest.raises(genai.errors.ClientError) as e: |
50 | 51 | response = client.models.generate_content(
|
51 | 52 | **request,
|
52 | 53 | )
|
53 |
| - original_error = exc_info.value.last_attempt.exception() |
54 |
| - assert isinstance(original_error, genai.errors.ClientError) |
55 |
| - assert original_error.code == 400 |
56 |
| - assert "[Invariant] The response did not pass the guardrails" in str( |
57 |
| - original_error |
| 54 | + assert e._excinfo[1].code == 400 |
| 55 | + assert ( |
| 56 | + "[Invariant] The response did not pass the guardrails" |
| 57 | + in e._excinfo[1].message |
58 | 58 | )
|
59 |
| - assert "Dublin detected in the response" in str(original_error) |
| 59 | + assert "Dublin detected in the response" in str(e._excinfo[1]) |
60 | 60 |
|
61 | 61 | else:
|
62 | 62 | response = client.models.generate_content_stream(**request)
|
@@ -151,17 +151,16 @@ def get_capital(country_name: str) -> str:
|
151 | 151 | }
|
152 | 152 |
|
153 | 153 | if not do_stream:
|
154 |
| - with pytest.raises(tenacity.RetryError) as exc_info: |
| 154 | + with pytest.raises(genai.errors.ClientError) as e: |
155 | 155 | client.models.generate_content(
|
156 | 156 | **request,
|
157 | 157 | )
|
158 |
| - original_error = exc_info.value.last_attempt.exception() |
159 |
| - assert isinstance(original_error, genai.errors.ClientError) |
160 |
| - assert original_error.code == 400 |
161 |
| - assert "[Invariant] The response did not pass the guardrails" in str( |
162 |
| - original_error |
| 158 | + assert e._excinfo[1].code == 400 |
| 159 | + assert ( |
| 160 | + "[Invariant] The response did not pass the guardrails" |
| 161 | + in e._excinfo[1].message |
163 | 162 | )
|
164 |
| - assert "get_capital is called with Germany as argument" in str(original_error) |
| 163 | + assert "get_capital is called with Germany as argument" in str(e._excinfo[1]) |
165 | 164 |
|
166 | 165 | else:
|
167 | 166 | response = client.models.generate_content_stream(
|
@@ -250,17 +249,15 @@ async def test_input_from_guardrail_from_file(
|
250 | 249 | }
|
251 | 250 |
|
252 | 251 | if not do_stream:
|
253 |
| - with pytest.raises(tenacity.RetryError) as exc_info: |
| 252 | + with pytest.raises(genai.errors.ClientError) as e: |
254 | 253 | client.models.generate_content(**request)
|
255 |
| - |
256 |
| - original_error = exc_info.value.last_attempt.exception() |
257 |
| - assert isinstance(original_error, genai.errors.ClientError) |
258 |
| - assert original_error.code == 400 |
259 |
| - assert "[Invariant] The request did not pass the guardrails" in str( |
260 |
| - original_error |
| 254 | + assert e._excinfo[1].code == 400 |
| 255 | + assert ( |
| 256 | + "[Invariant] The request did not pass the guardrails" |
| 257 | + in e._excinfo[1].message |
261 | 258 | )
|
262 | 259 | assert "Users must not mention the magic phrase 'Fight Club'" in str(
|
263 |
| - original_error |
| 260 | + e._excinfo[1] |
264 | 261 | )
|
265 | 262 |
|
266 | 263 | else:
|
@@ -379,18 +376,16 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s
|
379 | 376 | },
|
380 | 377 | }
|
381 | 378 | if not do_stream:
|
382 |
| - with pytest.raises(tenacity.RetryError) as exc_info: |
| 379 | + with pytest.raises(genai.errors.ClientError) as e: |
383 | 380 | client.models.generate_content(**shrek_request)
|
384 |
| - |
385 |
| - original_error = exc_info.value.last_attempt.exception() |
386 |
| - assert isinstance(original_error, genai.errors.ClientError) |
387 |
| - assert original_error.code == 400 |
388 |
| - assert "[Invariant] The response did not pass the guardrails" in str( |
389 |
| - original_error |
| 381 | + assert e._excinfo[1].code == 400 |
| 382 | + assert ( |
| 383 | + "[Invariant] The response did not pass the guardrails" |
| 384 | + in e._excinfo[1].message |
390 | 385 | )
|
391 | 386 | # Only the block guardrail should be triggered here
|
392 |
| - assert "ogre detected in response" in str(original_error) |
393 |
| - assert "Fiona detected in response" not in str(original_error) |
| 387 | + assert "ogre detected in response" in str(e._excinfo[1]) |
| 388 | + assert "Fiona detected in response" not in str(e._excinfo[1]) |
394 | 389 | else:
|
395 | 390 | response = client.models.generate_content_stream(**shrek_request)
|
396 | 391 |
|
@@ -502,15 +497,14 @@ async def test_preguardrailing_with_guardrails_from_explorer(
|
502 | 497 | ],
|
503 | 498 | )
|
504 | 499 | else:
|
505 |
| - with pytest.raises(tenacity.RetryError) as exc_info: |
| 500 | + with pytest.raises(genai.errors.ClientError) as e: |
506 | 501 | chat_response = client.models.generate_content(**request)
|
507 |
| - original_error = exc_info.value.last_attempt.exception() |
508 |
| - assert isinstance(original_error, genai.errors.ClientError) |
509 |
| - assert original_error.code == 400 |
510 |
| - assert "[Invariant] The request did not pass the guardrails" in str( |
511 |
| - original_error |
| 502 | + assert e._excinfo[1].code == 400 |
| 503 | + assert ( |
| 504 | + "[Invariant] The request did not pass the guardrails" |
| 505 | + in e._excinfo[1].message |
512 | 506 | )
|
513 |
| - assert "pun detected in user message" in str(original_error) |
| 507 | + assert "pun detected in user message" in str(e._excinfo[1]) |
514 | 508 | else:
|
515 | 509 | if do_stream:
|
516 | 510 | response = client.models.generate_content_stream(**request)
|
|
0 commit comments