From ab89a650d293febf746d8c6e52e6d1c291ad1380 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 10 Jul 2025 09:36:01 -0700 Subject: [PATCH 1/2] Openai Agents plugin PoC --- temporalio/contrib/openai_agents/__init__.py | 2 + .../openai_agents/temporal_openai_agents.py | 47 +- temporalio/worker/_worker.py | 44 +- tests/contrib/openai_agents/test_openai.py | 734 +++++++++--------- 4 files changed, 432 insertions(+), 395 deletions(-) diff --git a/temporalio/contrib/openai_agents/__init__.py b/temporalio/contrib/openai_agents/__init__.py index 027ad44ad..2c20effc7 100644 --- a/temporalio/contrib/openai_agents/__init__.py +++ b/temporalio/contrib/openai_agents/__init__.py @@ -14,6 +14,7 @@ OpenAIAgentsTracingInterceptor, ) from temporalio.contrib.openai_agents.temporal_openai_agents import ( + Plugin, TestModel, TestModelProvider, set_open_ai_agent_temporal_overrides, @@ -21,6 +22,7 @@ ) __all__ = [ + "Plugin", "ModelActivity", "ModelActivityParameters", "workflow", diff --git a/temporalio/contrib/openai_agents/temporal_openai_agents.py b/temporalio/contrib/openai_agents/temporal_openai_agents.py index b9bf57499..9b574d708 100644 --- a/temporalio/contrib/openai_agents/temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/temporal_openai_agents.py @@ -3,10 +3,9 @@ import json from contextlib import contextmanager from datetime import timedelta -from typing import Any, AsyncIterator, Callable, Optional, Union, overload +from typing import Any, AsyncIterator, Callable, Optional, Union from agents import ( - Agent, AgentOutputSchemaBase, Handoff, Model, @@ -19,31 +18,34 @@ TResponseInputItem, set_trace_provider, ) -from agents.function_schema import DocstringStyle, function_schema +from agents.function_schema import function_schema from agents.items import TResponseStreamEvent from agents.run import get_default_agent_runner, set_default_agent_runner from agents.tool import ( FunctionTool, - ToolErrorFunction, - ToolFunction, - ToolParams, - default_tool_error_function, - function_tool, ) from agents.tracing import get_trace_provider from agents.tracing.provider import DefaultTraceProvider -from agents.util._types import MaybeAwaitable from openai.types.responses import ResponsePromptParam +import temporalio.client +import temporalio.worker from temporalio import activity from temporalio import workflow as temporal_workflow +from temporalio.client import ClientConfig from temporalio.common import Priority, RetryPolicy +from temporalio.contrib.openai_agents import ( + ModelActivity, + OpenAIAgentsTracingInterceptor, +) from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters from temporalio.contrib.openai_agents._openai_runner import TemporalOpenAIRunner from temporalio.contrib.openai_agents._temporal_trace_provider import ( TemporalTraceProvider, ) +from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.exceptions import ApplicationError, TemporalError +from temporalio.worker import Worker, WorkerConfig from temporalio.workflow import ActivityCancellationType, VersioningIntent @@ -154,6 +156,33 @@ def stream_response( raise NotImplementedError() +class Plugin(temporalio.client.Plugin, temporalio.worker.Plugin): + def __init__( + self, + model_params: Optional[ModelActivityParameters] = None, + model_provider: Optional[ModelProvider] = None, + ) -> None: + self._model_params = model_params + self._model_provider = model_provider + + def on_create_client(self, config: ClientConfig) -> ClientConfig: + config["data_converter"] = pydantic_data_converter + return super().on_create_client(config) + + def on_create_worker(self, config: WorkerConfig) -> WorkerConfig: + config["interceptors"] = list(config.get("interceptors") or []) + [ + OpenAIAgentsTracingInterceptor() + ] + config["activities"] = list(config.get("activities") or []) + [ + ModelActivity(self._model_provider).invoke_model_activity + ] + return super().on_create_worker(config) + + async def run_worker(self, worker: Worker) -> None: + with set_open_ai_agent_temporal_overrides(self._model_params): + await super().run_worker(worker) + + class ToolSerializationError(TemporalError): """Error that occurs when a tool output could not be serialized.""" diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index bc9d07353..fca8a35a4 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -385,7 +385,8 @@ def __init__( ) plugins_from_client = cast( - List[Plugin], [p for p in client.config()["plugins"] if isinstance(p, Plugin)] + List[Plugin], + [p for p in client.config()["plugins"] if isinstance(p, Plugin)], ) plugins = plugins_from_client + list(plugins) @@ -402,7 +403,11 @@ def _init_from_config(self, config: WorkerConfig): self._config = config # TODO(nexus-preview): max_concurrent_nexus_tasks / tuner support - if not (config["activities"] or config["nexus_service_handlers"] or config["workflows"]): + if not ( + config["activities"] + or config["nexus_service_handlers"] + or config["workflows"] + ): raise ValueError( "At least one activity, Nexus service, or workflow must be specified" ) @@ -410,7 +415,9 @@ def _init_from_config(self, config: WorkerConfig): raise ValueError( "build_id must be specified when use_worker_versioning is True" ) - if config["deployment_config"] and (config["build_id"] or config["use_worker_versioning"]): + if config["deployment_config"] and ( + config["build_id"] or config["use_worker_versioning"] + ): raise ValueError( "deployment_config cannot be used with build_id or use_worker_versioning" ) @@ -506,9 +513,13 @@ def check_activity(activity): unsandboxed_workflow_runner=config["unsandboxed_workflow_runner"], data_converter=client_config["data_converter"], interceptors=interceptors, - workflow_failure_exception_types=config["workflow_failure_exception_types"], + workflow_failure_exception_types=config[ + "workflow_failure_exception_types" + ], debug_mode=config["debug_mode"], - disable_eager_activity_execution=config["disable_eager_activity_execution"], + disable_eager_activity_execution=config[ + "disable_eager_activity_execution" + ], metric_meter=self._runtime.metric_meter, on_eviction_hook=None, disable_safe_eviction=config["disable_safe_workflow_eviction"], @@ -519,7 +530,7 @@ def check_activity(activity): ) tuner = config["tuner"] - if config["tuner"] is not None: + if tuner is not None: if ( config["max_concurrent_workflow_tasks"] or config["max_concurrent_activities"] @@ -540,9 +551,9 @@ def check_activity(activity): versioning_strategy: temporalio.bridge.worker.WorkerVersioningStrategy if config["deployment_config"]: - versioning_strategy = ( - config["deployment_config"]._to_bridge_worker_deployment_options() - ) + versioning_strategy = config[ + "deployment_config" + ]._to_bridge_worker_deployment_options() elif config["use_worker_versioning"]: build_id = config["build_id"] or load_default_build_id() versioning_strategy = ( @@ -586,9 +597,11 @@ def check_activity(activity): # We have to disable remote activities if a user asks _or_ if we # are not running an activity worker at all. Otherwise shutdown # will not proceed properly. - no_remote_activities=config["no_remote_activities"] or not config["activities"], + no_remote_activities=config["no_remote_activities"] + or not config["activities"], sticky_queue_schedule_to_start_timeout_millis=int( - 1000 * config["sticky_queue_schedule_to_start_timeout"].total_seconds() + 1000 + * config["sticky_queue_schedule_to_start_timeout"].total_seconds() ), max_heartbeat_throttle_interval_millis=int( 1000 * config["max_heartbeat_throttle_interval"].total_seconds() @@ -597,7 +610,9 @@ def check_activity(activity): 1000 * config["default_heartbeat_throttle_interval"].total_seconds() ), max_activities_per_second=config["max_activities_per_second"], - max_task_queue_activities_per_second=config["max_task_queue_activities_per_second"], + max_task_queue_activities_per_second=config[ + "max_task_queue_activities_per_second" + ], graceful_shutdown_period_millis=int( 1000 * config["graceful_shutdown_timeout"].total_seconds() ), @@ -614,7 +629,9 @@ def check_activity(activity): versioning_strategy=versioning_strategy, workflow_task_poller_behavior=workflow_task_poller_behavior._to_bridge(), activity_task_poller_behavior=activity_task_poller_behavior._to_bridge(), - nexus_task_poller_behavior=config["nexus_task_poller_behavior"]._to_bridge(), + nexus_task_poller_behavior=config[ + "nexus_task_poller_behavior" + ]._to_bridge(), ), ) @@ -902,6 +919,7 @@ class WorkerConfig(TypedDict, total=False): activity_task_poller_behavior: PollerBehavior nexus_task_poller_behavior: PollerBehavior + @dataclass class WorkerDeploymentConfig: """Options for configuring the Worker Versioning feature. diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index cfc74eb6b..d2224761c 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -124,26 +124,28 @@ async def test_hello_world_agent(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.Plugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(TestHelloModel()) + if use_local_model + else None, + ) + ] client = Client(**new_config) - model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) - with set_open_ai_agent_temporal_overrides(model_params): - model_activity = ModelActivity( - TestModelProvider(TestHelloModel()) if use_local_model else None + async with new_worker(client, HelloWorldAgent) as worker: + result = await client.execute_workflow( + HelloWorldAgent.run, + "Tell me about recursion in programming.", + id=f"hello-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=5), ) - async with new_worker( - client, HelloWorldAgent, activities=[model_activity.invoke_model_activity] - ) as worker: - result = await client.execute_workflow( - HelloWorldAgent.run, - "Tell me about recursion in programming.", - id=f"hello-workflow-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=5), - ) - if use_local_model: - assert result == "test" + if use_local_model: + assert result == "test" @dataclass @@ -305,103 +307,100 @@ async def test_tool_workflow(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.Plugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(TestWeatherModel()) + if use_local_model + else None, + ) + ] client = Client(**new_config) - model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) - with set_open_ai_agent_temporal_overrides(model_params): - model_activity = ModelActivity( - TestModelProvider( - TestWeatherModel( # type: ignore - ) - ) - if use_local_model - else None + async with new_worker( + client, + ToolsWorkflow, + activities=[ + get_weather, + get_weather_object, + get_weather_country, + get_weather_context, + ], + ) as worker: + workflow_handle = await client.start_workflow( + ToolsWorkflow.run, + "What is the weather in Tokio?", + id=f"tools-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=30), ) - async with new_worker( - client, - ToolsWorkflow, - activities=[ - model_activity.invoke_model_activity, - get_weather, - get_weather_object, - get_weather_country, - get_weather_context, - ], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - workflow_handle = await client.start_workflow( - ToolsWorkflow.run, - "What is the weather in Tokio?", - id=f"tools-workflow-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=30), + result = await workflow_handle.result() + + if use_local_model: + assert result == "Test weather result" + + events = [] + async for e in workflow_handle.fetch_history_events(): + if e.HasField("activity_task_completed_event_attributes"): + events.append(e) + + assert len(events) == 9 + assert ( + "function_call" + in events[0] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Sunny with wind" + in events[1] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "function_call" + in events[2] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Sunny with wind" + in events[3] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "function_call" + in events[4] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Sunny with wind" + in events[5] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "function_call" + in events[6] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Stormy" + in events[7] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Test weather result" + in events[8] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() ) - result = await workflow_handle.result() - - if use_local_model: - assert result == "Test weather result" - - events = [] - async for e in workflow_handle.fetch_history_events(): - if e.HasField("activity_task_completed_event_attributes"): - events.append(e) - - assert len(events) == 9 - assert ( - "function_call" - in events[0] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Sunny with wind" - in events[1] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "function_call" - in events[2] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Sunny with wind" - in events[3] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "function_call" - in events[4] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Sunny with wind" - in events[5] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "function_call" - in events[6] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Stormy" - in events[7] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Test weather result" - in events[8] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) @no_type_check @@ -491,63 +490,60 @@ async def test_research_workflow(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.Plugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(TestResearchModel()) + if use_local_model + else None, + ) + ] client = Client(**new_config) - global response_index - response_index = 0 - - model_params = ModelActivityParameters( - start_to_close_timeout=timedelta(seconds=120) - ) - with set_open_ai_agent_temporal_overrides(model_params): - model_activity = ModelActivity( - TestModelProvider(TestResearchModel()) if use_local_model else None + async with new_worker( + client, + ResearchWorkflow, + ) as worker: + workflow_handle = await client.start_workflow( + ResearchWorkflow.run, + "Caribbean vacation spots in April, optimizing for surfing, hiking and water sports", + id=f"research-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=120), ) - async with new_worker( - client, - ResearchWorkflow, - activities=[model_activity.invoke_model_activity, get_weather], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - workflow_handle = await client.start_workflow( - ResearchWorkflow.run, - "Caribbean vacation spots in April, optimizing for surfing, hiking and water sports", - id=f"research-workflow-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=120), + result = await workflow_handle.result() + + if use_local_model: + assert result == "report" + + events = [] + async for e in workflow_handle.fetch_history_events(): + if e.HasField("activity_task_completed_event_attributes"): + events.append(e) + + assert len(events) == 12 + assert ( + '"type":"output_text"' + in events[0] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() ) - result = await workflow_handle.result() - - if use_local_model: - assert result == "report" - - events = [] - async for e in workflow_handle.fetch_history_events(): - if e.HasField("activity_task_completed_event_attributes"): - events.append(e) - - assert len(events) == 12 + for i in range(1, 11): assert ( - '"type":"output_text"' - in events[0] + "web_search_call" + in events[i] .activity_task_completed_event_attributes.result.payloads[0] .data.decode() ) - for i in range(1, 11): - assert ( - "web_search_call" - in events[i] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - '"type":"output_text"' - in events[11] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) + assert ( + '"type":"output_text"' + in events[11] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) def orchestrator_agent() -> Agent: @@ -708,67 +704,64 @@ async def test_agents_as_tools_workflow(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.Plugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(AgentAsToolsModel()) + if use_local_model + else None, + ) + ] client = Client(**new_config) - model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) - with set_open_ai_agent_temporal_overrides(model_params): - model_activity = ModelActivity( - TestModelProvider( - AgentAsToolsModel( # type: ignore - ) - ) - if use_local_model - else None + async with new_worker( + client, + AgentsAsToolsWorkflow, + ) as worker: + workflow_handle = await client.start_workflow( + AgentsAsToolsWorkflow.run, + "Translate to Spanish: 'I am full'", + id=f"agents-as-tools-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=30), ) - async with new_worker( - client, - AgentsAsToolsWorkflow, - activities=[model_activity.invoke_model_activity], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - workflow_handle = await client.start_workflow( - AgentsAsToolsWorkflow.run, - "Translate to Spanish: 'I am full'", - id=f"agents-as-tools-workflow-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=30), + result = await workflow_handle.result() + + if use_local_model: + assert result == 'The translation to Spanish is: "Estoy lleno."' + + events = [] + async for e in workflow_handle.fetch_history_events(): + if e.HasField("activity_task_completed_event_attributes"): + events.append(e) + + assert len(events) == 4 + assert ( + "function_call" + in events[0] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Estoy lleno" + in events[1] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "The translation to Spanish is:" + in events[2] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "The translation to Spanish is:" + in events[3] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() ) - result = await workflow_handle.result() - - if use_local_model: - assert result == 'The translation to Spanish is: "Estoy lleno."' - - events = [] - async for e in workflow_handle.fetch_history_events(): - if e.HasField("activity_task_completed_event_attributes"): - events.append(e) - - assert len(events) == 4 - assert ( - "function_call" - in events[0] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Estoy lleno" - in events[1] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "The translation to Spanish is:" - in events[2] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "The translation to Spanish is:" - in events[3] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) class AirlineAgentContext(BaseModel): @@ -1063,97 +1056,94 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.Plugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(CustomerServiceModel()) + if use_local_model + else None, + ) + ] client = Client(**new_config) questions = ["Hello", "Book me a flight to PDX", "11111", "Any window seat"] - model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) - with set_open_ai_agent_temporal_overrides(model_params): - model_activity = ModelActivity( - TestModelProvider( - CustomerServiceModel( # type: ignore - ) - ) - if use_local_model - else None + async with new_worker( + client, + CustomerServiceWorkflow, + ) as worker: + workflow_handle = await client.start_workflow( + CustomerServiceWorkflow.run, + id=f"customer-service-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=30), ) - async with new_worker( - client, - CustomerServiceWorkflow, - activities=[model_activity.invoke_model_activity], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - workflow_handle = await client.start_workflow( - CustomerServiceWorkflow.run, - id=f"customer-service-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=30), + history: list[Any] = [] + for q in questions: + message_input = ProcessUserMessageInput( + user_input=q, chat_length=len(history) + ) + new_history = await workflow_handle.execute_update( + CustomerServiceWorkflow.process_user_message, message_input + ) + history.extend(new_history) + print(*new_history, sep="\n") + + await workflow_handle.cancel() + + with pytest.raises(WorkflowFailureError) as err: + await workflow_handle.result() + assert isinstance(err.value.cause, CancelledError) + + if use_local_model: + events = [] + async for e in WorkflowHandle( + client, + workflow_handle.id, + run_id=workflow_handle._first_execution_run_id, + ).fetch_history_events(): + if e.HasField("activity_task_completed_event_attributes"): + events.append(e) + + assert len(events) == 6 + assert ( + "Hi there! How can I assist you today?" + in events[0] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "transfer_to_seat_booking_agent" + in events[1] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Could you please provide your confirmation number?" + in events[2] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Thanks! What seat number would you like to change to?" + in events[3] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "update_seat" + in events[4] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Your seat has been updated to a window seat. If there's anything else you need, feel free to let me know!" + in events[5] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() ) - history: list[Any] = [] - for q in questions: - message_input = ProcessUserMessageInput( - user_input=q, chat_length=len(history) - ) - new_history = await workflow_handle.execute_update( - CustomerServiceWorkflow.process_user_message, message_input - ) - history.extend(new_history) - print(*new_history, sep="\n") - - await workflow_handle.cancel() - - with pytest.raises(WorkflowFailureError) as err: - await workflow_handle.result() - assert isinstance(err.value.cause, CancelledError) - - if use_local_model: - events = [] - async for e in WorkflowHandle( - client, - workflow_handle.id, - run_id=workflow_handle._first_execution_run_id, - ).fetch_history_events(): - if e.HasField("activity_task_completed_event_attributes"): - events.append(e) - - assert len(events) == 6 - assert ( - "Hi there! How can I assist you today?" - in events[0] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "transfer_to_seat_booking_agent" - in events[1] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Could you please provide your confirmation number?" - in events[2] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Thanks! What seat number would you like to change to?" - in events[3] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "update_seat" - in events[4] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Your seat has been updated to a window seat. If there's anything else you need, feel free to let me know!" - in events[5] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) guardrail_response_index: int = 0 @@ -1356,42 +1346,40 @@ async def test_input_guardrail(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = pydantic_data_converter - client = Client(**new_config) - - model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) - with set_open_ai_agent_temporal_overrides(model_params): - model_activity = ModelActivity( - TestModelProvider( - InputGuardrailModel( # type: ignore - "", openai_client=AsyncOpenAI(api_key="Fake key") - ) + new_config["plugins"] = [ + openai_agents.Plugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider( + InputGuardrailModel("", openai_client=AsyncOpenAI(api_key="Fake key")) ) if use_local_model - else None + else None, ) - async with new_worker( - client, - InputGuardrailWorkflow, - activities=[model_activity.invoke_model_activity], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - workflow_handle = await client.start_workflow( - InputGuardrailWorkflow.run, - [ - "What's the capital of California?", - "Can you help me solve for x: 2x + 5 = 11", - ], - id=f"input-guardrail-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=10), - ) - result = await workflow_handle.result() + ] + client = Client(**new_config) - if use_local_model: - assert len(result) == 2 - assert result[0] == "The capital of California is Sacramento." - assert result[1] == "Sorry, I can't help you with your math homework." + async with new_worker( + client, + InputGuardrailWorkflow, + ) as worker: + workflow_handle = await client.start_workflow( + InputGuardrailWorkflow.run, + [ + "What's the capital of California?", + "Can you help me solve for x: 2x + 5 = 11", + ], + id=f"input-guardrail-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + result = await workflow_handle.result() + + if use_local_model: + assert len(result) == 2 + assert result[0] == "The capital of California is Sacramento." + assert result[1] == "Sorry, I can't help you with your math homework." class OutputGuardrailModel(StaticTestModel): @@ -1473,35 +1461,32 @@ async def test_output_guardrail(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.Plugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(OutputGuardrailModel()) + if use_local_model + else None, + ) + ] client = Client(**new_config) - model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) - with set_open_ai_agent_temporal_overrides(model_params): - model_activity = ModelActivity( - TestModelProvider( - OutputGuardrailModel( # type: ignore - ) - ) - if use_local_model - else None + async with new_worker( + client, + OutputGuardrailWorkflow, + ) as worker: + workflow_handle = await client.start_workflow( + OutputGuardrailWorkflow.run, + id=f"output-guardrail-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), ) - async with new_worker( - client, - OutputGuardrailWorkflow, - activities=[model_activity.invoke_model_activity], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - workflow_handle = await client.start_workflow( - OutputGuardrailWorkflow.run, - id=f"output-guardrail-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=10), - ) - result = await workflow_handle.result() + result = await workflow_handle.result() - if use_local_model: - assert not result + if use_local_model: + assert not result class WorkflowToolModel(StaticTestModel): @@ -1564,21 +1549,24 @@ async def run_tool(self): async def test_workflow_method_tools(client: Client): new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.Plugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(WorkflowToolModel()), + ) + ] client = Client(**new_config) - with set_open_ai_agent_temporal_overrides(): - model_activity = ModelActivity(TestModelProvider(WorkflowToolModel())) - async with new_worker( - client, - WorkflowToolWorkflow, - activities=[model_activity.invoke_model_activity], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - workflow_handle = await client.start_workflow( - WorkflowToolWorkflow.run, - id=f"workflow-tool-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=10), - ) - await workflow_handle.result() + async with new_worker( + client, + WorkflowToolWorkflow, + ) as worker: + workflow_handle = await client.start_workflow( + WorkflowToolWorkflow.run, + id=f"workflow-tool-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + await workflow_handle.result() From 25e98213d7deb9385acfaf7f08614f585a738a43 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 10 Jul 2025 09:38:50 -0700 Subject: [PATCH 2/2] Remove extra import/exports --- temporalio/contrib/openai_agents/__init__.py | 8 -------- .../contrib/openai_agents/temporal_openai_agents.py | 8 ++++---- tests/contrib/openai_agents/test_openai.py | 4 ---- 3 files changed, 4 insertions(+), 16 deletions(-) diff --git a/temporalio/contrib/openai_agents/__init__.py b/temporalio/contrib/openai_agents/__init__.py index 2c20effc7..43636fa17 100644 --- a/temporalio/contrib/openai_agents/__init__.py +++ b/temporalio/contrib/openai_agents/__init__.py @@ -8,26 +8,18 @@ Use with caution in production environments. """ -from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters -from temporalio.contrib.openai_agents._trace_interceptor import ( - OpenAIAgentsTracingInterceptor, -) from temporalio.contrib.openai_agents.temporal_openai_agents import ( Plugin, TestModel, TestModelProvider, - set_open_ai_agent_temporal_overrides, workflow, ) __all__ = [ "Plugin", - "ModelActivity", "ModelActivityParameters", "workflow", - "set_open_ai_agent_temporal_overrides", - "OpenAIAgentsTracingInterceptor", "TestModel", "TestModelProvider", ] diff --git a/temporalio/contrib/openai_agents/temporal_openai_agents.py b/temporalio/contrib/openai_agents/temporal_openai_agents.py index 9b574d708..04d68a5af 100644 --- a/temporalio/contrib/openai_agents/temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/temporal_openai_agents.py @@ -34,15 +34,15 @@ from temporalio import workflow as temporal_workflow from temporalio.client import ClientConfig from temporalio.common import Priority, RetryPolicy -from temporalio.contrib.openai_agents import ( - ModelActivity, - OpenAIAgentsTracingInterceptor, -) +from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters from temporalio.contrib.openai_agents._openai_runner import TemporalOpenAIRunner from temporalio.contrib.openai_agents._temporal_trace_provider import ( TemporalTraceProvider, ) +from temporalio.contrib.openai_agents._trace_interceptor import ( + OpenAIAgentsTracingInterceptor, +) from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.exceptions import ApplicationError, TemporalError from temporalio.worker import Worker, WorkerConfig diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index d2224761c..3fc283e1c 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -50,14 +50,10 @@ from temporalio.client import Client, WorkflowFailureError, WorkflowHandle from temporalio.contrib import openai_agents from temporalio.contrib.openai_agents import ( - ModelActivity, ModelActivityParameters, - OpenAIAgentsTracingInterceptor, TestModel, TestModelProvider, - set_open_ai_agent_temporal_overrides, ) -from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.exceptions import CancelledError from tests.contrib.openai_agents.research_agents.research_manager import ( ResearchManager,