|
2 | 2 |
|
3 | 3 | import os
|
4 | 4 | import sys
|
5 |
| -import uuid |
6 | 5 | import time
|
| 6 | +import uuid |
7 | 7 |
|
8 | 8 | # Add integration folder (parent) to sys.path
|
9 | 9 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
10 | 10 |
|
11 |
| -from utils import get_gemini_client, create_dataset, add_guardrail_to_dataset |
12 |
| - |
13 | 11 | import pytest
|
14 | 12 | import requests
|
| 13 | +import tenacity |
15 | 14 | from google import genai
|
| 15 | +from utils import add_guardrail_to_dataset, create_dataset, get_gemini_client |
16 | 16 |
|
17 | 17 | # Pytest plugins
|
18 | 18 | pytest_plugins = ("pytest_asyncio",)
|
@@ -46,14 +46,17 @@ async def test_message_content_guardrail_from_file(
|
46 | 46 | }
|
47 | 47 |
|
48 | 48 | if not do_stream:
|
49 |
| - with pytest.raises(genai.errors.ClientError) as exc_info: |
| 49 | + with pytest.raises(tenacity.RetryError) as exc_info: |
50 | 50 | response = client.models.generate_content(
|
51 | 51 | **request,
|
52 | 52 | )
|
53 |
| - assert "[Invariant] The response did not pass the guardrails" in str( |
54 |
| - exc_info |
55 |
| - ) |
56 |
| - assert "Dublin detected in the response" in str(exc_info) |
| 53 | + original_error = exc_info.value.last_attempt.exception() |
| 54 | + assert isinstance(original_error, genai.errors.ClientError) |
| 55 | + assert original_error.code == 400 |
| 56 | + assert "[Invariant] The response did not pass the guardrails" in str( |
| 57 | + original_error |
| 58 | + ) |
| 59 | + assert "Dublin detected in the response" in str(original_error) |
57 | 60 |
|
58 | 61 | else:
|
59 | 62 | response = client.models.generate_content_stream(**request)
|
@@ -148,16 +151,17 @@ def get_capital(country_name: str) -> str:
|
148 | 151 | }
|
149 | 152 |
|
150 | 153 | if not do_stream:
|
151 |
| - with pytest.raises(genai.errors.ClientError) as exc_info: |
| 154 | + with pytest.raises(tenacity.RetryError) as exc_info: |
152 | 155 | client.models.generate_content(
|
153 | 156 | **request,
|
154 | 157 | )
|
155 |
| - |
156 |
| - assert exc_info.value.status_code == 400 |
157 |
| - assert "[Invariant] The response did not pass the guardrails" in str( |
158 |
| - exc_info |
159 |
| - ) |
160 |
| - assert "get_capital is called with Germany as argument" in str(exc_info) |
| 158 | + original_error = exc_info.value.last_attempt.exception() |
| 159 | + assert isinstance(original_error, genai.errors.ClientError) |
| 160 | + assert original_error.code == 400 |
| 161 | + assert "[Invariant] The response did not pass the guardrails" in str( |
| 162 | + original_error |
| 163 | + ) |
| 164 | + assert "get_capital is called with Germany as argument" in str(original_error) |
161 | 165 |
|
162 | 166 | else:
|
163 | 167 | response = client.models.generate_content_stream(
|
@@ -246,14 +250,17 @@ async def test_input_from_guardrail_from_file(
|
246 | 250 | }
|
247 | 251 |
|
248 | 252 | if not do_stream:
|
249 |
| - with pytest.raises(genai.errors.ClientError) as exc_info: |
| 253 | + with pytest.raises(tenacity.RetryError) as exc_info: |
250 | 254 | client.models.generate_content(**request)
|
251 | 255 |
|
| 256 | + original_error = exc_info.value.last_attempt.exception() |
| 257 | + assert isinstance(original_error, genai.errors.ClientError) |
| 258 | + assert original_error.code == 400 |
252 | 259 | assert "[Invariant] The request did not pass the guardrails" in str(
|
253 |
| - exc_info.value |
| 260 | + original_error |
254 | 261 | )
|
255 | 262 | assert "Users must not mention the magic phrase 'Fight Club'" in str(
|
256 |
| - exc_info.value |
| 263 | + original_error |
257 | 264 | )
|
258 | 265 |
|
259 | 266 | else:
|
@@ -372,15 +379,18 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s
|
372 | 379 | },
|
373 | 380 | }
|
374 | 381 | if not do_stream:
|
375 |
| - with pytest.raises(genai.errors.ClientError) as exc_info: |
| 382 | + with pytest.raises(tenacity.RetryError) as exc_info: |
376 | 383 | client.models.generate_content(**shrek_request)
|
377 | 384 |
|
| 385 | + original_error = exc_info.value.last_attempt.exception() |
| 386 | + assert isinstance(original_error, genai.errors.ClientError) |
| 387 | + assert original_error.code == 400 |
378 | 388 | assert "[Invariant] The response did not pass the guardrails" in str(
|
379 |
| - exc_info.value |
| 389 | + original_error |
380 | 390 | )
|
381 | 391 | # Only the block guardrail should be triggered here
|
382 |
| - assert "ogre detected in response" in str(exc_info.value) |
383 |
| - assert "Fiona detected in response" not in str(exc_info.value) |
| 392 | + assert "ogre detected in response" in str(original_error) |
| 393 | + assert "Fiona detected in response" not in str(original_error) |
384 | 394 | else:
|
385 | 395 | response = client.models.generate_content_stream(**shrek_request)
|
386 | 396 |
|
@@ -492,12 +502,15 @@ async def test_preguardrailing_with_guardrails_from_explorer(
|
492 | 502 | ],
|
493 | 503 | )
|
494 | 504 | else:
|
495 |
| - with pytest.raises(genai.errors.ClientError) as exc_info: |
| 505 | + with pytest.raises(tenacity.RetryError) as exc_info: |
496 | 506 | chat_response = client.models.generate_content(**request)
|
| 507 | + original_error = exc_info.value.last_attempt.exception() |
| 508 | + assert isinstance(original_error, genai.errors.ClientError) |
| 509 | + assert original_error.code == 400 |
497 | 510 | assert "[Invariant] The request did not pass the guardrails" in str(
|
498 |
| - exc_info.value |
| 511 | + original_error |
499 | 512 | )
|
500 |
| - assert "pun detected in user message" in str(exc_info.value) |
| 513 | + assert "pun detected in user message" in str(original_error) |
501 | 514 | else:
|
502 | 515 | if do_stream:
|
503 | 516 | response = client.models.generate_content_stream(**request)
|
|
0 commit comments