Skip to content

More builtin OpenAI tool types and test for Lite LLM custom provider #988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ dev = [
"pytest-cov>=6.1.1",
"httpx>=0.28.1",
"pytest-pretty>=1.3.0",
"openai-agents[litellm] >= 0.2.3,<0.3"
]

[tool.poe.tasks]
Expand Down
65 changes: 48 additions & 17 deletions temporalio/contrib/openai_agents/_invoke_model_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
import json
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, Optional, Union, cast
from typing import Any, Optional, Union

from agents import (
AgentOutputSchemaBase,
CodeInterpreterTool,
FileSearchTool,
FunctionTool,
Handoff,
HostedMCPTool,
ImageGenerationTool,
ModelProvider,
ModelResponse,
ModelSettings,
Expand All @@ -25,13 +28,12 @@
UserError,
WebSearchTool,
)
from agents.models.multi_provider import MultiProvider
from openai import (
APIStatusError,
AsyncOpenAI,
AuthenticationError,
PermissionDeniedError,
)
from openai.types.responses.tool_param import Mcp
from pydantic_core import to_json
from typing_extensions import Required, TypedDict

from temporalio import activity
Expand All @@ -41,7 +43,9 @@

@dataclass
class HandoffInput:
"""Data conversion friendly representation of a Handoff."""
"""Data conversion friendly representation of a Handoff. Contains only the fields which are needed by the model
execution to determine what to handoff to, not the actual handoff invocation, which remains in the workflow context.
"""

tool_name: str
tool_description: str
Expand All @@ -52,15 +56,33 @@ class HandoffInput:

@dataclass
class FunctionToolInput:
"""Data conversion friendly representation of a FunctionTool."""
"""Data conversion friendly representation of a FunctionTool. Contains only the fields which are needed by the model
execution to determine what tool to call, not the actual tool invocation, which remains in the workflow context.
"""

name: str
description: str
params_json_schema: dict[str, Any]
strict_json_schema: bool = True


ToolInput = Union[FunctionToolInput, FileSearchTool, WebSearchTool]
@dataclass
class HostedMCPToolInput:
"""Data conversion friendly representation of a HostedMCPTool. Contains only the fields which are needed by the model
execution to determine what tool to call, not the actual tool invocation, which remains in the workflow context.
"""

tool_config: Mcp


ToolInput = Union[
FunctionToolInput,
FileSearchTool,
WebSearchTool,
ImageGenerationTool,
CodeInterpreterTool,
HostedMCPToolInput,
]


