Skip to content

💥 Streamline OpenAI module layout #947

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 8, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions temporalio/contrib/openai_agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,25 @@
This module is experimental and may change in future versions.
Use with caution in production environments.
"""

from temporalio.contrib.openai_agents.invoke_model_activity import ModelActivity
from temporalio.contrib.openai_agents.model_parameters import ModelActivityParameters
from temporalio.contrib.openai_agents.open_ai_data_converter import (
open_ai_data_converter,
)
from temporalio.contrib.openai_agents.temporal_openai_agents import (
set_open_ai_agent_temporal_overrides,
workflow,
)
from temporalio.contrib.openai_agents.trace_interceptor import (
OpenAIAgentsTracingInterceptor,
)

__all__ = [
"open_ai_data_converter",
"ModelActivity",
"ModelActivityParameters",
"workflow",
"set_open_ai_agent_temporal_overrides",
"OpenAIAgentsTracingInterceptor",
]
275 changes: 270 additions & 5 deletions temporalio/contrib/openai_agents/temporal_openai_agents.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,51 @@
"""Initialize Temporal OpenAI Agents overrides."""

import json
from contextlib import contextmanager
from typing import Optional
from datetime import timedelta
from typing import Any, AsyncIterator, Callable, Optional, Union, overload

from agents import set_trace_provider
from agents.run import get_default_agent_runner, set_default_agent_runner
from agents.tracing import get_trace_provider
from agents.tracing.provider import DefaultTraceProvider
from agents.items import TResponseStreamEvent
from openai.types.responses import ResponsePromptParam

from temporalio import activity
from temporalio import workflow as temporal_workflow
from temporalio.common import Priority, RetryPolicy
from temporalio.contrib.openai_agents._openai_runner import TemporalOpenAIRunner
from temporalio.contrib.openai_agents._temporal_trace_provider import (
TemporalTraceProvider,
)
from temporalio.contrib.openai_agents.model_parameters import ModelActivityParameters
from temporalio.exceptions import ApplicationError, TemporalError
from temporalio.workflow import ActivityCancellationType, VersioningIntent, unsafe

with unsafe.imports_passed_through():
from agents import (
Agent,
AgentOutputSchemaBase,
Handoff,
Model,
ModelProvider,
ModelResponse,
ModelSettings,
ModelTracing,
RunContextWrapper,
Tool,
TResponseInputItem,
set_trace_provider,
)
from agents.function_schema import DocstringStyle, function_schema
from agents.run import get_default_agent_runner, set_default_agent_runner
from agents.tool import (
FunctionTool,
ToolErrorFunction,
ToolFunction,
default_tool_error_function,
function_tool,
)
from agents.tracing import get_trace_provider
from agents.tracing.provider import DefaultTraceProvider
from agents.util._types import MaybeAwaitable


@contextmanager
Expand Down Expand Up @@ -68,3 +101,235 @@ def set_open_ai_agent_temporal_overrides(
finally:
set_default_agent_runner(previous_runner)
set_trace_provider(previous_trace_provider or DefaultTraceProvider())


class TestModelProvider(ModelProvider):
"""Test model provider which simply returns the given module."""

def __init__(self, model: Model):
"""Initialize a test model provider with a model."""
self._model = model

def get_model(self, model_name: Union[str, None]) -> Model:
"""Get a model from the model provider."""
return self._model


class TestModel(Model):
"""Test model for use mocking model responses."""

def __init__(self, fn: Callable[[], ModelResponse]) -> None:
"""Initialize a test model with a callable."""
self.fn = fn

async def get_response(
self,
system_instructions: Union[str, None],
input: Union[str, list[TResponseInputItem]],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: Union[AgentOutputSchemaBase, None],
handoffs: list[Handoff],
tracing: ModelTracing,
*,
previous_response_id: Union[str, None],
prompt: Union[ResponsePromptParam, None] = None,
) -> ModelResponse:
"""Get a response from the model."""
return self.fn()

def stream_response(
self,
system_instructions: str | None,
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
tracing: ModelTracing,
*,
previous_response_id: str | None,
prompt: ResponsePromptParam | None,
) -> AsyncIterator[TResponseStreamEvent]:
"""Get a streamed response from the model. Unimplemented."""
raise NotImplementedError()


class ToolSerializationError(TemporalError):
"""Error that occurs when a tool output could not be serialized."""


class workflow:
"""Encapsulates workflow specific primitives for working with the OpenAI Agents SDK in a workflow context"""

@classmethod
def activity_as_tool(
cls,
fn: Callable,
*,
task_queue: Optional[str] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
schedule_to_start_timeout: Optional[timedelta] = None,
start_to_close_timeout: Optional[timedelta] = None,
heartbeat_timeout: Optional[timedelta] = None,
retry_policy: Optional[RetryPolicy] = None,
cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL,
activity_id: Optional[str] = None,
versioning_intent: Optional[VersioningIntent] = None,
summary: Optional[str] = None,
priority: Priority = Priority.default,
) -> Tool:
"""Convert a single Temporal activity function to an OpenAI agent tool.

