Skip to content

Commit 61d5072

Browse files
authored
Support for more activity tool inputs (#923)
* Processing activity tool input through the function schema's pydantic model * Fix issues with multiple arguments * Undo prompt change * Add explicit application error for invalid json tool input * Change error return * Change error base class * Fix linting * Update core
1 parent 5a95f8e commit 61d5072

File tree

2 files changed

+124
-25
lines changed

2 files changed

+124
-25
lines changed

temporalio/contrib/openai_agents/temporal_tools.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
11
"""Support for using Temporal activities as OpenAI agents tools."""
22

3+
import json
34
from datetime import timedelta
45
from typing import Any, Callable, Optional
56

67
from temporalio import activity, workflow
78
from temporalio.common import Priority, RetryPolicy
8-
from temporalio.exceptions import ApplicationError
9+
from temporalio.exceptions import ApplicationError, TemporalError
910
from temporalio.workflow import ActivityCancellationType, VersioningIntent, unsafe
1011

1112
with unsafe.imports_passed_through():
1213
from agents import FunctionTool, RunContextWrapper, Tool
1314
from agents.function_schema import function_schema
1415

1516

17+
class ToolSerializationError(TemporalError):
18+
"""Error that occurs when a tool output could not be serialized."""
19+
20+
1621
def activity_as_tool(
1722
fn: Callable,
1823
*,
@@ -69,32 +74,40 @@ def activity_as_tool(
6974
"Bare function without tool and activity decorators is not supported",
7075
"invalid_tool",
7176
)
77+
schema = function_schema(fn)
7278

7379
async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
7480
try:
75-
return str(
76-
await workflow.execute_activity(
77-
fn,
78-
input,
79-
task_queue=task_queue,
80-
schedule_to_close_timeout=schedule_to_close_timeout,
81-
schedule_to_start_timeout=schedule_to_start_timeout,
82-
start_to_close_timeout=start_to_close_timeout,
83-
heartbeat_timeout=heartbeat_timeout,
84-
retry_policy=retry_policy,
85-
cancellation_type=cancellation_type,
86-
activity_id=activity_id,
87-
versioning_intent=versioning_intent,
88-
summary=summary,
89-
priority=priority,
90-
)
91-
)
92-
except Exception:
81+
json_data = json.loads(input)
82+
except Exception as e:
9383
raise ApplicationError(
84+
f"Invalid JSON input for tool {schema.name}: {input}"
85+
) from e
86+
87+
# Activities don't support keyword only arguments, so we can ignore the kwargs_dict return
88+
args, _ = schema.to_call_args(schema.params_pydantic_model(**json_data))
89+
result = await workflow.execute_activity(
90+
fn,
91+
args=args,
92+
task_queue=task_queue,
93+
schedule_to_close_timeout=schedule_to_close_timeout,
94+
schedule_to_start_timeout=schedule_to_start_timeout,
95+
start_to_close_timeout=start_to_close_timeout,
96+
heartbeat_timeout=heartbeat_timeout,
97+
retry_policy=retry_policy,
98+
cancellation_type=cancellation_type,
99+
activity_id=activity_id,
100+
versioning_intent=versioning_intent,
101+
summary=summary,
102+
priority=priority,
103+
)
104+
try:
105+
return str(result)
106+
except Exception as e:
107+
raise ToolSerializationError(
94108
"You must return a string representation of the tool output, or something we can call str() on"
95-
)
109+
) from e
96110

97-
schema = function_schema(fn)
98111
return FunctionTool(
99112
name=schema.name,
100113
description=schema.description or "",

tests/contrib/test_openai.py

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,29 @@ async def get_weather(city: str) -> Weather:
219219
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
220220

221221

222+
@activity.defn
223+
async def get_weather_country(city: str, country: str) -> Weather:
224+
"""
225+
Get the weather for a given city in a country.
226+
"""
227+
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
228+
229+
230+
@dataclass
231+
class WeatherInput:
232+
city: str
233+
234+
235+
@activity.defn
236+
async def get_weather_object(input: WeatherInput) -> Weather:
237+
"""
238+
Get the weather for a given city.
239+
"""
240+
return Weather(
241+
city=input.city, temperature_range="14-20C", conditions="Sunny with wind."
242+
)
243+
244+
222245
class TestWeatherModel(TestModel):
223246
responses = [
224247
ModelResponse(
@@ -235,6 +258,34 @@ class TestWeatherModel(TestModel):
235258
usage=Usage(),
236259
response_id=None,
237260
),
261+
ModelResponse(
262+
output=[
263+
ResponseFunctionToolCall(
264+
arguments='{"input":{"city":"Tokyo"}}',
265+
call_id="call",
266+
name="get_weather_object",
267+
type="function_call",
268+
id="id",
269+
status="completed",
270+
)
271+
],
272+
usage=Usage(),
273+
response_id=None,
274+
),
275+
ModelResponse(
276+
output=[
277+
ResponseFunctionToolCall(
278+
arguments='{"city":"Tokyo","country":"Japan"}',
279+
call_id="call",
280+
name="get_weather_country",
281+
type="function_call",
282+
id="id",
283+
status="completed",
284+
)
285+
],
286+
usage=Usage(),
287+
response_id=None,
288+
),
238289
ModelResponse(
239290
output=[
240291
ResponseOutputMessage(
@@ -267,7 +318,13 @@ async def run(self, question: str) -> str:
267318
tools=[
268319
activity_as_tool(
269320
get_weather, start_to_close_timeout=timedelta(seconds=10)
270-
)
321+
),
322+
activity_as_tool(
323+
get_weather_object, start_to_close_timeout=timedelta(seconds=10)
324+
),
325+
activity_as_tool(
326+
get_weather_country, start_to_close_timeout=timedelta(seconds=10)
327+
),
271328
],
272329
) # type: Agent
273330
result = await Runner.run(starting_agent=agent, input=question)
@@ -291,7 +348,12 @@ async def test_tool_workflow(client: Client):
291348
async with new_worker(
292349
client,
293350
ToolsWorkflow,
294-
activities=[model_activity.invoke_model_activity, get_weather],
351+
activities=[
352+
model_activity.invoke_model_activity,
353+
get_weather,
354+
get_weather_object,
355+
get_weather_country,
356+
],
295357
interceptors=[OpenAIAgentsTracingInterceptor()],
296358
) as worker:
297359
workflow_handle = await client.start_workflow(
@@ -309,7 +371,7 @@ async def test_tool_workflow(client: Client):
309371
if e.HasField("activity_task_completed_event_attributes"):
310372
events.append(e)
311373

312-
assert len(events) == 3
374+
assert len(events) == 7
313375
assert (
314376
"function_call"
315377
in events[0]
@@ -323,11 +385,35 @@ async def test_tool_workflow(client: Client):
323385
.data.decode()
324386
)
325387
assert (
326-
"Test weather result"
388+
"function_call"
327389
in events[2]
328390
.activity_task_completed_event_attributes.result.payloads[0]
329391
.data.decode()
330392
)
393+
assert (
394+
"Sunny with wind"
395+
in events[3]
396+
.activity_task_completed_event_attributes.result.payloads[0]
397+
.data.decode()
398+
)
399+
assert (
400+
"function_call"
401+
in events[4]
402+
.activity_task_completed_event_attributes.result.payloads[0]
403+
.data.decode()
404+
)
405+
assert (
406+
"Sunny with wind"
407+
in events[5]
408+
.activity_task_completed_event_attributes.result.payloads[0]
409+
.data.decode()
410+
)
411+
assert (
412+
"Test weather result"
413+
in events[6]
414+
.activity_task_completed_event_attributes.result.payloads[0]
415+
.data.decode()
416+
)
331417

332418

333419
class TestPlannerModel(OpenAIResponsesModel):

0 commit comments

Comments
 (0)