@dataclass
Expand Down Expand Up @@ -152,22 +174,31 @@ async def empty_on_invoke_handoff(

# workaround for https://github.com/pydantic/pydantic/issues/9541
# ValidatorIterator returned
input_json = json.dumps(input["input"], default=str)
input_json = to_json(input["input"])
input_input = json.loads(input_json)

def make_tool(tool: ToolInput) -> Tool:
if isinstance(tool, FileSearchTool):
return cast(FileSearchTool, tool)
elif isinstance(tool, WebSearchTool):
return cast(WebSearchTool, tool)
if isinstance(
tool,
(
FileSearchTool,
WebSearchTool,
ImageGenerationTool,
CodeInterpreterTool,
),
):
return tool
elif isinstance(tool, HostedMCPToolInput):
return HostedMCPTool(
tool_config=tool.tool_config,
)
elif isinstance(tool, FunctionToolInput):
t = cast(FunctionToolInput, tool)
return FunctionTool(
name=t.name,
description=t.description,
params_json_schema=t.params_json_schema,
name=tool.name,
description=tool.description,
params_json_schema=tool.params_json_schema,
on_invoke_tool=empty_on_invoke_tool,
strict_json_schema=t.strict_json_schema,
strict_json_schema=tool.strict_json_schema,
)
else:
raise UserError(f"Unknown tool type: {tool.name}")
Expand Down
23 changes: 16 additions & 7 deletions temporalio/contrib/openai_agents/_temporal_model_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
from agents import (
AgentOutputSchema,
AgentOutputSchemaBase,
ComputerTool,
CodeInterpreterTool,
FileSearchTool,
FunctionTool,
Handoff,
HostedMCPTool,
ImageGenerationTool,
Model,
ModelResponse,
ModelSettings,
Expand All @@ -33,6 +35,7 @@
AgentOutputSchemaInput,
FunctionToolInput,
HandoffInput,
HostedMCPToolInput,
ModelActivity,
ModelTracingInput,
ToolInput,
Expand Down Expand Up @@ -65,12 +68,18 @@ async def get_response(
prompt: Optional[ResponsePromptParam],
) -> ModelResponse:
def make_tool_info(tool: Tool) -> ToolInput:
if isinstance(tool, (FileSearchTool, WebSearchTool)):
if isinstance(
tool,
(
FileSearchTool,
WebSearchTool,
ImageGenerationTool,
CodeInterpreterTool,
),
):
return tool
elif isinstance(tool, ComputerTool):
raise NotImplementedError(
"Computer search preview is not supported in Temporal model"
)
elif isinstance(tool, HostedMCPTool):
return HostedMCPToolInput(tool_config=tool.tool_config)
elif isinstance(tool, FunctionTool):
return FunctionToolInput(
name=tool.name,
Expand All @@ -79,7 +88,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
strict_json_schema=tool.strict_json_schema,
)
else:
raise ValueError(f"Unknown tool type: {tool.name}")
raise ValueError(f"Unsupported tool type: {tool.name}")

tool_infos = [make_tool_info(x) for x in tools]
handoff_infos = [
Expand Down
17 changes: 15 additions & 2 deletions temporalio/contrib/openai_agents/_temporal_openai_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@
from temporalio.contrib.openai_agents._trace_interceptor import (
OpenAIAgentsTracingInterceptor,
)
from temporalio.contrib.pydantic import pydantic_data_converter
from temporalio.contrib.pydantic import (
PydanticPayloadConverter,
ToJsonOptions,
)
from temporalio.converter import (
DataConverter,
)
from temporalio.worker import Worker, WorkerConfig


Expand Down Expand Up @@ -137,6 +143,11 @@ def stream_response(
raise NotImplementedError()


class _OpenAIPayloadConverter(PydanticPayloadConverter):
def __init__(self) -> None:
super().__init__(ToJsonOptions(exclude_unset=True))


class OpenAIAgentsPlugin(temporalio.client.Plugin, temporalio.worker.Plugin):
"""Temporal plugin for integrating OpenAI agents with Temporal workflows.

Expand Down Expand Up @@ -232,7 +243,9 @@ def configure_client(self, config: ClientConfig) -> ClientConfig:
Returns:
The modified client configuration.
"""
config["data_converter"] = pydantic_data_converter
config["data_converter"] = DataConverter(
payload_converter_class=_OpenAIPayloadConverter
)
return super().configure_client(config)

def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
Expand Down
29 changes: 25 additions & 4 deletions temporalio/contrib/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
Pydantic v1 is not supported.
"""

from dataclasses import dataclass
from typing import Any, Optional, Type

from pydantic import TypeAdapter
from pydantic_core import to_json
from pydantic_core import SchemaSerializer, to_json
from pydantic_core.core_schema import any_schema

import temporalio.api.common.v1
from temporalio.converter import (
Expand All @@ -31,6 +33,13 @@
# implements __get_pydantic_core_schema__ so that pydantic unwraps proxied types.


@dataclass
class ToJsonOptions:
"""Options for converting to JSON with pydantic."""

exclude_unset: bool = False


class PydanticJSONPlainPayloadConverter(EncodingPayloadConverter):
"""Pydantic JSON payload converter.

Expand All @@ -44,6 +53,11 @@ class PydanticJSONPlainPayloadConverter(EncodingPayloadConverter):
See https://docs.pydantic.dev/latest/api/standard_library_types/
"""

def __init__(self, to_json_options: Optional[ToJsonOptions] = None):
"""Create a new payload converter."""
self._schema_serializer = SchemaSerializer(any_schema())
self._to_json_options = to_json_options

@property
def encoding(self) -> str:
"""See base class."""
Expand All @@ -57,8 +71,15 @@ def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]:
See
https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.to_json.
"""
data = (
self._schema_serializer.to_json(
value, exclude_unset=self._to_json_options.exclude_unset
)
if self._to_json_options
else to_json(value)
)
return temporalio.api.common.v1.Payload(
metadata={"encoding": self.encoding.encode()}, data=to_json(value)
metadata={"encoding": self.encoding.encode()}, data=data
)

def from_payload(
Expand All @@ -85,9 +106,9 @@ class PydanticPayloadConverter(CompositePayloadConverter):
:py:class:`PydanticJSONPlainPayloadConverter`.
"""

def __init__(self) -> None:
def __init__(self, to_json_options: Optional[ToJsonOptions] = None) -> None:
"""Initialize object"""
json_payload_converter = PydanticJSONPlainPayloadConverter()
json_payload_converter = PydanticJSONPlainPayloadConverter(to_json_options)
super().__init__(
*(
c
Expand Down
Loading
Loading