Skip to content

Commit 4cf0697

Browse files
authored
Adding named params to openai activity configurations (#917)
* Adding named params to openai activity configurations * Remove a few unused imports * Move model parameters into an object, remove activity id * Refer argument docs to , early check timeout * Removing example usage because of pydoctor errors
1 parent 40f1624 commit 4cf0697

File tree

7 files changed

+148
-60
lines changed

7 files changed

+148
-60
lines changed

temporalio/contrib/openai_agents/README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ from datetime import timedelta
9696

9797
from temporalio.client import Client
9898
from temporalio.contrib.openai_agents.invoke_model_activity import ModelActivity
99+
from temporalio.contrib.openai_agents.model_parameters import ModelActivityParameters
99100
from temporalio.contrib.openai_agents.open_ai_data_converter import open_ai_data_converter
100101
from temporalio.contrib.openai_agents.temporal_openai_agents import set_open_ai_agent_temporal_overrides
101102
from temporalio.worker import Worker
@@ -105,9 +106,10 @@ from hello_world_workflow import HelloWorldAgent
105106
async def worker_main():
106107
# Configure the OpenAI Agents SDK to use Temporal activities for LLM API calls
107108
# and for tool calls.
108-
with set_open_ai_agent_temporal_overrides(
109+
model_params = ModelActivityParameters(
109110
start_to_close_timeout=timedelta(seconds=10)
110-
):
111+
)
112+
with set_open_ai_agent_temporal_overrides(model_params):
111113
# Create a Temporal client connected to server at the given address
112114
# Use the OpenAI data converter to ensure proper serialization/deserialization
113115
client = await Client.connect(

temporalio/contrib/openai_agents/_openai_runner.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from dataclasses import replace
2-
from typing import Union
2+
from datetime import timedelta
3+
from typing import Optional, Union
34

45
from agents import (
56
Agent,
@@ -13,7 +14,10 @@
1314
from agents.run import DEFAULT_AGENT_RUNNER, DEFAULT_MAX_TURNS, AgentRunner
1415

1516
from temporalio import workflow
17+
from temporalio.common import Priority, RetryPolicy
1618
from temporalio.contrib.openai_agents._temporal_model_stub import _TemporalModelStub
19+
from temporalio.contrib.openai_agents.model_parameters import ModelActivityParameters
20+
from temporalio.workflow import ActivityCancellationType, VersioningIntent
1721

1822

1923
class TemporalOpenAIRunner(AgentRunner):
@@ -23,10 +27,10 @@ class TemporalOpenAIRunner(AgentRunner):
2327
2428
"""
2529

26-
def __init__(self, **kwargs) -> None:
30+
def __init__(self, model_params: ModelActivityParameters) -> None:
2731
"""Initialize the Temporal OpenAI Runner."""
2832
self._runner = DEFAULT_AGENT_RUNNER or AgentRunner()
29-
self.kwargs = kwargs
33+
self.model_params = model_params
3034

3135
async def run(
3236
self,
@@ -56,7 +60,11 @@ async def run(
5660
"Temporal workflows require a model name to be a string in the run config."
5761
)
5862
updated_run_config = replace(
59-
run_config, model=_TemporalModelStub(run_config.model, **self.kwargs)
63+
run_config,
64+
model=_TemporalModelStub(
65+
run_config.model,
66+
model_params=self.model_params,
67+
),
6068
)
6169

6270
with workflow.unsafe.imports_passed_through():

temporalio/contrib/openai_agents/_temporal_model_stub.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
from __future__ import annotations
22

33
import logging
4+
from datetime import timedelta
5+
from typing import Optional
46

57
from temporalio import workflow
8+
from temporalio.common import Priority, RetryPolicy
9+
from temporalio.contrib.openai_agents.model_parameters import ModelActivityParameters
10+
from temporalio.workflow import ActivityCancellationType, VersioningIntent
611

712
logger = logging.getLogger(__name__)
813

@@ -41,9 +46,14 @@
4146
class _TemporalModelStub(Model):
4247
"""A stub that allows invoking models as Temporal activities."""
4348

44-
def __init__(self, model_name: Optional[str], **kwargs) -> None:
49+
def __init__(
50+
self,
51+
model_name: Optional[str],
52+
*,
53+
model_params: ModelActivityParameters,
54+
) -> None:
4555
self.model_name = model_name
46-
self.kwargs = kwargs
56+
self.model_params = model_params
4757

4858
async def get_response(
4959
self,
@@ -141,11 +151,20 @@ def make_tool_info(tool: Tool) -> ToolInput:
141151
previous_response_id=previous_response_id,
142152
prompt=prompt,
143153
)
154+
144155
return await workflow.execute_activity_method(
145156
ModelActivity.invoke_model_activity,
146157
activity_input,
147-
summary=get_summary(input),
148-
**self.kwargs,
158+
summary=self.model_params.summary_override or get_summary(input),
159+
task_queue=self.model_params.task_queue,
160+
schedule_to_close_timeout=self.model_params.schedule_to_close_timeout,
161+
schedule_to_start_timeout=self.model_params.schedule_to_start_timeout,
162+
start_to_close_timeout=self.model_params.start_to_close_timeout,
163+
heartbeat_timeout=self.model_params.heartbeat_timeout,
164+
retry_policy=self.model_params.retry_policy,
165+
cancellation_type=self.model_params.cancellation_type,
166+
versioning_intent=self.model_params.versioning_intent,
167+
priority=self.model_params.priority,
149168
)
150169

151170
def stream_response(
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""Parameters for configuring Temporal activity execution for model calls."""
2+
3+
from dataclasses import dataclass
4+
from datetime import timedelta
5+
from typing import Optional
6+
7+
from temporalio.common import Priority, RetryPolicy
8+
from temporalio.workflow import ActivityCancellationType, VersioningIntent
9+
10+
11+
@dataclass
12+
class ModelActivityParameters:
13+
"""Parameters for configuring Temporal activity execution for model calls.
14+
15+
This class encapsulates all the parameters that can be used to configure
16+
how Temporal activities are executed when making model calls through the
17+
OpenAI Agents integration.
18+
"""
19+
20+
task_queue: Optional[str] = None
21+
"""Specific task queue to use for model activities."""
22+
23+
schedule_to_close_timeout: Optional[timedelta] = None
24+
"""Maximum time from scheduling to completion."""
25+
26+
schedule_to_start_timeout: Optional[timedelta] = None
27+
"""Maximum time from scheduling to starting."""
28+
29+
start_to_close_timeout: Optional[timedelta] = None
30+
"""Maximum time for the activity to complete."""
31+
32+
heartbeat_timeout: Optional[timedelta] = None
33+
"""Maximum time between heartbeats."""
34+
35+
retry_policy: Optional[RetryPolicy] = None
36+
"""Policy for retrying failed activities."""
37+
38+
cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL
39+
"""How the activity handles cancellation."""
40+
41+
versioning_intent: Optional[VersioningIntent] = None
42+
"""Versioning intent for the activity."""
43+
44+
summary_override: Optional[str] = None
45+
"""Summary for the activity execution."""
46+
47+
priority: Priority = Priority.default
48+
"""Priority for the activity execution."""

temporalio/contrib/openai_agents/temporal_openai_agents.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,27 @@
11
"""Initialize Temporal OpenAI Agents overrides."""
22

33
from contextlib import contextmanager
4+
from datetime import timedelta
45
from typing import Optional
56

67
from agents import set_trace_provider
7-
from agents.run import AgentRunner, get_default_agent_runner, set_default_agent_runner
8-
from agents.tracing import TraceProvider, get_trace_provider
8+
from agents.run import get_default_agent_runner, set_default_agent_runner
9+
from agents.tracing import get_trace_provider
910
from agents.tracing.provider import DefaultTraceProvider
1011

12+
from temporalio.common import Priority, RetryPolicy
1113
from temporalio.contrib.openai_agents._openai_runner import TemporalOpenAIRunner
1214
from temporalio.contrib.openai_agents._temporal_trace_provider import (
1315
TemporalTraceProvider,
1416
)
17+
from temporalio.contrib.openai_agents.model_parameters import ModelActivityParameters
18+
from temporalio.workflow import ActivityCancellationType, VersioningIntent
1519

1620

1721
@contextmanager
18-
def set_open_ai_agent_temporal_overrides(**kwargs):
22+
def set_open_ai_agent_temporal_overrides(
23+
model_params: ModelActivityParameters,
24+
):
1925
"""Configure Temporal-specific overrides for OpenAI agents.
2026
2127
.. warning::
@@ -33,34 +39,26 @@ def set_open_ai_agent_temporal_overrides(**kwargs):
3339
3. Restoring previous settings when the context exits
3440
3541
Args:
36-
**kwargs: Additional arguments to pass to the TemporalOpenAIRunner constructor.
37-
These arguments are forwarded to workflow.execute_activity_method when
38-
executing model calls. Common options include:
39-
- start_to_close_timeout: Maximum time for the activity to complete
40-
- schedule_to_close_timeout: Maximum time from scheduling to completion
41-
- retry_policy: Policy for retrying failed activities
42-
- task_queue: Specific task queue to use for model activities
43-
44-
Example usage:
45-
with set_open_ai_agent_temporal_overrides(
46-
start_to_close_timeout=timedelta(seconds=30),
47-
retry_policy=RetryPolicy(maximum_attempts=3)
48-
):
49-
# Initialize Temporal client and worker here
50-
client = await Client.connect("localhost:7233")
51-
worker = Worker(client, task_queue="my-task-queue")
52-
await worker.run()
42+
model_params: Configuration parameters for Temporal activity execution of model calls.
5343
5444
Returns:
5545
A context manager that yields the configured TemporalTraceProvider.
5646
5747
"""
48+
if (
49+
not model_params.start_to_close_timeout
50+
and not model_params.schedule_to_close_timeout
51+
):
52+
raise ValueError(
53+
"Activity must have start_to_close_timeout or schedule_to_close_timeout"
54+
)
55+
5856
previous_runner = get_default_agent_runner()
5957
previous_trace_provider = get_trace_provider()
6058
provider = TemporalTraceProvider()
6159

6260
try:
63-
set_default_agent_runner(TemporalOpenAIRunner(**kwargs))
61+
set_default_agent_runner(TemporalOpenAIRunner(model_params))
6462
set_trace_provider(provider)
6563
yield provider
6664
finally:

temporalio/contrib/openai_agents/temporal_tools.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,33 @@
11
"""Support for using Temporal activities as OpenAI agents tools."""
22

3-
from typing import Any, Callable
3+
from datetime import timedelta
4+
from typing import Any, Callable, Optional
45

56
from temporalio import activity, workflow
7+
from temporalio.common import Priority, RetryPolicy
68
from temporalio.exceptions import ApplicationError
7-
from temporalio.workflow import unsafe
9+
from temporalio.workflow import ActivityCancellationType, VersioningIntent, unsafe
810

911
with unsafe.imports_passed_through():
1012
from agents import FunctionTool, RunContextWrapper, Tool
1113
from agents.function_schema import function_schema
1214

1315

14-
def activity_as_tool(fn: Callable, **kwargs) -> Tool:
16+
def activity_as_tool(
17+
fn: Callable,
18+
*,
19+
task_queue: Optional[str] = None,
20+
schedule_to_close_timeout: Optional[timedelta] = None,
21+
schedule_to_start_timeout: Optional[timedelta] = None,
22+
start_to_close_timeout: Optional[timedelta] = None,
23+
heartbeat_timeout: Optional[timedelta] = None,
24+
retry_policy: Optional[RetryPolicy] = None,
25+
cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL,
26+
activity_id: Optional[str] = None,
27+
versioning_intent: Optional[VersioningIntent] = None,
28+
summary: Optional[str] = None,
29+
priority: Priority = Priority.default,
30+
) -> Tool:
1531
"""Convert a single Temporal activity function to an OpenAI agent tool.
1632
1733
.. warning::
@@ -25,16 +41,7 @@ def activity_as_tool(fn: Callable, **kwargs) -> Tool:
2541
2642
Args:
2743
fn: A Temporal activity function to convert to a tool.
28-
**kwargs: Additional arguments to pass to workflow.execute_activity.
29-
These arguments configure how the activity is executed. Common options include:
30-
- start_to_close_timeout: Maximum time for the activity to complete
31-
- schedule_to_close_timeout: Maximum time from scheduling to completion
32-
- schedule_to_start_timeout: Maximum time from scheduling to starting
33-
- heartbeat_timeout: Maximum time between heartbeats
34-
- retry_policy: Policy for retrying failed activities
35-
- task_queue: Specific task queue to use for this activity
36-
- cancellation_type: How the activity handles cancellation
37-
- workflow_id_reuse_policy: Policy for workflow ID reuse
44+
For other arguments, refer to :py:mod:`workflow` :py:meth:`start_activity`
3845
3946
Returns:
4047
An OpenAI agent tool that wraps the provided activity.
@@ -69,7 +76,17 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
6976
await workflow.execute_activity(
7077
fn,
7178
input,
72-
**kwargs,
79+
task_queue=task_queue,
80+
schedule_to_close_timeout=schedule_to_close_timeout,
81+
schedule_to_start_timeout=schedule_to_start_timeout,
82+
start_to_close_timeout=start_to_close_timeout,
83+
heartbeat_timeout=heartbeat_timeout,
84+
retry_policy=retry_policy,
85+
cancellation_type=cancellation_type,
86+
activity_id=activity_id,
87+
versioning_intent=versioning_intent,
88+
summary=summary,
89+
priority=priority,
7390
)
7491
)
7592
except Exception:

tests/contrib/test_openai.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from temporalio.contrib.openai_agents.invoke_model_activity import (
1212
ModelActivity,
1313
)
14+
from temporalio.contrib.openai_agents.model_parameters import ModelActivityParameters
1415
from temporalio.contrib.openai_agents.open_ai_data_converter import (
1516
open_ai_data_converter,
1617
)
@@ -144,9 +145,8 @@ async def test_hello_world_agent(client: Client):
144145
new_config["data_converter"] = open_ai_data_converter
145146
client = Client(**new_config)
146147

147-
with set_open_ai_agent_temporal_overrides(
148-
start_to_close_timeout=timedelta(seconds=10)
149-
):
148+
model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=10))
149+
with set_open_ai_agent_temporal_overrides(model_params):
150150
model_activity = ModelActivity(
151151
TestProvider(
152152
TestHelloModel( # type: ignore
@@ -242,9 +242,8 @@ async def test_tool_workflow(client: Client):
242242
new_config["data_converter"] = open_ai_data_converter
243243
client = Client(**new_config)
244244

245-
with set_open_ai_agent_temporal_overrides(
246-
start_to_close_timeout=timedelta(seconds=10)
247-
):
245+
model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=10))
246+
with set_open_ai_agent_temporal_overrides(model_params):
248247
model_activity = ModelActivity(
249248
TestProvider(
250249
TestWeatherModel( # type: ignore
@@ -464,9 +463,8 @@ async def test_research_workflow(client: Client):
464463
global response_index
465464
response_index = 0
466465

467-
with set_open_ai_agent_temporal_overrides(
468-
start_to_close_timeout=timedelta(seconds=10)
469-
):
466+
model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=10))
467+
with set_open_ai_agent_temporal_overrides(model_params):
470468
model_activity = ModelActivity(
471469
TestProvider(
472470
TestResearchModel( # type: ignore
@@ -675,9 +673,8 @@ async def test_agents_as_tools_workflow(client: Client):
675673
new_config["data_converter"] = open_ai_data_converter
676674
client = Client(**new_config)
677675

678-
with set_open_ai_agent_temporal_overrides(
679-
start_to_close_timeout=timedelta(seconds=10)
680-
):
676+
model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=10))
677+
with set_open_ai_agent_temporal_overrides(model_params):
681678
model_activity = ModelActivity(
682679
TestProvider(
683680
AgentAsToolsModel( # type: ignore
@@ -1033,9 +1030,8 @@ async def test_customer_service_workflow(client: Client):
10331030

10341031
questions = ["Hello", "Book me a flight to PDX", "11111", "Any window seat"]
10351032

1036-
with set_open_ai_agent_temporal_overrides(
1037-
start_to_close_timeout=timedelta(seconds=10)
1038-
):
1033+
model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=10))
1034+
with set_open_ai_agent_temporal_overrides(model_params):
10391035
model_activity = ModelActivity(
10401036
TestProvider(
10411037
CustomerServiceModel( # type: ignore

0 commit comments

Comments
 (0)