Skip to content

Commit fe82b07

Browse files
authored
Add exception catching logic to model activity (#982)
* Add exception catching logic to model activity * Fix timeout * Retry after + fixes * Linting * Only set timeouts for default model provider
1 parent 9e7dc7a commit fe82b07

File tree

6 files changed

+154
-29
lines changed

6 files changed

+154
-29
lines changed

temporalio/contrib/openai_agents/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@
99
"""
1010

1111
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
12-
from temporalio.contrib.openai_agents.temporal_openai_agents import (
12+
from temporalio.contrib.openai_agents._temporal_openai_agents import (
1313
OpenAIAgentsPlugin,
1414
TestModel,
1515
TestModelProvider,
1616
)
17+
from temporalio.contrib.openai_agents._trace_interceptor import (
18+
OpenAIAgentsTracingInterceptor,
19+
)
1720

1821
from . import workflow
1922

temporalio/contrib/openai_agents/_invoke_model_activity.py

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import enum
77
import json
88
from dataclasses import dataclass
9+
from datetime import timedelta
910
from typing import Any, Optional, Union, cast
1011

1112
from agents import (
@@ -17,17 +18,25 @@
1718
ModelResponse,
1819
ModelSettings,
1920
ModelTracing,
21+
OpenAIProvider,
2022
RunContextWrapper,
2123
Tool,
2224
TResponseInputItem,
2325
UserError,
2426
WebSearchTool,
2527
)
2628
from agents.models.multi_provider import MultiProvider
29+
from openai import (
30+
APIStatusError,
31+
AsyncOpenAI,
32+
AuthenticationError,
33+
PermissionDeniedError,
34+
)
2735
from typing_extensions import Required, TypedDict
2836

2937
from temporalio import activity
3038
from temporalio.contrib.openai_agents._heartbeat_decorator import _auto_heartbeater
39+
from temporalio.exceptions import ApplicationError
3140

3241

3342
@dataclass
@@ -117,11 +126,15 @@ class ActivityModelInput(TypedDict, total=False):
117126

118127

119128
class ModelActivity:
120-
"""Class wrapper for model invocation activities to allow model customization."""
129+
"""Class wrapper for model invocation activities to allow model customization. By default, we use an OpenAIProvider with retries disabled.
130+
Disabling retries in your model of choice is recommended to allow activity retries to define the retry model.
131+
"""
121132

122133
def __init__(self, model_provider: Optional[ModelProvider] = None):
123134
"""Initialize the activity with a model provider."""
124-
self._model_provider = model_provider or MultiProvider()
135+
self._model_provider = model_provider or OpenAIProvider(
136+
openai_client=AsyncOpenAI(max_retries=0)
137+
)
125138

126139
@activity.defn
127140
@_auto_heartbeater
@@ -171,14 +184,51 @@ def make_tool(tool: ToolInput) -> Tool:
171184
)
172185
for x in input.get("handoffs", [])
173186
]
174-
return await model.get_response(
175-
system_instructions=input.get("system_instructions"),
176-
input=input_input,
177-
model_settings=input["model_settings"],
178-
tools=tools,
179-
output_schema=input.get("output_schema"),
180-
handoffs=handoffs,
181-
tracing=ModelTracing(input["tracing"]),
182-
previous_response_id=input.get("previous_response_id"),
183-
prompt=input.get("prompt"),
184-
)
187+
188+
try:
189+
return await model.get_response(
190+
system_instructions=input.get("system_instructions"),
191+
input=input_input,
192+
model_settings=input["model_settings"],
193+
tools=tools,
194+
output_schema=input.get("output_schema"),
195+
handoffs=handoffs,
196+
tracing=ModelTracing(input["tracing"]),
197+
previous_response_id=input.get("previous_response_id"),
198+
prompt=input.get("prompt"),
199+
)
200+
except APIStatusError as e:
201+
# Listen to server hints
202+
retry_after = None
203+
retry_after_ms_header = e.response.headers.get("retry-after-ms")
204+
if retry_after_ms_header is not None:
205+
retry_after = timedelta(milliseconds=float(retry_after_ms_header))
206+
207+
if retry_after is None:
208+
retry_after_header = e.response.headers.get("retry-after")
209+
if retry_after_header is not None:
210+
retry_after = timedelta(seconds=float(retry_after_header))
211+
212+
should_retry_header = e.response.headers.get("x-should-retry")
213+
if should_retry_header == "true":
214+
raise e
215+
if should_retry_header == "false":
216+
raise ApplicationError(
217+
"Non retryable OpenAI error",
218+
non_retryable=True,
219+
next_retry_delay=retry_after,
220+
) from e
221+
222+
# Specifically retryable status codes
223+
if e.response.status_code in [408, 409, 429, 500]:
224+
raise ApplicationError(
225+
"Retryable OpenAI status code",
226+
non_retryable=False,
227+
next_retry_delay=retry_after,
228+
) from e
229+
230+
raise ApplicationError(
231+
"Non retryable OpenAI status code",
232+
non_retryable=True,
233+
next_retry_delay=retry_after,
234+
) from e

temporalio/contrib/openai_agents/_model_parameters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@ class ModelActivityParameters:
2020
task_queue: Optional[str] = None
2121
"""Specific task queue to use for model activities."""
2222

23-
schedule_to_close_timeout: Optional[timedelta] = timedelta(seconds=60)
23+
schedule_to_close_timeout: Optional[timedelta] = None
2424
"""Maximum time from scheduling to completion."""
2525

2626
schedule_to_start_timeout: Optional[timedelta] = None
2727
"""Maximum time from scheduling to starting."""
2828

29-
start_to_close_timeout: Optional[timedelta] = None
29+
start_to_close_timeout: Optional[timedelta] = timedelta(seconds=60)
3030
"""Maximum time for the activity to complete."""
3131

3232
heartbeat_timeout: Optional[timedelta] = None

temporalio/contrib/openai_agents/temporal_openai_agents.py renamed to temporalio/contrib/openai_agents/_temporal_openai_agents.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Initialize Temporal OpenAI Agents overrides."""
22

33
from contextlib import contextmanager
4+
from datetime import timedelta
45
from typing import AsyncIterator, Callable, Optional, Union
56

67
from agents import (
@@ -39,7 +40,7 @@
3940

4041
@contextmanager
4142
def set_open_ai_agent_temporal_overrides(
42-
model_params: Optional[ModelActivityParameters] = None,
43+
model_params: ModelActivityParameters,
4344
auto_close_tracing_in_workflows: bool = False,
4445
):
4546
"""Configure Temporal-specific overrides for OpenAI agents.
@@ -69,14 +70,6 @@ def set_open_ai_agent_temporal_overrides(
6970
if model_params is None:
7071
model_params = ModelActivityParameters()
7172

72-
if (
73-
not model_params.start_to_close_timeout
74-
and not model_params.schedule_to_close_timeout
75-
):
76-
raise ValueError(
77-
"Activity must have start_to_close_timeout or schedule_to_close_timeout"
78-
)
79-
8073
previous_runner = get_default_agent_runner()
8174
previous_trace_provider = get_trace_provider()
8275
provider = TemporalTraceProvider(
@@ -208,6 +201,22 @@ def __init__(
208201
model_provider: Optional model provider for custom model implementations.
209202
Useful for testing or custom model integrations.
210203
"""
204+
if model_params is None:
205+
model_params = ModelActivityParameters()
206+
207+
# For the default provider, we provide a default start_to_close_timeout of 60 seconds.
208+
# Other providers will need to define their own.
209+
if (
210+
model_params.start_to_close_timeout is None
211+
and model_params.schedule_to_close_timeout is None
212+
):
213+
if model_provider is None:
214+
model_params.start_to_close_timeout = timedelta(seconds=60)
215+
else:
216+
raise ValueError(
217+
"When configuring a custom provider, the model activity must have start_to_close_timeout or schedule_to_close_timeout"
218+
)
219+
211220
self._model_params = model_params
212221
self._model_provider = model_provider
213222

tests/contrib/openai_agents/test_openai.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
ToolCallItem,
3838
ToolCallOutputItem,
3939
)
40-
from openai import AsyncOpenAI, BaseModel
40+
from openai import APIStatusError, AsyncOpenAI, BaseModel
4141
from openai.types.responses import (
4242
ResponseFunctionToolCall,
4343
ResponseFunctionWebSearch,
@@ -48,16 +48,18 @@
4848
from openai.types.responses.response_prompt_param import ResponsePromptParam
4949
from pydantic import ConfigDict, Field, TypeAdapter
5050

51+
import temporalio.api.cloud.namespace.v1
5152
from temporalio import activity, workflow
5253
from temporalio.client import Client, WorkflowFailureError, WorkflowHandle
54+
from temporalio.common import RetryPolicy, SearchAttributeValueType
5355
from temporalio.contrib import openai_agents
5456
from temporalio.contrib.openai_agents import (
5557
ModelActivityParameters,
5658
TestModel,
5759
TestModelProvider,
5860
)
5961
from temporalio.contrib.pydantic import pydantic_data_converter
60-
from temporalio.exceptions import CancelledError
62+
from temporalio.exceptions import ApplicationError, CancelledError
6163
from temporalio.testing import WorkflowEnvironment
6264
from tests.contrib.openai_agents.research_agents.research_manager import (
6365
ResearchManager,
@@ -1777,3 +1779,63 @@ async def test_response_serialization():
17771779
response_id="",
17781780
)
17791781
encoded = await pydantic_data_converter.encode([model_response])
1782+
1783+
1784+
async def assert_status_retry_behavior(status: int, client: Client, should_retry: bool):
1785+
def status_error(status: int):
1786+
with workflow.unsafe.imports_passed_through():
1787+
with workflow.unsafe.sandbox_unrestricted():
1788+
import httpx
1789+
raise APIStatusError(
1790+
message="Something went wrong.",
1791+
response=httpx.Response(
1792+
status_code=status, request=httpx.Request("GET", url="")
1793+
),
1794+
body=None,
1795+
)
1796+
1797+
new_config = client.config()
1798+
new_config["plugins"] = [
1799+
openai_agents.OpenAIAgentsPlugin(
1800+
model_params=ModelActivityParameters(
1801+
retry_policy=RetryPolicy(maximum_attempts=2),
1802+
),
1803+
model_provider=TestModelProvider(TestModel(lambda: status_error(status))),
1804+
)
1805+
]
1806+
client = Client(**new_config)
1807+
1808+
async with new_worker(
1809+
client,
1810+
HelloWorldAgent,
1811+
) as worker:
1812+
workflow_handle = await client.start_workflow(
1813+
HelloWorldAgent.run,
1814+
"Input",
1815+
id=f"workflow-tool-{uuid.uuid4()}",
1816+
task_queue=worker.task_queue,
1817+
execution_timeout=timedelta(seconds=10),
1818+
)
1819+
with pytest.raises(WorkflowFailureError) as e:
1820+
await workflow_handle.result()
1821+
1822+
found = False
1823+
async for event in workflow_handle.fetch_history_events():
1824+
if event.HasField("activity_task_started_event_attributes"):
1825+
found = True
1826+
if should_retry:
1827+
assert event.activity_task_started_event_attributes.attempt == 2
1828+
else:
1829+
assert event.activity_task_started_event_attributes.attempt == 1
1830+
assert found
1831+
1832+
1833+
async def test_exception_handling(client: Client):
1834+
await assert_status_retry_behavior(408, client, should_retry=True)
1835+
await assert_status_retry_behavior(409, client, should_retry=True)
1836+
await assert_status_retry_behavior(429, client, should_retry=True)
1837+
await assert_status_retry_behavior(500, client, should_retry=True)
1838+
1839+
await assert_status_retry_behavior(400, client, should_retry=False)
1840+
await assert_status_retry_behavior(403, client, should_retry=False)
1841+
await assert_status_retry_behavior(404, client, should_retry=False)

tests/contrib/openai_agents/test_openai_replay.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import pytest
44

55
from temporalio.client import WorkflowHistory
6-
from temporalio.contrib.openai_agents.temporal_openai_agents import (
6+
from temporalio.contrib.openai_agents import ModelActivityParameters
7+
from temporalio.contrib.openai_agents._temporal_openai_agents import (
78
set_open_ai_agent_temporal_overrides,
89
)
910
from temporalio.contrib.pydantic import pydantic_data_converter
@@ -35,7 +36,7 @@ async def test_replay(file_name: str) -> None:
3536
with (Path(__file__).with_name("histories") / file_name).open("r") as f:
3637
history_json = f.read()
3738

38-
with set_open_ai_agent_temporal_overrides():
39+
with set_open_ai_agent_temporal_overrides(ModelActivityParameters()):
3940
await Replayer(
4041
workflows=[
4142
ResearchWorkflow,

0 commit comments

Comments
 (0)