Skip to content

Commit ecfb72f

Browse files
authored
OpenAI tool context (#942)
* Toying with context input to tools * Doc update * Remove import
1 parent e2b2337 commit ecfb72f

File tree

2 files changed

+50
-4
lines changed

2 files changed

+50
-4
lines changed

temporalio/contrib/openai_agents/temporal_tools.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def activity_as_tool(
4141
This function takes a Temporal activity function and converts it into an
4242
OpenAI agent tool that can be used by the agent to execute the activity
4343
during workflow execution. The tool will automatically handle the conversion
44-
of inputs and outputs between the agent and the activity.
44+
of inputs and outputs between the agent and the activity. Note that if you take a context,
45+
mutation will not be persisted, as the activity may not be running in the same location.
4546
4647
Args:
4748
fn: A Temporal activity function to convert to a tool.
@@ -85,6 +86,11 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
8586

8687
# Activities don't support keyword only arguments, so we can ignore the kwargs_dict return
8788
args, _ = schema.to_call_args(schema.params_pydantic_model(**json_data))
89+
90+
# Add the context to the arguments if it takes that
91+
if schema.takes_context:
92+
args = [ctx] + args
93+
8894
result = await workflow.execute_activity(
8995
fn,
9096
args=args,

tests/contrib/openai_agents/test_openai.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,14 @@ async def get_weather_object(input: WeatherInput) -> Weather:
221221
)
222222

223223

224+
@activity.defn
225+
async def get_weather_context(ctx: RunContextWrapper[str], city: str) -> Weather:
226+
"""
227+
Get the weather for a given city.
228+
"""
229+
return Weather(city=city, temperature_range="14-20C", conditions=ctx.context)
230+
231+
224232
class TestWeatherModel(TestModel):
225233
responses = [
226234
ModelResponse(
@@ -265,6 +273,20 @@ class TestWeatherModel(TestModel):
265273
usage=Usage(),
266274
response_id=None,
267275
),
276+
ModelResponse(
277+
output=[
278+
ResponseFunctionToolCall(
279+
arguments='{"city":"Tokyo"}',
280+
call_id="call",
281+
name="get_weather_context",
282+
type="function_call",
283+
id="id",
284+
status="completed",
285+
)
286+
],
287+
usage=Usage(),
288+
response_id=None,
289+
),
268290
ModelResponse(
269291
output=[
270292
ResponseOutputMessage(
@@ -304,9 +326,14 @@ async def run(self, question: str) -> str:
304326
activity_as_tool(
305327
get_weather_country, start_to_close_timeout=timedelta(seconds=10)
306328
),
329+
activity_as_tool(
330+
get_weather_context, start_to_close_timeout=timedelta(seconds=10)
331+
),
307332
],
308333
) # type: Agent
309-
result = await Runner.run(starting_agent=agent, input=question)
334+
result = await Runner.run(
335+
starting_agent=agent, input=question, context="Stormy"
336+
)
310337
return result.final_output
311338

312339

@@ -337,6 +364,7 @@ async def test_tool_workflow(client: Client, use_local_model: bool):
337364
get_weather,
338365
get_weather_object,
339366
get_weather_country,
367+
get_weather_context,
340368
],
341369
interceptors=[OpenAIAgentsTracingInterceptor()],
342370
) as worker:
@@ -357,7 +385,7 @@ async def test_tool_workflow(client: Client, use_local_model: bool):
357385
if e.HasField("activity_task_completed_event_attributes"):
358386
events.append(e)
359387

360-
assert len(events) == 7
388+
assert len(events) == 9
361389
assert (
362390
"function_call"
363391
in events[0]
@@ -395,11 +423,23 @@ async def test_tool_workflow(client: Client, use_local_model: bool):
395423
.data.decode()
396424
)
397425
assert (
398-
"Test weather result"
426+
"function_call"
399427
in events[6]
400428
.activity_task_completed_event_attributes.result.payloads[0]
401429
.data.decode()
402430
)
431+
assert (
432+
"Stormy"
433+
in events[7]
434+
.activity_task_completed_event_attributes.result.payloads[0]
435+
.data.decode()
436+
)
437+
assert (
438+
"Test weather result"
439+
in events[8]
440+
.activity_task_completed_event_attributes.result.payloads[0]
441+
.data.decode()
442+
)
403443

404444

405445
class TestPlannerModel(OpenAIResponsesModel):

0 commit comments

Comments
 (0)