Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions temporalio/contrib/openai_agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,13 @@
from temporalio.contrib.openai_agents._temporal_openai_agents import (
OpenAIAgentsPlugin,
OpenAIPayloadConverter,
TestModel,
TestModelProvider,
)
from temporalio.contrib.openai_agents._trace_interceptor import (
OpenAIAgentsTracingInterceptor,
)
from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError

from . import workflow
from . import testing, workflow

__all__ = [
"AgentsWorkflowError",
Expand All @@ -38,7 +36,6 @@
"OpenAIPayloadConverter",
"StatelessMCPServerProvider",
"StatefulMCPServerProvider",
"TestModel",
"TestModelProvider",
"testing",
"workflow",
]
66 changes: 1 addition & 65 deletions temporalio/contrib/openai_agents/_temporal_openai_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,7 @@
from datetime import timedelta
from typing import AsyncIterator, Callable, Optional, Sequence, Union

from agents import (
AgentOutputSchemaBase,
Handoff,
Model,
ModelProvider,
ModelResponse,
ModelSettings,
ModelTracing,
Tool,
TResponseInputItem,
set_trace_provider,
)
from agents.items import TResponseStreamEvent
from agents import ModelProvider, set_trace_provider
from agents.run import get_default_agent_runner, set_default_agent_runner
from agents.tracing import get_trace_provider
from agents.tracing.provider import DefaultTraceProvider
Expand Down Expand Up @@ -103,58 +91,6 @@ def set_open_ai_agent_temporal_overrides(
set_trace_provider(previous_trace_provider or DefaultTraceProvider())


class TestModelProvider(ModelProvider):
"""Test model provider which simply returns the given module."""

__test__ = False

def __init__(self, model: Model):
"""Initialize a test model provider with a model."""
self._model = model

def get_model(self, model_name: Union[str, None]) -> Model:
"""Get a model from the model provider."""
return self._model


class TestModel(Model):
"""Test model for use mocking model responses."""

__test__ = False

def __init__(self, fn: Callable[[], ModelResponse]) -> None:
"""Initialize a test model with a callable."""
self.fn = fn

async def get_response(
self,
system_instructions: Union[str, None],
input: Union[str, list[TResponseInputItem]],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: Union[AgentOutputSchemaBase, None],
handoffs: list[Handoff],
tracing: ModelTracing,
**kwargs,
) -> ModelResponse:
"""Get a response from the model."""
return self.fn()

def stream_response(
self,
system_instructions: Optional[str],
input: Union[str, list[TResponseInputItem]],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: Optional[AgentOutputSchemaBase],
handoffs: list[Handoff],
tracing: ModelTracing,
**kwargs,
) -> AsyncIterator[TResponseStreamEvent]:
"""Get a streamed response from the model. Unimplemented."""
raise NotImplementedError()


class OpenAIPayloadConverter(PydanticPayloadConverter):
"""PayloadConverter for OpenAI agents."""

Expand Down
175 changes: 175 additions & 0 deletions temporalio/contrib/openai_agents/testing.py
Copy link
Member

@cretz cretz Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit hard to see from this PR how this looks from a user POV. One reason we did "ActivityEnvironment" and "WorkflowEnvironment" instead of only the building blocks is because users like the nice simplicity of one-liners and reusable constructs. I'm wondering if there's an opportunity to design something here. If not too much trouble, can I see what tests/openai_agents/basic/test_hello_world_workflow.py will look like using these utilities?

Part of me wonders if we can have an AgentEnvironment that basically accepts everything the plugin accepts and also some of this mock stuff. So maybe something like:

from temporalio.contrib.openai_agents.testing import AgentEnvironment

# ...

async def test_hello_world_agent_workflow(client: Client):

    async def on_model_call(req: WhateverOpenAIRequestType) -> WhateverOpenAIResponseType:
        # Do some stuff

    # on_model_call is just an advanced example, accepting direct mocks can
    # in this constructor be allowed too
    async with AgentEnvironment(on_model_call=on_model_call) as env:
        # Applies plugin and such (which is also available on env.plugin if you want it)
        client = env.applied_on_client(client)
        # Rest of the stuff w/ worker and such

Copy link
Author

@donald-pinckney donald-pinckney Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently (with the change to static factory method I just pushed), that test would look like:

@pytest.fixture
def test_model():
    return TestModel.returning_responses(
        [ResponseBuilders.output_message("This is a haiku (not really)")]
    )

async def test_execute_workflow(client: Client):
    task_queue_name = str(uuid.uuid4())

    async with Worker(
        client,
        task_queue=task_queue_name,
        workflows=[HelloWorldAgent],
        activity_executor=ThreadPoolExecutor(5),
    ):
        result = await client.execute_workflow(
            HelloWorldAgent.run,
            "Write a recursive haiku about recursive haikus.",
            id=str(uuid.uuid4()),
            task_queue=task_queue_name,
        )
        assert isinstance(result, str)
        assert len(result) > 0

client is a fixture that depends on the test_model fixture, so you can override the test_model fixture per test or per module.

Copy link
Member

@cretz cretz Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for most users this is missing the client and plugin configuration which I think we should make easy for testers too. I think to show the full code to compare, you'd have to include your other fixtures like client configuration and plugin creation. Those fixtures are a little pytest specific and external to the test and not really have we have done test helpers in the past. I guess I was thinking something you could easily configure inside your test for each test (but still share if you want). Basically you need an easy way to configure an existing client with the plugin and model stuff.

Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""Testing utilities for OpenAI agents."""

from typing import AsyncIterator, Callable, Optional, Union

from agents import (
AgentOutputSchemaBase,
Handoff,
Model,
ModelProvider,
ModelResponse,
ModelSettings,
ModelTracing,
Tool,
TResponseInputItem,
Usage,
)
from agents.items import TResponseOutputItem, TResponseStreamEvent
from openai.types.responses import (
ResponseFunctionToolCall,
ResponseOutputMessage,
ResponseOutputText,
)


class ResponseBuilders:
"""Builders for creating model responses for testing.

.. warning::
This API is experimental and may change in the future.
"""

@staticmethod
def model_response(output: TResponseOutputItem) -> ModelResponse:
"""Create a ModelResponse with the given output.

.. warning::
This API is experimental and may change in the future.
"""
return ModelResponse(
output=[output],
usage=Usage(),
response_id=None,
)

@staticmethod
def response_output_message(text: str) -> ResponseOutputMessage:
"""Create a ResponseOutputMessage with text content.

.. warning::
This API is experimental and may change in the future.
"""
return ResponseOutputMessage(
id="",
content=[
ResponseOutputText(
text=text,
annotations=[],
type="output_text",
)
],
role="assistant",
status="completed",
type="message",
)

@staticmethod
def tool_call(arguments: str, name: str) -> ModelResponse:
"""Create a ModelResponse with a function tool call.

.. warning::
This API is experimental and may change in the future.
"""
return ResponseBuilders.model_response(
ResponseFunctionToolCall(
arguments=arguments,
call_id="call",
name=name,
type="function_call",
id="id",
status="completed",
)
)

@staticmethod
def output_message(text: str) -> ModelResponse:
"""Create a ModelResponse with an output message.

.. warning::
This API is experimental and may change in the future.
"""
return ResponseBuilders.model_response(
ResponseBuilders.response_output_message(text)
)


class TestModelProvider(ModelProvider):
"""Test model provider which simply returns the given module.

.. warning::
This API is experimental and may change in the future.
"""

__test__ = False

def __init__(self, model: Model):
"""Initialize a test model provider with a model.

.. warning::
This API is experimental and may change in the future.
"""
self._model = model

def get_model(self, model_name: Union[str, None]) -> Model:
"""Get a model from the model provider.

.. warning::
This API is experimental and may change in the future.
"""
return self._model


class TestModel(Model):
"""Test model for use mocking model responses.

.. warning::
This API is experimental and may change in the future.
"""

__test__ = False

def __init__(self, fn: Callable[[], ModelResponse]) -> None:
"""Initialize a test model with a callable.

.. warning::
This API is experimental and may change in the future.
"""
self.fn = fn

async def get_response(
self,
system_instructions: Union[str, None],
input: Union[str, list[TResponseInputItem]],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: Union[AgentOutputSchemaBase, None],
handoffs: list[Handoff],
tracing: ModelTracing,
**kwargs,
) -> ModelResponse:
"""Get a response from the mocked model, by calling the callable passed to the constructor."""
return self.fn()

def stream_response(
self,
system_instructions: Optional[str],
input: Union[str, list[TResponseInputItem]],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: Optional[AgentOutputSchemaBase],
handoffs: list[Handoff],
tracing: ModelTracing,
**kwargs,
) -> AsyncIterator[TResponseStreamEvent]:
"""Get a streamed response from the model. Unimplemented."""
raise NotImplementedError()

@staticmethod
def returning_responses(responses: list[ModelResponse]) -> "TestModel":
"""Create a mock model which sequentially returns responses from a list.

.. warning::
This API is experimental and may change in the future.
"""
i = iter(responses)
return TestModel(lambda: next(i))
Loading
Loading