Skip to content

OpenAI/plugin #955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions temporalio/contrib/openai_agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,18 @@
Use with caution in production environments.
"""

from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
from temporalio.contrib.openai_agents._trace_interceptor import (
OpenAIAgentsTracingInterceptor,
)
from temporalio.contrib.openai_agents.temporal_openai_agents import (
Plugin,
TestModel,
TestModelProvider,
set_open_ai_agent_temporal_overrides,
workflow,
)

__all__ = [
"ModelActivity",
"Plugin",
"ModelActivityParameters",
"workflow",
"set_open_ai_agent_temporal_overrides",
"OpenAIAgentsTracingInterceptor",
"TestModel",
"TestModelProvider",
]
47 changes: 38 additions & 9 deletions temporalio/contrib/openai_agents/temporal_openai_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import json
from contextlib import contextmanager
from datetime import timedelta
from typing import Any, AsyncIterator, Callable, Optional, Union, overload
from typing import Any, AsyncIterator, Callable, Optional, Union

from agents import (
Agent,
AgentOutputSchemaBase,
Handoff,
Model,
Expand All @@ -19,31 +18,34 @@
TResponseInputItem,
set_trace_provider,
)
from agents.function_schema import DocstringStyle, function_schema
from agents.function_schema import function_schema
from agents.items import TResponseStreamEvent
from agents.run import get_default_agent_runner, set_default_agent_runner
from agents.tool import (
FunctionTool,
ToolErrorFunction,
ToolFunction,
ToolParams,
default_tool_error_function,
function_tool,
)
from agents.tracing import get_trace_provider
from agents.tracing.provider import DefaultTraceProvider
from agents.util._types import MaybeAwaitable
from openai.types.responses import ResponsePromptParam

import temporalio.client
import temporalio.worker
from temporalio import activity
from temporalio import workflow as temporal_workflow
from temporalio.client import ClientConfig
from temporalio.common import Priority, RetryPolicy
from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
from temporalio.contrib.openai_agents._openai_runner import TemporalOpenAIRunner
from temporalio.contrib.openai_agents._temporal_trace_provider import (
TemporalTraceProvider,
)
from temporalio.contrib.openai_agents._trace_interceptor import (
OpenAIAgentsTracingInterceptor,
)
from temporalio.contrib.pydantic import pydantic_data_converter
from temporalio.exceptions import ApplicationError, TemporalError
from temporalio.worker import Worker, WorkerConfig
from temporalio.workflow import ActivityCancellationType, VersioningIntent


Expand Down Expand Up @@ -154,6 +156,33 @@ def stream_response(
raise NotImplementedError()


class Plugin(temporalio.client.Plugin, temporalio.worker.Plugin):
def __init__(
self,
model_params: Optional[ModelActivityParameters] = None,
model_provider: Optional[ModelProvider] = None,
) -> None:
self._model_params = model_params
self._model_provider = model_provider

def on_create_client(self, config: ClientConfig) -> ClientConfig:
config["data_converter"] = pydantic_data_converter
return super().on_create_client(config)

def on_create_worker(self, config: WorkerConfig) -> WorkerConfig:
config["interceptors"] = list(config.get("interceptors") or []) + [
OpenAIAgentsTracingInterceptor()
]
config["activities"] = list(config.get("activities") or []) + [
ModelActivity(self._model_provider).invoke_model_activity
]
return super().on_create_worker(config)

async def run_worker(self, worker: Worker) -> None:
with set_open_ai_agent_temporal_overrides(self._model_params):
await super().run_worker(worker)


class ToolSerializationError(TemporalError):
"""Error that occurs when a tool output could not be serialized."""

Expand Down
44 changes: 31 additions & 13 deletions temporalio/worker/_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,8 @@ def __init__(
)

plugins_from_client = cast(
List[Plugin], [p for p in client.config()["plugins"] if isinstance(p, Plugin)]
List[Plugin],
[p for p in client.config()["plugins"] if isinstance(p, Plugin)],
)
plugins = plugins_from_client + list(plugins)

Expand All @@ -402,15 +403,21 @@ def _init_from_config(self, config: WorkerConfig):
self._config = config

# TODO(nexus-preview): max_concurrent_nexus_tasks / tuner support
if not (config["activities"] or config["nexus_service_handlers"] or config["workflows"]):
if not (
config["activities"]
or config["nexus_service_handlers"]
or config["workflows"]
):
raise ValueError(
"At least one activity, Nexus service, or workflow must be specified"
)
if config["use_worker_versioning"] and not config["build_id"]:
raise ValueError(
"build_id must be specified when use_worker_versioning is True"
)
if config["deployment_config"] and (config["build_id"] or config["use_worker_versioning"]):
if config["deployment_config"] and (
config["build_id"] or config["use_worker_versioning"]
):
raise ValueError(
"deployment_config cannot be used with build_id or use_worker_versioning"
)
Expand Down Expand Up @@ -506,9 +513,13 @@ def check_activity(activity):
unsandboxed_workflow_runner=config["unsandboxed_workflow_runner"],
data_converter=client_config["data_converter"],
interceptors=interceptors,
workflow_failure_exception_types=config["workflow_failure_exception_types"],
workflow_failure_exception_types=config[
"workflow_failure_exception_types"
],
debug_mode=config["debug_mode"],
disable_eager_activity_execution=config["disable_eager_activity_execution"],
disable_eager_activity_execution=config[
"disable_eager_activity_execution"
],
metric_meter=self._runtime.metric_meter,
on_eviction_hook=None,
disable_safe_eviction=config["disable_safe_workflow_eviction"],
Expand All @@ -519,7 +530,7 @@ def check_activity(activity):
)

tuner = config["tuner"]
if config["tuner"] is not None:
if tuner is not None:
if (
config["max_concurrent_workflow_tasks"]
or config["max_concurrent_activities"]
Expand All @@ -540,9 +551,9 @@ def check_activity(activity):

versioning_strategy: temporalio.bridge.worker.WorkerVersioningStrategy
if config["deployment_config"]:
versioning_strategy = (
config["deployment_config"]._to_bridge_worker_deployment_options()
)
versioning_strategy = config[
"deployment_config"
]._to_bridge_worker_deployment_options()
elif config["use_worker_versioning"]:
build_id = config["build_id"] or load_default_build_id()
versioning_strategy = (
Expand Down Expand Up @@ -586,9 +597,11 @@ def check_activity(activity):
# We have to disable remote activities if a user asks _or_ if we
# are not running an activity worker at all. Otherwise shutdown
# will not proceed properly.
no_remote_activities=config["no_remote_activities"] or not config["activities"],
no_remote_activities=config["no_remote_activities"]
or not config["activities"],
sticky_queue_schedule_to_start_timeout_millis=int(
1000 * config["sticky_queue_schedule_to_start_timeout"].total_seconds()
1000
* config["sticky_queue_schedule_to_start_timeout"].total_seconds()
),
max_heartbeat_throttle_interval_millis=int(
1000 * config["max_heartbeat_throttle_interval"].total_seconds()
Expand All @@ -597,7 +610,9 @@ def check_activity(activity):
1000 * config["default_heartbeat_throttle_interval"].total_seconds()
),
max_activities_per_second=config["max_activities_per_second"],
max_task_queue_activities_per_second=config["max_task_queue_activities_per_second"],
max_task_queue_activities_per_second=config[
"max_task_queue_activities_per_second"
],
graceful_shutdown_period_millis=int(
1000 * config["graceful_shutdown_timeout"].total_seconds()
),
Expand All @@ -614,7 +629,9 @@ def check_activity(activity):
versioning_strategy=versioning_strategy,
workflow_task_poller_behavior=workflow_task_poller_behavior._to_bridge(),
activity_task_poller_behavior=activity_task_poller_behavior._to_bridge(),
nexus_task_poller_behavior=config["nexus_task_poller_behavior"]._to_bridge(),
nexus_task_poller_behavior=config[
"nexus_task_poller_behavior"
]._to_bridge(),
),
)

Expand Down Expand Up @@ -902,6 +919,7 @@ class WorkerConfig(TypedDict, total=False):
activity_task_poller_behavior: PollerBehavior
nexus_task_poller_behavior: PollerBehavior


@dataclass
class WorkerDeploymentConfig:
"""Options for configuring the Worker Versioning feature.
Expand Down
Loading
Loading