Skip to content

Commit 5a95f8e

Browse files
authored
Adding tests for guardrail samples (#929)
* Adding tests for guardrail samples * Fix lint * Remove | syntax
1 parent 8fc55e7 commit 5a95f8e

File tree

1 file changed

+341
-0
lines changed

1 file changed

+341
-0
lines changed

tests/contrib/test_openai.py

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any, Optional, Union, no_type_check
77

88
import pytest
9+
from pydantic import ConfigDict, Field
910

1011
from temporalio import activity, workflow
1112
from temporalio.client import Client, WorkflowFailureError, WorkflowHandle
@@ -29,7 +30,9 @@
2930
from agents import (
3031
Agent,
3132
AgentOutputSchemaBase,
33+
GuardrailFunctionOutput,
3234
Handoff,
35+
InputGuardrailTripwireTriggered,
3336
ItemHelpers,
3437
MessageOutputItem,
3538
Model,
@@ -38,13 +41,16 @@
3841
ModelSettings,
3942
ModelTracing,
4043
OpenAIResponsesModel,
44+
OutputGuardrailTripwireTriggered,
4145
RunContextWrapper,
4246
Runner,
4347
Tool,
4448
TResponseInputItem,
4549
Usage,
4650
function_tool,
4751
handoff,
52+
input_guardrail,
53+
output_guardrail,
4854
trace,
4955
)
5056
from agents.extensions.handoff_prompt import RECOMMENDED_PROMPT_PREFIX
@@ -1151,3 +1157,338 @@ async def test_customer_service_workflow(client: Client):
11511157
.activity_task_completed_event_attributes.result.payloads[0]
11521158
.data.decode()
11531159
)
1160+
1161+
1162+
guardrail_response_index: int = 0
1163+
1164+
1165+
class InputGuardrailModel(OpenAIResponsesModel):
1166+
__test__ = False
1167+
responses: list[ModelResponse] = [
1168+
ModelResponse(
1169+
output=[
1170+
ResponseOutputMessage(
1171+
id="",
1172+
content=[
1173+
ResponseOutputText(
1174+
text="The capital of California is Sacramento.",
1175+
annotations=[],
1176+
type="output_text",
1177+
)
1178+
],
1179+
role="assistant",
1180+
status="completed",
1181+
type="message",
1182+
)
1183+
],
1184+
usage=Usage(),
1185+
response_id=None,
1186+
),
1187+
ModelResponse(
1188+
output=[
1189+
ResponseOutputMessage(
1190+
id="",
1191+
content=[
1192+
ResponseOutputText(
1193+
text="x=3",
1194+
annotations=[],
1195+
type="output_text",
1196+
)
1197+
],
1198+
role="assistant",
1199+
status="completed",
1200+
type="message",
1201+
)
1202+
],
1203+
usage=Usage(),
1204+
response_id=None,
1205+
),
1206+
]
1207+
guardrail_responses = [
1208+
ModelResponse(
1209+
output=[
1210+
ResponseOutputMessage(
1211+
id="",
1212+
content=[
1213+
ResponseOutputText(
1214+
text='{"is_math_homework":false,"reasoning":"The question asked is about the capital of California, which is a geography-related query, not math."}',
1215+
annotations=[],
1216+
type="output_text",
1217+
)
1218+
],
1219+
role="assistant",
1220+
status="completed",
1221+
type="message",
1222+
)
1223+
],
1224+
usage=Usage(),
1225+
response_id=None,
1226+
),
1227+
ModelResponse(
1228+
output=[
1229+
ResponseOutputMessage(
1230+
id="",
1231+
content=[
1232+
ResponseOutputText(
1233+
text='{"is_math_homework":true,"reasoning":"The question involves solving an equation for a variable, which is a typical math homework problem."}',
1234+
annotations=[],
1235+
type="output_text",
1236+
)
1237+
],
1238+
role="assistant",
1239+
status="completed",
1240+
type="message",
1241+
)
1242+
],
1243+
usage=Usage(),
1244+
response_id=None,
1245+
),
1246+
]
1247+
1248+
def __init__(
1249+
self,
1250+
model: str,
1251+
openai_client: AsyncOpenAI,
1252+
) -> None:
1253+
global response_index
1254+
response_index = 0
1255+
global guardrail_response_index
1256+
guardrail_response_index = 0
1257+
super().__init__(model, openai_client)
1258+
1259+
async def get_response(
1260+
self,
1261+
system_instructions: Union[str, None],
1262+
input: Union[str, list[TResponseInputItem]],
1263+
model_settings: ModelSettings,
1264+
tools: list[Tool],
1265+
output_schema: Union[AgentOutputSchemaBase, None],
1266+
handoffs: list[Handoff],
1267+
tracing: ModelTracing,
1268+
previous_response_id: Union[str, None],
1269+
prompt: Union[ResponsePromptParam, None] = None,
1270+
) -> ModelResponse:
1271+
if (
1272+
system_instructions
1273+
== "Check if the user is asking you to do their math homework."
1274+
):
1275+
global guardrail_response_index
1276+
response = self.guardrail_responses[guardrail_response_index]
1277+
guardrail_response_index += 1
1278+
return response
1279+
else:
1280+
global response_index
1281+
response = self.responses[response_index]
1282+
response_index += 1
1283+
return response
1284+
1285+
1286+
### 1. An agent-based guardrail that is triggered if the user is asking to do math homework
1287+
class MathHomeworkOutput(BaseModel):
1288+
reasoning: str
1289+
is_math_homework: bool
1290+
model_config = ConfigDict(extra="forbid")
1291+
1292+
1293+
guardrail_agent: Agent = Agent(
1294+
name="Guardrail check",
1295+
instructions="Check if the user is asking you to do their math homework.",
1296+
output_type=MathHomeworkOutput,
1297+
)
1298+
1299+
1300+
@input_guardrail
1301+
async def math_guardrail(
1302+
context: RunContextWrapper[None],
1303+
agent: Agent,
1304+
input: Union[str, list[TResponseInputItem]],
1305+
) -> GuardrailFunctionOutput:
1306+
"""This is an input guardrail function, which happens to call an agent to check if the input
1307+
is a math homework question.
1308+
"""
1309+
result = await Runner.run(guardrail_agent, input, context=context.context)
1310+
final_output = result.final_output_as(MathHomeworkOutput)
1311+
1312+
return GuardrailFunctionOutput(
1313+
output_info=final_output,
1314+
tripwire_triggered=final_output.is_math_homework,
1315+
)
1316+
1317+
1318+
@workflow.defn
1319+
class InputGuardrailWorkflow:
1320+
@workflow.run
1321+
async def run(self, messages: list[str]) -> list[str]:
1322+
agent = Agent(
1323+
name="Customer support agent",
1324+
instructions="You are a customer support agent. You help customers with their questions.",
1325+
input_guardrails=[math_guardrail],
1326+
)
1327+
1328+
input_data: list[TResponseInputItem] = []
1329+
results: list[str] = []
1330+
1331+
for user_input in messages:
1332+
input_data.append(
1333+
{
1334+
"role": "user",
1335+
"content": user_input,
1336+
}
1337+
)
1338+
1339+
try:
1340+
result = await Runner.run(agent, input_data)
1341+
results.append(result.final_output)
1342+
# If the guardrail didn't trigger, we use the result as the input for the next run
1343+
input_data = result.to_input_list()
1344+
except InputGuardrailTripwireTriggered:
1345+
# If the guardrail triggered, we instead add a refusal message to the input
1346+
message = "Sorry, I can't help you with your math homework."
1347+
results.append(message)
1348+
input_data.append(
1349+
{
1350+
"role": "assistant",
1351+
"content": message,
1352+
}
1353+
)
1354+
return results
1355+
1356+
1357+
async def test_input_guardrail(client: Client):
1358+
new_config = client.config()
1359+
new_config["data_converter"] = open_ai_data_converter
1360+
client = Client(**new_config)
1361+
1362+
model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=10))
1363+
with set_open_ai_agent_temporal_overrides(model_params):
1364+
model_activity = ModelActivity(
1365+
TestProvider(
1366+
InputGuardrailModel( # type: ignore
1367+
"", openai_client=AsyncOpenAI(api_key="Fake key")
1368+
)
1369+
)
1370+
)
1371+
async with new_worker(
1372+
client,
1373+
InputGuardrailWorkflow,
1374+
activities=[model_activity.invoke_model_activity],
1375+
interceptors=[OpenAIAgentsTracingInterceptor()],
1376+
) as worker:
1377+
workflow_handle = await client.start_workflow(
1378+
InputGuardrailWorkflow.run,
1379+
[
1380+
"What's the capital of California?",
1381+
"Can you help me solve for x: 2x + 5 = 11",
1382+
],
1383+
id=f"input-guardrail-{uuid.uuid4()}",
1384+
task_queue=worker.task_queue,
1385+
execution_timeout=timedelta(seconds=10),
1386+
)
1387+
result = await workflow_handle.result()
1388+
assert len(result) == 2
1389+
assert result[0] == "The capital of California is Sacramento."
1390+
assert result[1] == "Sorry, I can't help you with your math homework."
1391+
1392+
1393+
class OutputGuardrailModel(TestModel):
1394+
responses = [
1395+
ModelResponse(
1396+
output=[
1397+
ResponseOutputMessage(
1398+
id="",
1399+
content=[
1400+
ResponseOutputText(
1401+
text='{"reasoning":"The phone number\'s area code (650) is associated with a region. However, the exact location is not definitive, but it\'s commonly linked to the San Francisco Peninsula in California, including cities like San Mateo, Palo Alto, and parts of Silicon Valley. It\'s important to note that area codes don\'t always guarantee a specific location due to mobile number portability.","response":"The area code 650 is typically associated with California, particularly the San Francisco Peninsula, including cities like Palo Alto and San Mateo.","user_name":null}',
1402+
annotations=[],
1403+
type="output_text",
1404+
)
1405+
],
1406+
role="assistant",
1407+
status="completed",
1408+
type="message",
1409+
)
1410+
],
1411+
usage=Usage(),
1412+
response_id=None,
1413+
)
1414+
]
1415+
1416+
1417+
# The agent's output type
1418+
class MessageOutput(BaseModel):
1419+
reasoning: str = Field(
1420+
description="Thoughts on how to respond to the user's message"
1421+
)
1422+
response: str = Field(description="The response to the user's message")
1423+
user_name: Optional[str] = Field(
1424+
description="The name of the user who sent the message, if known"
1425+
)
1426+
model_config = ConfigDict(extra="forbid")
1427+
1428+
1429+
@output_guardrail
1430+
async def sensitive_data_check(
1431+
context: RunContextWrapper, agent: Agent, output: MessageOutput
1432+
) -> GuardrailFunctionOutput:
1433+
phone_number_in_response = "650" in output.response
1434+
phone_number_in_reasoning = "650" in output.reasoning
1435+
1436+
return GuardrailFunctionOutput(
1437+
output_info={
1438+
"phone_number_in_response": phone_number_in_response,
1439+
"phone_number_in_reasoning": phone_number_in_reasoning,
1440+
},
1441+
tripwire_triggered=phone_number_in_response or phone_number_in_reasoning,
1442+
)
1443+
1444+
1445+
output_guardrail_agent = Agent(
1446+
name="Assistant",
1447+
instructions="You are a helpful assistant.",
1448+
output_type=MessageOutput,
1449+
output_guardrails=[sensitive_data_check],
1450+
)
1451+
1452+
1453+
@workflow.defn
1454+
class OutputGuardrailWorkflow:
1455+
@workflow.run
1456+
async def run(self) -> bool:
1457+
try:
1458+
await Runner.run(
1459+
output_guardrail_agent,
1460+
"My phone number is 650-123-4567. Where do you think I live?",
1461+
)
1462+
return True
1463+
except OutputGuardrailTripwireTriggered:
1464+
return False
1465+
1466+
1467+
async def test_output_guardrail(client: Client):
1468+
new_config = client.config()
1469+
new_config["data_converter"] = open_ai_data_converter
1470+
client = Client(**new_config)
1471+
1472+
model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=10))
1473+
with set_open_ai_agent_temporal_overrides(model_params):
1474+
model_activity = ModelActivity(
1475+
TestProvider(
1476+
OutputGuardrailModel( # type: ignore
1477+
"", openai_client=AsyncOpenAI(api_key="Fake key")
1478+
)
1479+
)
1480+
)
1481+
async with new_worker(
1482+
client,
1483+
OutputGuardrailWorkflow,
1484+
activities=[model_activity.invoke_model_activity],
1485+
interceptors=[OpenAIAgentsTracingInterceptor()],
1486+
) as worker:
1487+
workflow_handle = await client.start_workflow(
1488+
OutputGuardrailWorkflow.run,
1489+
id=f"output-guardrail-{uuid.uuid4()}",
1490+
task_queue=worker.task_queue,
1491+
execution_timeout=timedelta(seconds=10),
1492+
)
1493+
result = await workflow_handle.result()
1494+
assert not result

0 commit comments

Comments
 (0)