diff --git a/temporalio/contrib/openai_agents/README.md b/temporalio/contrib/openai_agents/README.md index 3edcf294a..05b1a2331 100644 --- a/temporalio/contrib/openai_agents/README.md +++ b/temporalio/contrib/openai_agents/README.md @@ -56,11 +56,7 @@ The first file, `hello_world_workflow.py`, defines an OpenAI agent within a Temp ```python # File: hello_world_workflow.py from temporalio import workflow - -# Trusted imports bypass the Temporal sandbox, which otherwise -# prevents imports which may result in non-deterministic execution. -with workflow.unsafe.imports_passed_through(): - from agents import Agent, Runner +from agents import Agent, Runner @workflow.defn class HelloWorldAgent: @@ -80,11 +76,6 @@ We annotate the `HelloWorldAgent` class with `@workflow.defn` to define a workfl We use the `Agent` class to define a simple agent, one which always responds with haikus. Within the workflow, we start the agent using the `Runner`, as is typical, passing through `prompt` as an argument. -Perhaps the most interesting thing about this code is the `workflow.unsafe.imports_passed_through()` context manager that precedes the OpenAI Agents SDK imports. -This statement tells Temporal to skip sandboxing for these trusted libraries. -This is important because Python's dynamic nature forces Temporal's Python's sandbox to re-validate imports every time a workflow runs, which comes at a performance cost. -The OpenAI Agents SDK also contains certain code that Temporal is not able to validate automatically for determinism. - The second file, `run_worker.py`, launches a Temporal worker. This is a program that connects to the Temporal server and receives work to run, in this case `HelloWorldAgent` invocations. @@ -95,14 +86,13 @@ import asyncio from datetime import timedelta from temporalio.client import Client -from temporalio.contrib.openai_agents.invoke_model_activity import ModelActivity -from temporalio.contrib.openai_agents.model_parameters import ModelActivityParameters -from temporalio.contrib.openai_agents.open_ai_data_converter import open_ai_data_converter -from temporalio.contrib.openai_agents.temporal_openai_agents import set_open_ai_agent_temporal_overrides +from temporalio.contrib.openai_agents import ModelActivity, ModelActivityParameters, set_open_ai_agent_temporal_overrides +from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.worker import Worker from hello_world_workflow import HelloWorldAgent + async def worker_main(): # Configure the OpenAI Agents SDK to use Temporal activities for LLM API calls # and for tool calls. @@ -114,17 +104,17 @@ async def worker_main(): # Use the OpenAI data converter to ensure proper serialization/deserialization client = await Client.connect( "localhost:7233", - data_converter=open_ai_data_converter, + data_converter=pydantic_data_converter, ) - model_activity = ModelActivity(model_provider=None) - worker = Worker( - client, - task_queue="my-task-queue", - workflows=[HelloWorldAgent], - activities=[model_activity.invoke_model_activity], - ) - await worker.run() + worker = Worker( + client, + task_queue="my-task-queue", + workflows=[HelloWorldAgent], + activities=[ModelActivity().invoke_model_activity], + ) + await worker.run() + if __name__ == "__main__": asyncio.run(worker_main()) @@ -132,8 +122,8 @@ if __name__ == "__main__": We wrap the entire `worker_main` function body in the `set_open_ai_agent_temporal_overrides()` context manager. This causes a Temporal activity to be invoked whenever the OpenAI Agents SDK invokes an LLM or calls a tool. -We also pass the `open_ai_data_converter` to the Temporal Client, which ensures proper serialization of OpenAI Agents SDK data. -We create a `ModelActivity` which serves as a generic wrapper for LLM calls, and we register this wrapper's invocation point, `model_activity.invoke_model_activity`, with the worker. +We also pass the `pydantic_data_converter` to the Temporal Client, which ensures proper serialization of pydantic models in OpenAI Agents SDK data. +We create a `ModelActivity` which serves as a generic wrapper for LLM calls, and we register this wrapper's invocation point, `ModelActivity().invoke_model_activity`, with the worker. In order to launch the agent, use the standard Temporal workflow invocation: @@ -144,7 +134,7 @@ import asyncio from temporalio.client import Client from temporalio.common import WorkflowIDReusePolicy -from temporalio.contrib.openai_agents.open_ai_data_converter import open_ai_data_converter +from temporalio.contrib.pydantic import pydantic_data_converter from hello_world_workflow import HelloWorldAgent @@ -152,7 +142,7 @@ async def main(): # Create client connected to server at the given address client = await Client.connect( "localhost:7233", - data_converter=open_ai_data_converter, + data_converter=pydantic_data_converter, ) # Execute a workflow @@ -171,7 +161,7 @@ if __name__ == "__main__": This launcher script executes the Temporal workflow to start the agent. -Note that this basic example works without providing the `open_ai_data_converter` to the Temporal client that executes the workflow, but we include it because more complex uses will generally need it. +Note that this basic example works without providing the `pydantic_data_converter` to the Temporal client that executes the workflow, but we include it because more complex uses will generally need it. ## Using Temporal Activities as OpenAI Agents Tools @@ -186,10 +176,8 @@ We then pass this through the `activity_as_tool` helper function to create an Op from dataclasses import dataclass from datetime import timedelta from temporalio import activity, workflow -from temporalio.contrib.openai_agents.temporal_tools import activity_as_tool - -with workflow.unsafe.imports_passed_through(): - from agents import Agent, Runner +from temporalio.contrib import openai_agents +from agents import Agent, Runner @dataclass class Weather: @@ -210,7 +198,7 @@ class WeatherAgent: name="Weather Assistant", instructions="You are a helpful weather agent.", tools=[ - activity_as_tool( + openai_agents.workflow.activity_as_tool( get_weather, start_to_close_timeout=timedelta(seconds=10) ) diff --git a/temporalio/contrib/openai_agents/__init__.py b/temporalio/contrib/openai_agents/__init__.py index faf025df0..027ad44ad 100644 --- a/temporalio/contrib/openai_agents/__init__.py +++ b/temporalio/contrib/openai_agents/__init__.py @@ -7,3 +7,25 @@ This module is experimental and may change in future versions. Use with caution in production environments. """ + +from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity +from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters +from temporalio.contrib.openai_agents._trace_interceptor import ( + OpenAIAgentsTracingInterceptor, +) +from temporalio.contrib.openai_agents.temporal_openai_agents import ( + TestModel, + TestModelProvider, + set_open_ai_agent_temporal_overrides, + workflow, +) + +__all__ = [ + "ModelActivity", + "ModelActivityParameters", + "workflow", + "set_open_ai_agent_temporal_overrides", + "OpenAIAgentsTracingInterceptor", + "TestModel", + "TestModelProvider", +] diff --git a/temporalio/contrib/openai_agents/invoke_model_activity.py b/temporalio/contrib/openai_agents/_invoke_model_activity.py similarity index 100% rename from temporalio/contrib/openai_agents/invoke_model_activity.py rename to temporalio/contrib/openai_agents/_invoke_model_activity.py diff --git a/temporalio/contrib/openai_agents/model_parameters.py b/temporalio/contrib/openai_agents/_model_parameters.py similarity index 100% rename from temporalio/contrib/openai_agents/model_parameters.py rename to temporalio/contrib/openai_agents/_model_parameters.py diff --git a/temporalio/contrib/openai_agents/_openai_runner.py b/temporalio/contrib/openai_agents/_openai_runner.py index f5f431c81..f43d01388 100644 --- a/temporalio/contrib/openai_agents/_openai_runner.py +++ b/temporalio/contrib/openai_agents/_openai_runner.py @@ -15,8 +15,8 @@ from temporalio import workflow from temporalio.common import Priority, RetryPolicy +from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters from temporalio.contrib.openai_agents._temporal_model_stub import _TemporalModelStub -from temporalio.contrib.openai_agents.model_parameters import ModelActivityParameters from temporalio.workflow import ActivityCancellationType, VersioningIntent diff --git a/temporalio/contrib/openai_agents/_temporal_model_stub.py b/temporalio/contrib/openai_agents/_temporal_model_stub.py index d823e4b09..1092c9ada 100644 --- a/temporalio/contrib/openai_agents/_temporal_model_stub.py +++ b/temporalio/contrib/openai_agents/_temporal_model_stub.py @@ -6,7 +6,7 @@ from temporalio import workflow from temporalio.common import Priority, RetryPolicy -from temporalio.contrib.openai_agents.model_parameters import ModelActivityParameters +from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters from temporalio.workflow import ActivityCancellationType, VersioningIntent logger = logging.getLogger(__name__) @@ -31,7 +31,7 @@ from agents.items import TResponseStreamEvent from openai.types.responses.response_prompt_param import ResponsePromptParam -from temporalio.contrib.openai_agents.invoke_model_activity import ( +from temporalio.contrib.openai_agents._invoke_model_activity import ( ActivityModelInput, AgentOutputSchemaInput, FunctionToolInput, diff --git a/temporalio/contrib/openai_agents/trace_interceptor.py b/temporalio/contrib/openai_agents/_trace_interceptor.py similarity index 100% rename from temporalio/contrib/openai_agents/trace_interceptor.py rename to temporalio/contrib/openai_agents/_trace_interceptor.py diff --git a/temporalio/contrib/openai_agents/open_ai_data_converter.py b/temporalio/contrib/openai_agents/open_ai_data_converter.py deleted file mode 100644 index 59ba76085..000000000 --- a/temporalio/contrib/openai_agents/open_ai_data_converter.py +++ /dev/null @@ -1,11 +0,0 @@ -"""DataConverter that supports conversion of types used by OpenAI Agents SDK. - -These are mostly Pydantic types. Some of them should be explicitly imported. -""" - -from __future__ import annotations - -import temporalio.contrib.pydantic - -open_ai_data_converter = temporalio.contrib.pydantic.pydantic_data_converter -"""DEPRECATED, use temporalio.contrib.pydantic.pydantic_data_converter""" diff --git a/temporalio/contrib/openai_agents/temporal_openai_agents.py b/temporalio/contrib/openai_agents/temporal_openai_agents.py index 7dc4ab2aa..d5c096539 100644 --- a/temporalio/contrib/openai_agents/temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/temporal_openai_agents.py @@ -1,18 +1,50 @@ """Initialize Temporal OpenAI Agents overrides.""" +import json from contextlib import contextmanager -from typing import Optional +from datetime import timedelta +from typing import Any, AsyncIterator, Callable, Optional, Union, overload -from agents import set_trace_provider +from agents import ( + Agent, + AgentOutputSchemaBase, + Handoff, + Model, + ModelProvider, + ModelResponse, + ModelSettings, + ModelTracing, + RunContextWrapper, + Tool, + TResponseInputItem, + set_trace_provider, +) +from agents.function_schema import DocstringStyle, function_schema +from agents.items import TResponseStreamEvent from agents.run import get_default_agent_runner, set_default_agent_runner +from agents.tool import ( + FunctionTool, + ToolErrorFunction, + ToolFunction, + ToolParams, + default_tool_error_function, + function_tool, +) from agents.tracing import get_trace_provider from agents.tracing.provider import DefaultTraceProvider +from agents.util._types import MaybeAwaitable +from openai.types.responses import ResponsePromptParam +from temporalio import activity +from temporalio import workflow as temporal_workflow +from temporalio.common import Priority, RetryPolicy +from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters from temporalio.contrib.openai_agents._openai_runner import TemporalOpenAIRunner from temporalio.contrib.openai_agents._temporal_trace_provider import ( TemporalTraceProvider, ) -from temporalio.contrib.openai_agents.model_parameters import ModelActivityParameters +from temporalio.exceptions import ApplicationError, TemporalError +from temporalio.workflow import ActivityCancellationType, VersioningIntent @contextmanager @@ -68,3 +100,235 @@ def set_open_ai_agent_temporal_overrides( finally: set_default_agent_runner(previous_runner) set_trace_provider(previous_trace_provider or DefaultTraceProvider()) + + +class TestModelProvider(ModelProvider): + """Test model provider which simply returns the given module.""" + + def __init__(self, model: Model): + """Initialize a test model provider with a model.""" + self._model = model + + def get_model(self, model_name: Union[str, None]) -> Model: + """Get a model from the model provider.""" + return self._model + + +class TestModel(Model): + """Test model for use mocking model responses.""" + + def __init__(self, fn: Callable[[], ModelResponse]) -> None: + """Initialize a test model with a callable.""" + self.fn = fn + + async def get_response( + self, + system_instructions: Union[str, None], + input: Union[str, list[TResponseInputItem]], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: Union[AgentOutputSchemaBase, None], + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: Union[str, None], + prompt: Union[ResponsePromptParam, None] = None, + ) -> ModelResponse: + """Get a response from the model.""" + return self.fn() + + def stream_response( + self, + system_instructions: Optional[str], + input: Union[str, list[TResponseInputItem]], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: Optional[AgentOutputSchemaBase], + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: Optional[str], + prompt: Optional[ResponsePromptParam], + ) -> AsyncIterator[TResponseStreamEvent]: + """Get a streamed response from the model. Unimplemented.""" + raise NotImplementedError() + + +class ToolSerializationError(TemporalError): + """Error that occurs when a tool output could not be serialized.""" + + +class workflow: + """Encapsulates workflow specific primitives for working with the OpenAI Agents SDK in a workflow context""" + + @classmethod + def activity_as_tool( + cls, + fn: Callable, + *, + task_queue: Optional[str] = None, + schedule_to_close_timeout: Optional[timedelta] = None, + schedule_to_start_timeout: Optional[timedelta] = None, + start_to_close_timeout: Optional[timedelta] = None, + heartbeat_timeout: Optional[timedelta] = None, + retry_policy: Optional[RetryPolicy] = None, + cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL, + activity_id: Optional[str] = None, + versioning_intent: Optional[VersioningIntent] = None, + summary: Optional[str] = None, + priority: Priority = Priority.default, + ) -> Tool: + """Convert a single Temporal activity function to an OpenAI agent tool. + + .. warning:: + This API is experimental and may change in future versions. + Use with caution in production environments. + + This function takes a Temporal activity function and converts it into an + OpenAI agent tool that can be used by the agent to execute the activity + during workflow execution. The tool will automatically handle the conversion + of inputs and outputs between the agent and the activity. Note that if you take a context, + mutation will not be persisted, as the activity may not be running in the same location. + + Args: + fn: A Temporal activity function to convert to a tool. + For other arguments, refer to :py:mod:`workflow` :py:meth:`start_activity` + + Returns: + An OpenAI agent tool that wraps the provided activity. + + Raises: + ApplicationError: If the function is not properly decorated as a Temporal activity. + + Example: + >>> @activity.defn + >>> def process_data(input: str) -> str: + ... return f"Processed: {input}" + >>> + >>> # Create tool with custom activity options + >>> tool = activity_as_tool( + ... process_data, + ... start_to_close_timeout=timedelta(seconds=30), + ... retry_policy=RetryPolicy(maximum_attempts=3), + ... heartbeat_timeout=timedelta(seconds=10) + ... ) + >>> # Use tool with an OpenAI agent + """ + ret = activity._Definition.from_callable(fn) + if not ret: + raise ApplicationError( + "Bare function without tool and activity decorators is not supported", + "invalid_tool", + ) + schema = function_schema(fn) + + async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any: + try: + json_data = json.loads(input) + except Exception as e: + raise ApplicationError( + f"Invalid JSON input for tool {schema.name}: {input}" + ) from e + + # Activities don't support keyword only arguments, so we can ignore the kwargs_dict return + args, _ = schema.to_call_args(schema.params_pydantic_model(**json_data)) + + # Add the context to the arguments if it takes that + if schema.takes_context: + args = [ctx] + args + + result = await temporal_workflow.execute_activity( + fn, + args=args, + task_queue=task_queue, + schedule_to_close_timeout=schedule_to_close_timeout, + schedule_to_start_timeout=schedule_to_start_timeout, + start_to_close_timeout=start_to_close_timeout, + heartbeat_timeout=heartbeat_timeout, + retry_policy=retry_policy, + cancellation_type=cancellation_type, + activity_id=activity_id, + versioning_intent=versioning_intent, + summary=summary, + priority=priority, + ) + try: + return str(result) + except Exception as e: + raise ToolSerializationError( + "You must return a string representation of the tool output, or something we can call str() on" + ) from e + + return FunctionTool( + name=schema.name, + description=schema.description or "", + params_json_schema=schema.params_json_schema, + on_invoke_tool=run_activity, + strict_json_schema=True, + ) + + @classmethod + @overload + def tool( + cls, + *, + name_override: Union[str, None] = None, + description_override: Union[str, None] = None, + docstring_style: Union[DocstringStyle, None] = None, + use_docstring_info: bool = True, + failure_error_function: Union[ + ToolErrorFunction, None + ] = default_tool_error_function, + strict_mode: bool = True, + is_enabled: Union[ + bool, Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] + ] = True, + ) -> Callable[[ToolFunction[ToolParams]], FunctionTool]: ... + + @classmethod + @overload + def tool( + cls, + func: ToolFunction[ToolParams], + *, + name_override: Union[str, None] = None, + description_override: Union[str, None] = None, + docstring_style: Union[DocstringStyle, None] = None, + use_docstring_info: bool = True, + failure_error_function: Union[ + ToolErrorFunction, None + ] = default_tool_error_function, + strict_mode: bool = True, + is_enabled: Union[ + bool, Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] + ] = True, + ) -> FunctionTool: ... + + @classmethod + def tool( + cls, + func: Union[ToolFunction[ToolParams], None] = None, + *, + name_override: Union[str, None] = None, + description_override: Union[str, None] = None, + docstring_style: Union[DocstringStyle, None] = None, + use_docstring_info: bool = True, + failure_error_function: Union[ + ToolErrorFunction, None + ] = default_tool_error_function, + strict_mode: bool = True, + is_enabled: Union[ + bool, Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] + ] = True, + ) -> Union[FunctionTool, Callable[[ToolFunction[ToolParams]], FunctionTool]]: + """A temporal specific wrapper for OpenAI's @function_tool. This exists to ensure the user is aware that the function tool is workflow level code and must be deterministic.""" + return function_tool( + func, # type: ignore + name_override=name_override, + description_override=description_override, + docstring_style=docstring_style, + use_docstring_info=use_docstring_info, + failure_error_function=failure_error_function, + strict_mode=strict_mode, + is_enabled=is_enabled, + ) diff --git a/temporalio/contrib/openai_agents/temporal_tools.py b/temporalio/contrib/openai_agents/temporal_tools.py deleted file mode 100644 index 6aacf1d08..000000000 --- a/temporalio/contrib/openai_agents/temporal_tools.py +++ /dev/null @@ -1,122 +0,0 @@ -"""Support for using Temporal activities as OpenAI agents tools.""" - -import json -from datetime import timedelta -from typing import Any, Callable, Optional - -from agents import FunctionTool, RunContextWrapper, Tool -from agents.function_schema import function_schema - -from temporalio import activity, workflow -from temporalio.common import Priority, RetryPolicy -from temporalio.exceptions import ApplicationError, TemporalError -from temporalio.workflow import ActivityCancellationType, VersioningIntent, unsafe - - -class ToolSerializationError(TemporalError): - """Error that occurs when a tool output could not be serialized.""" - - -def activity_as_tool( - fn: Callable, - *, - task_queue: Optional[str] = None, - schedule_to_close_timeout: Optional[timedelta] = None, - schedule_to_start_timeout: Optional[timedelta] = None, - start_to_close_timeout: Optional[timedelta] = None, - heartbeat_timeout: Optional[timedelta] = None, - retry_policy: Optional[RetryPolicy] = None, - cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL, - activity_id: Optional[str] = None, - versioning_intent: Optional[VersioningIntent] = None, - summary: Optional[str] = None, - priority: Priority = Priority.default, -) -> Tool: - """Convert a single Temporal activity function to an OpenAI agent tool. - - .. warning:: - This API is experimental and may change in future versions. - Use with caution in production environments. - - This function takes a Temporal activity function and converts it into an - OpenAI agent tool that can be used by the agent to execute the activity - during workflow execution. The tool will automatically handle the conversion - of inputs and outputs between the agent and the activity. Note that if you take a context, - mutation will not be persisted, as the activity may not be running in the same location. - - Args: - fn: A Temporal activity function to convert to a tool. - For other arguments, refer to :py:mod:`workflow` :py:meth:`start_activity` - - Returns: - An OpenAI agent tool that wraps the provided activity. - - Raises: - ApplicationError: If the function is not properly decorated as a Temporal activity. - - Example: - >>> @activity.defn - >>> def process_data(input: str) -> str: - ... return f"Processed: {input}" - >>> - >>> # Create tool with custom activity options - >>> tool = activity_as_tool( - ... process_data, - ... start_to_close_timeout=timedelta(seconds=30), - ... retry_policy=RetryPolicy(maximum_attempts=3), - ... heartbeat_timeout=timedelta(seconds=10) - ... ) - >>> # Use tool with an OpenAI agent - """ - ret = activity._Definition.from_callable(fn) - if not ret: - raise ApplicationError( - "Bare function without tool and activity decorators is not supported", - "invalid_tool", - ) - schema = function_schema(fn) - - async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any: - try: - json_data = json.loads(input) - except Exception as e: - raise ApplicationError( - f"Invalid JSON input for tool {schema.name}: {input}" - ) from e - - # Activities don't support keyword only arguments, so we can ignore the kwargs_dict return - args, _ = schema.to_call_args(schema.params_pydantic_model(**json_data)) - - # Add the context to the arguments if it takes that - if schema.takes_context: - args = [ctx] + args - - result = await workflow.execute_activity( - fn, - args=args, - task_queue=task_queue, - schedule_to_close_timeout=schedule_to_close_timeout, - schedule_to_start_timeout=schedule_to_start_timeout, - start_to_close_timeout=start_to_close_timeout, - heartbeat_timeout=heartbeat_timeout, - retry_policy=retry_policy, - cancellation_type=cancellation_type, - activity_id=activity_id, - versioning_intent=versioning_intent, - summary=summary, - priority=priority, - ) - try: - return str(result) - except Exception as e: - raise ToolSerializationError( - "You must return a string representation of the tool output, or something we can call str() on" - ) from e - - return FunctionTool( - name=schema.name, - description=schema.description or "", - params_json_schema=schema.params_json_schema, - on_invoke_tool=run_activity, - strict_json_schema=True, - ) diff --git a/tests/contrib/openai_agents/research_agents/planner_agent.py b/tests/contrib/openai_agents/research_agents/planner_agent.py index 3e2f26de3..8289d8636 100644 --- a/tests/contrib/openai_agents/research_agents/planner_agent.py +++ b/tests/contrib/openai_agents/research_agents/planner_agent.py @@ -3,7 +3,7 @@ PROMPT = ( "You are a helpful research assistant. Given a query, come up with a set of web searches " - "to perform to best answer the query. Output between 5 and 20 terms to query for." + "to perform to best answer the query. Output between 2 and 3 terms to query for." ) diff --git a/tests/contrib/openai_agents/research_agents/research_manager.py b/tests/contrib/openai_agents/research_agents/research_manager.py index c5d343651..de721f9b9 100644 --- a/tests/contrib/openai_agents/research_agents/research_manager.py +++ b/tests/contrib/openai_agents/research_agents/research_manager.py @@ -61,7 +61,7 @@ async def _search(self, item: WebSearchItem) -> str | None: ) return str(result.final_output) except Exception: - return None + raise async def _write_report(self, query: str, search_results: list[str]) -> ReportData: input = f"Original query: {query}\nSummarized search results: {search_results}" diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 990849d62..897a598c8 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -25,7 +25,6 @@ Tool, TResponseInputItem, Usage, - function_tool, handoff, input_guardrail, output_guardrail, @@ -37,7 +36,6 @@ ToolCallItem, ToolCallOutputItem, ) -from agents.run import DEFAULT_AGENT_RUNNER, AgentRunner from openai import AsyncOpenAI, BaseModel from openai.types.responses import ( ResponseFunctionToolCall, @@ -51,72 +49,44 @@ from temporalio import activity, workflow from temporalio.client import Client, WorkflowFailureError, WorkflowHandle -from temporalio.contrib.openai_agents.invoke_model_activity import ( +from temporalio.contrib import openai_agents +from temporalio.contrib.openai_agents import ( ModelActivity, -) -from temporalio.contrib.openai_agents.model_parameters import ModelActivityParameters -from temporalio.contrib.openai_agents.open_ai_data_converter import ( - open_ai_data_converter, -) -from temporalio.contrib.openai_agents.temporal_openai_agents import ( - set_open_ai_agent_temporal_overrides, -) -from temporalio.contrib.openai_agents.temporal_tools import activity_as_tool -from temporalio.contrib.openai_agents.trace_interceptor import ( + ModelActivityParameters, OpenAIAgentsTracingInterceptor, + TestModel, + TestModelProvider, + set_open_ai_agent_temporal_overrides, ) +from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.exceptions import CancelledError from tests.contrib.openai_agents.research_agents.research_manager import ( ResearchManager, ) from tests.helpers import new_worker - -class TestProvider(ModelProvider): - __test__ = False - - def __init__(self, model: Model): - self._model = model - - def get_model(self, model_name: Union[str, None]) -> Model: - return self._model - - response_index: int = 0 -class TestModel(OpenAIResponsesModel): +class StaticTestModel(TestModel): __test__ = False responses: list[ModelResponse] = [] + def response(self): + global response_index + response = self.responses[response_index] + response_index += 1 + return response + def __init__( self, - model: str, - openai_client: AsyncOpenAI, ) -> None: global response_index response_index = 0 - super().__init__(model, openai_client) - - async def get_response( - self, - system_instructions: Union[str, None], - input: Union[str, list[TResponseInputItem]], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: Union[AgentOutputSchemaBase, None], - handoffs: list[Handoff], - tracing: ModelTracing, - previous_response_id: Union[str, None], - prompt: Union[ResponsePromptParam, None] = None, - ) -> ModelResponse: - global response_index - response = self.responses[response_index] - response_index += 1 - return response + super().__init__(self.response) -class TestHelloModel(TestModel): +class TestHelloModel(StaticTestModel): responses = [ ModelResponse( output=[ @@ -155,19 +125,13 @@ async def test_hello_world_agent(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = open_ai_data_converter + new_config["data_converter"] = pydantic_data_converter client = Client(**new_config) model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) with set_open_ai_agent_temporal_overrides(model_params): model_activity = ModelActivity( - TestProvider( - TestHelloModel( # type: ignore - "", openai_client=AsyncOpenAI(api_key="Fake key") - ) - ) - if use_local_model - else None + TestModelProvider(TestHelloModel()) if use_local_model else None ) async with new_worker( client, HelloWorldAgent, activities=[model_activity.invoke_model_activity] @@ -229,7 +193,7 @@ async def get_weather_context(ctx: RunContextWrapper[str], city: str) -> Weather return Weather(city=city, temperature_range="14-20C", conditions=ctx.context) -class TestWeatherModel(TestModel): +class TestWeatherModel(StaticTestModel): responses = [ ModelResponse( output=[ @@ -317,16 +281,16 @@ async def run(self, question: str) -> str: name="Tools Workflow", instructions="You are a helpful agent.", tools=[ - activity_as_tool( + openai_agents.workflow.activity_as_tool( get_weather, start_to_close_timeout=timedelta(seconds=10) ), - activity_as_tool( + openai_agents.workflow.activity_as_tool( get_weather_object, start_to_close_timeout=timedelta(seconds=10) ), - activity_as_tool( + openai_agents.workflow.activity_as_tool( get_weather_country, start_to_close_timeout=timedelta(seconds=10) ), - activity_as_tool( + openai_agents.workflow.activity_as_tool( get_weather_context, start_to_close_timeout=timedelta(seconds=10) ), ], @@ -342,15 +306,14 @@ async def test_tool_workflow(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = open_ai_data_converter + new_config["data_converter"] = pydantic_data_converter client = Client(**new_config) model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) with set_open_ai_agent_temporal_overrides(model_params): model_activity = ModelActivity( - TestProvider( + TestModelProvider( TestWeatherModel( # type: ignore - "", openai_client=AsyncOpenAI(api_key="Fake key") ) ) if use_local_model @@ -442,94 +405,8 @@ async def test_tool_workflow(client: Client, use_local_model: bool): ) -class TestPlannerModel(OpenAIResponsesModel): - __test__ = False - - def __init__( - self, - model: str, - openai_client: AsyncOpenAI, - ) -> None: - super().__init__(model, openai_client) - - async def get_response( - self, - system_instructions: Union[str, None], - input: Union[str, list[TResponseInputItem]], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: Union[AgentOutputSchemaBase, None], - handoffs: list[Handoff], - tracing: ModelTracing, - previous_response_id: Union[str, None], - prompt: Union[ResponsePromptParam, None] = None, - ) -> ModelResponse: - return ModelResponse( - output=[ - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text='{"searches":[{"query":"best Caribbean surfing spots April","reason":"Identify locations with optimal surfing conditions in the Caribbean during April."},{"query":"top Caribbean islands for hiking April","reason":"Find Caribbean islands with excellent hiking opportunities that are ideal in April."},{"query":"Caribbean water sports destinations April","reason":"Locate Caribbean destinations offering a variety of water sports activities in April."},{"query":"surfing conditions Caribbean April","reason":"Understand the surfing conditions and which islands are suitable for surfing in April."},{"query":"Caribbean adventure travel hiking surfing","reason":"Explore adventure travel options that combine hiking and surfing in the Caribbean."},{"query":"best beaches for surfing Caribbean April","reason":"Identify which Caribbean beaches are renowned for surfing in April."},{"query":"Caribbean islands with national parks hiking","reason":"Find islands with national parks or reserves that offer hiking trails."},{"query":"Caribbean weather April surfing conditions","reason":"Research the weather conditions in April affecting surfing in the Caribbean."},{"query":"Caribbean water sports rentals April","reason":"Look for places where water sports equipment can be rented in the Caribbean during April."},{"query":"Caribbean multi-activity vacation packages","reason":"Look for vacation packages that offer a combination of surfing, hiking, and water sports."}]}', - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ) - ], - usage=Usage(), - response_id=None, - ) - - -class TestReportModel(OpenAIResponsesModel): - __test__ = False - - def __init__( - self, - model: str, - openai_client: AsyncOpenAI, - ) -> None: - super().__init__(model, openai_client) - - async def get_response( - self, - system_instructions: Union[str, None], - input: Union[str, list[TResponseInputItem]], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: Union[AgentOutputSchemaBase, None], - handoffs: list[Handoff], - tracing: ModelTracing, - previous_response_id: Union[str, None], - prompt: Union[ResponsePromptParam, None] = None, - ) -> ModelResponse: - return ModelResponse( - output=[ - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text="report", - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ) - ], - usage=Usage(), - response_id=None, - ) - - @no_type_check -class TestResearchModel(TestModel): +class TestResearchModel(StaticTestModel): responses = [ ModelResponse( output=[ @@ -615,7 +492,7 @@ async def test_research_workflow(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = open_ai_data_converter + new_config["data_converter"] = pydantic_data_converter client = Client(**new_config) global response_index @@ -626,13 +503,7 @@ async def test_research_workflow(client: Client, use_local_model: bool): ) with set_open_ai_agent_temporal_overrides(model_params): model_activity = ModelActivity( - TestProvider( - TestResearchModel( # type: ignore - "", openai_client=AsyncOpenAI(api_key="Fake key") - ) - ) - if use_local_model - else None + TestModelProvider(TestResearchModel()) if use_local_model else None ) async with new_worker( client, @@ -757,7 +628,7 @@ async def run(self, msg: str) -> str: return synthesizer_result.final_output -class AgentAsToolsModel(TestModel): +class AgentAsToolsModel(StaticTestModel): responses = [ ModelResponse( output=[ @@ -838,15 +709,14 @@ async def test_agents_as_tools_workflow(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = open_ai_data_converter + new_config["data_converter"] = pydantic_data_converter client = Client(**new_config) model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) with set_open_ai_agent_temporal_overrides(model_params): model_activity = ModelActivity( - TestProvider( + TestModelProvider( AgentAsToolsModel( # type: ignore - "", openai_client=AsyncOpenAI(api_key="Fake key") ) ) if use_local_model @@ -909,7 +779,7 @@ class AirlineAgentContext(BaseModel): flight_number: Optional[str] = None -@function_tool( +@openai_agents.workflow.tool( name_override="faq_lookup_tool", description_override="Lookup frequently asked questions.", ) @@ -931,19 +801,12 @@ async def faq_lookup_tool(question: str) -> str: return "I'm sorry, I don't know the answer to that question." -@function_tool +@openai_agents.workflow.tool async def update_seat( context: RunContextWrapper[AirlineAgentContext], confirmation_number: str, new_seat: str, ) -> str: - """ - Update the seat for a given confirmation number. - - Args: - confirmation_number: The confirmation number for the flight. - new_seat: The new seat to update to. - """ # Update the context based on the customer's input context.context.confirmation_number = confirmation_number context.context.seat_number = new_seat @@ -1020,7 +883,7 @@ class ProcessUserMessageInput(BaseModel): chat_length: int -class CustomerServiceModel(TestModel): +class CustomerServiceModel(StaticTestModel): responses = [ ModelResponse( output=[ @@ -1201,7 +1064,7 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = open_ai_data_converter + new_config["data_converter"] = pydantic_data_converter client = Client(**new_config) questions = ["Hello", "Book me a flight to PDX", "11111", "Any window seat"] @@ -1209,9 +1072,8 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool): model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) with set_open_ai_agent_temporal_overrides(model_params): model_activity = ModelActivity( - TestProvider( + TestModelProvider( CustomerServiceModel( # type: ignore - "", openai_client=AsyncOpenAI(api_key="Fake key") ) ) if use_local_model @@ -1495,13 +1357,13 @@ async def test_input_guardrail(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = open_ai_data_converter + new_config["data_converter"] = pydantic_data_converter client = Client(**new_config) model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) with set_open_ai_agent_temporal_overrides(model_params): model_activity = ModelActivity( - TestProvider( + TestModelProvider( InputGuardrailModel( # type: ignore "", openai_client=AsyncOpenAI(api_key="Fake key") ) @@ -1533,7 +1395,7 @@ async def test_input_guardrail(client: Client, use_local_model: bool): assert result[1] == "Sorry, I can't help you with your math homework." -class OutputGuardrailModel(TestModel): +class OutputGuardrailModel(StaticTestModel): responses = [ ModelResponse( output=[ @@ -1612,15 +1474,14 @@ async def test_output_guardrail(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = open_ai_data_converter + new_config["data_converter"] = pydantic_data_converter client = Client(**new_config) model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) with set_open_ai_agent_temporal_overrides(model_params): model_activity = ModelActivity( - TestProvider( + TestModelProvider( OutputGuardrailModel( # type: ignore - "", openai_client=AsyncOpenAI(api_key="Fake key") ) ) if use_local_model diff --git a/tests/contrib/openai_agents/test_openai_replay.py b/tests/contrib/openai_agents/test_openai_replay.py index 8f8f66cc9..01ea04407 100644 --- a/tests/contrib/openai_agents/test_openai_replay.py +++ b/tests/contrib/openai_agents/test_openai_replay.py @@ -1,16 +1,12 @@ -from datetime import timedelta from pathlib import Path import pytest from temporalio.client import WorkflowHistory -from temporalio.contrib.openai_agents.model_parameters import ModelActivityParameters -from temporalio.contrib.openai_agents.open_ai_data_converter import ( - open_ai_data_converter, -) from temporalio.contrib.openai_agents.temporal_openai_agents import ( set_open_ai_agent_temporal_overrides, ) +from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.worker import Replayer from tests.contrib.openai_agents.test_openai import ( AgentsAsToolsWorkflow, @@ -39,10 +35,7 @@ async def test_replay(file_name: str) -> None: with (Path(__file__).with_name("histories") / file_name).open("r") as f: history_json = f.read() - model_params = ModelActivityParameters( - start_to_close_timeout=timedelta(seconds=120) - ) - with set_open_ai_agent_temporal_overrides(model_params): + with set_open_ai_agent_temporal_overrides(): await Replayer( workflows=[ ResearchWorkflow, @@ -53,5 +46,5 @@ async def test_replay(file_name: str) -> None: InputGuardrailWorkflow, OutputGuardrailWorkflow, ], - data_converter=open_ai_data_converter, + data_converter=pydantic_data_converter, ).replay_workflow(WorkflowHistory.from_json("fake", history_json))