Skip to content

OpenAI/plugin #956

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: plugins
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions temporalio/contrib/openai_agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,18 @@
Use with caution in production environments.
"""

from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
from temporalio.contrib.openai_agents._trace_interceptor import (
OpenAIAgentsTracingInterceptor,
)
from temporalio.contrib.openai_agents.temporal_openai_agents import (
Plugin,
TestModel,
TestModelProvider,
set_open_ai_agent_temporal_overrides,
workflow,
)

__all__ = [
"ModelActivity",
"Plugin",
"ModelActivityParameters",
"workflow",
"set_open_ai_agent_temporal_overrides",
"OpenAIAgentsTracingInterceptor",
"TestModel",
"TestModelProvider",
]
47 changes: 38 additions & 9 deletions temporalio/contrib/openai_agents/temporal_openai_agents.py
Copy link
Member

@cretz cretz Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should start with an underscore IMO as should every file in this module (until we add a workflow.py file to move from static class to module for that stuff). Granted this can be done separately. There may need to be general file rename/cleanup anyways (it's a bit scattered/inconsistent in some ways).

Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import json
from contextlib import contextmanager
from datetime import timedelta
from typing import Any, AsyncIterator, Callable, Optional, Union, overload
from typing import Any, AsyncIterator, Callable, Optional, Union

from agents import (
Agent,
AgentOutputSchemaBase,
Handoff,
Model,
Expand All @@ -19,31 +18,34 @@
TResponseInputItem,
set_trace_provider,
)
from agents.function_schema import DocstringStyle, function_schema
from agents.function_schema import function_schema
from agents.items import TResponseStreamEvent
from agents.run import get_default_agent_runner, set_default_agent_runner
from agents.tool import (
FunctionTool,
ToolErrorFunction,
ToolFunction,
ToolParams,
default_tool_error_function,
function_tool,
)
from agents.tracing import get_trace_provider
from agents.tracing.provider import DefaultTraceProvider
from agents.util._types import MaybeAwaitable
from openai.types.responses import ResponsePromptParam

import temporalio.client
import temporalio.worker
from temporalio import activity
from temporalio import workflow as temporal_workflow
from temporalio.client import ClientConfig
from temporalio.common import Priority, RetryPolicy
from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
from temporalio.contrib.openai_agents._openai_runner import TemporalOpenAIRunner
from temporalio.contrib.openai_agents._temporal_trace_provider import (
TemporalTraceProvider,
)
from temporalio.contrib.openai_agents._trace_interceptor import (
OpenAIAgentsTracingInterceptor,
)
from temporalio.contrib.pydantic import pydantic_data_converter
from temporalio.exceptions import ApplicationError, TemporalError
from temporalio.worker import Worker, WorkerConfig
from temporalio.workflow import ActivityCancellationType, VersioningIntent


Expand Down Expand Up @@ -154,6 +156,33 @@ def stream_response(
raise NotImplementedError()


class Plugin(temporalio.client.Plugin, temporalio.worker.Plugin):
def __init__(
self,
model_params: Optional[ModelActivityParameters] = None,
model_provider: Optional[ModelProvider] = None,
) -> None:
self._model_params = model_params
self._model_provider = model_provider

def on_create_client(self, config: ClientConfig) -> ClientConfig:
config["data_converter"] = pydantic_data_converter
return super().on_create_client(config)

def on_create_worker(self, config: WorkerConfig) -> WorkerConfig:
config["interceptors"] = list(config.get("interceptors") or []) + [
OpenAIAgentsTracingInterceptor()
]
config["activities"] = list(config.get("activities") or []) + [
ModelActivity(self._model_provider).invoke_model_activity
]
return super().on_create_worker(config)

async def run_worker(self, worker: Worker) -> None:
with set_open_ai_agent_temporal_overrides(self._model_params):
await super().run_worker(worker)


class ToolSerializationError(TemporalError):
"""Error that occurs when a tool output could not be serialized."""

Expand Down
Loading
Loading