.. warning::
This API is experimental and may change in future versions.
Use with caution in production environments.

This function takes a Temporal activity function and converts it into an
OpenAI agent tool that can be used by the agent to execute the activity
during workflow execution. The tool will automatically handle the conversion
of inputs and outputs between the agent and the activity. Note that if you take a context,
mutation will not be persisted, as the activity may not be running in the same location.

Args:
fn: A Temporal activity function to convert to a tool.
For other arguments, refer to :py:mod:`workflow` :py:meth:`start_activity`

Returns:
An OpenAI agent tool that wraps the provided activity.

Raises:
ApplicationError: If the function is not properly decorated as a Temporal activity.

Example:
>>> @activity.defn
>>> def process_data(input: str) -> str:
... return f"Processed: {input}"
>>>
>>> # Create tool with custom activity options
>>> tool = activity_as_tool(
... process_data,
... start_to_close_timeout=timedelta(seconds=30),
... retry_policy=RetryPolicy(maximum_attempts=3),
... heartbeat_timeout=timedelta(seconds=10)
... )
>>> # Use tool with an OpenAI agent
"""
ret = activity._Definition.from_callable(fn)
if not ret:
raise ApplicationError(
"Bare function without tool and activity decorators is not supported",
"invalid_tool",
)
schema = function_schema(fn)

async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
try:
json_data = json.loads(input)
except Exception as e:
raise ApplicationError(
f"Invalid JSON input for tool {schema.name}: {input}"
) from e

# Activities don't support keyword only arguments, so we can ignore the kwargs_dict return
args, _ = schema.to_call_args(schema.params_pydantic_model(**json_data))

# Add the context to the arguments if it takes that
if schema.takes_context:
args = [ctx] + args

result = await temporal_workflow.execute_activity(
fn,
args=args,
task_queue=task_queue,
schedule_to_close_timeout=schedule_to_close_timeout,
schedule_to_start_timeout=schedule_to_start_timeout,
start_to_close_timeout=start_to_close_timeout,
heartbeat_timeout=heartbeat_timeout,
retry_policy=retry_policy,
cancellation_type=cancellation_type,
activity_id=activity_id,
versioning_intent=versioning_intent,
summary=summary,
priority=priority,
)
try:
return str(result)
except Exception as e:
raise ToolSerializationError(
"You must return a string representation of the tool output, or something we can call str() on"
) from e

return FunctionTool(
name=schema.name,
description=schema.description or "",
params_json_schema=schema.params_json_schema,
on_invoke_tool=run_activity,
strict_json_schema=True,
)

@classmethod
@overload
def tool(
cls,
*,
name_override: Union[str, None] = None,
description_override: Union[str, None] = None,
docstring_style: Union[DocstringStyle, None] = None,
use_docstring_info: bool = True,
failure_error_function: Union[
ToolErrorFunction, None
] = default_tool_error_function,
strict_mode: bool = True,
is_enabled: Union[
bool, Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]]
] = True,
) -> Callable[[ToolFunction[...]], FunctionTool]: ...

@classmethod
@overload
def tool(
cls,
func: ToolFunction[...],
*,
name_override: Union[str, None] = None,
description_override: Union[str, None] = None,
docstring_style: Union[DocstringStyle, None] = None,
use_docstring_info: bool = True,
failure_error_function: Union[
ToolErrorFunction, None
] = default_tool_error_function,
strict_mode: bool = True,
is_enabled: Union[
bool, Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]]
] = True,
) -> FunctionTool: ...

@classmethod
def tool(
cls,
func: Union[ToolFunction[...], None] = None,
*,
name_override: Union[str, None] = None,
description_override: Union[str, None] = None,
docstring_style: Union[DocstringStyle, None] = None,
use_docstring_info: bool = True,
failure_error_function: Union[
ToolErrorFunction, None
] = default_tool_error_function,
strict_mode: bool = True,
is_enabled: Union[
bool, Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]]
] = True,
) -> Union[FunctionTool, Callable[[ToolFunction[...]], FunctionTool]]:
"""A temporal specific wrapper for OpenAI's @function_tool. This exists to ensure the user is aware that the function tool is workflow level code and must be deterministic."""
return function_tool(
func, # type: ignore
name_override=name_override,
description_override=description_override,
docstring_style=docstring_style,
use_docstring_info=use_docstring_info,
failure_error_function=failure_error_function,
strict_mode=strict_mode,
is_enabled=is_enabled,
)
Loading
Loading