Skip to content

Commit b6b0973

Browse files
authored
More builtin OpenAI tool types and test for Lite LLM custom provider (#988)
* WIP - additional tool types and LiteLLM * Adding a few more tool types and tests * Remove stale commented code * Lazy import LiteLLM, skip on 3.9 * Addressing some PR comments * Cleanup image generation test, required moving serialization test to new type * Lint * Fixing serialization issues * Move json options down a level * Extend code interpreter timeout
1 parent 3d9bfee commit b6b0973

File tree

7 files changed

+1513
-42
lines changed

7 files changed

+1513
-42
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ dev = [
5757
"pytest-cov>=6.1.1",
5858
"httpx>=0.28.1",
5959
"pytest-pretty>=1.3.0",
60+
"openai-agents[litellm] >= 0.2.3,<0.3"
6061
]
6162

6263
[tool.poe.tasks]

temporalio/contrib/openai_agents/_invoke_model_activity.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77
import json
88
from dataclasses import dataclass
99
from datetime import timedelta
10-
from typing import Any, Optional, Union, cast
10+
from typing import Any, Optional, Union
1111

1212
from agents import (
1313
AgentOutputSchemaBase,
14+
CodeInterpreterTool,
1415
FileSearchTool,
1516
FunctionTool,
1617
Handoff,
18+
HostedMCPTool,
19+
ImageGenerationTool,
1720
ModelProvider,
1821
ModelResponse,
1922
ModelSettings,
@@ -25,13 +28,12 @@
2528
UserError,
2629
WebSearchTool,
2730
)
28-
from agents.models.multi_provider import MultiProvider
2931
from openai import (
3032
APIStatusError,
3133
AsyncOpenAI,
32-
AuthenticationError,
33-
PermissionDeniedError,
3434
)
35+
from openai.types.responses.tool_param import Mcp
36+
from pydantic_core import to_json
3537
from typing_extensions import Required, TypedDict
3638

3739
from temporalio import activity
@@ -41,7 +43,9 @@
4143

4244
@dataclass
4345
class HandoffInput:
44-
"""Data conversion friendly representation of a Handoff."""
46+
"""Data conversion friendly representation of a Handoff. Contains only the fields which are needed by the model
47+
execution to determine what to handoff to, not the actual handoff invocation, which remains in the workflow context.
48+
"""
4549

4650
tool_name: str
4751
tool_description: str
@@ -52,15 +56,33 @@ class HandoffInput:
5256

5357
@dataclass
5458
class FunctionToolInput:
55-
"""Data conversion friendly representation of a FunctionTool."""
59+
"""Data conversion friendly representation of a FunctionTool. Contains only the fields which are needed by the model
60+
execution to determine what tool to call, not the actual tool invocation, which remains in the workflow context.
61+
"""
5662

5763
name: str
5864
description: str
5965
params_json_schema: dict[str, Any]
6066
strict_json_schema: bool = True
6167

6268

63-
ToolInput = Union[FunctionToolInput, FileSearchTool, WebSearchTool]
69+
@dataclass
70+
class HostedMCPToolInput:
71+
"""Data conversion friendly representation of a HostedMCPTool. Contains only the fields which are needed by the model
72+
execution to determine what tool to call, not the actual tool invocation, which remains in the workflow context.
73+
"""
74+
75+
tool_config: Mcp
76+
77+
78+
ToolInput = Union[
79+
FunctionToolInput,
80+
FileSearchTool,
81+
WebSearchTool,
82+
ImageGenerationTool,
83+
CodeInterpreterTool,
84+
HostedMCPToolInput,
85+
]
6486

6587

6688
@dataclass
@@ -152,22 +174,31 @@ async def empty_on_invoke_handoff(
152174

153175
# workaround for https://github.com/pydantic/pydantic/issues/9541
154176
# ValidatorIterator returned
155-
input_json = json.dumps(input["input"], default=str)
177+
input_json = to_json(input["input"])
156178
input_input = json.loads(input_json)
157179

158180
def make_tool(tool: ToolInput) -> Tool:
159-
if isinstance(tool, FileSearchTool):
160-
return cast(FileSearchTool, tool)
161-
elif isinstance(tool, WebSearchTool):
162-
return cast(WebSearchTool, tool)
181+
if isinstance(
182+
tool,
183+
(
184+
FileSearchTool,
185+
WebSearchTool,
186+
ImageGenerationTool,
187+
CodeInterpreterTool,
188+
),
189+
):
190+
return tool
191+
elif isinstance(tool, HostedMCPToolInput):
192+
return HostedMCPTool(
193+
tool_config=tool.tool_config,
194+
)
163195
elif isinstance(tool, FunctionToolInput):
164-
t = cast(FunctionToolInput, tool)
165196
return FunctionTool(
166-
name=t.name,
167-
description=t.description,
168-
params_json_schema=t.params_json_schema,
197+
name=tool.name,
198+
description=tool.description,
199+
params_json_schema=tool.params_json_schema,
169200
on_invoke_tool=empty_on_invoke_tool,
170-
strict_json_schema=t.strict_json_schema,
201+
strict_json_schema=tool.strict_json_schema,
171202
)
172203
else:
173204
raise UserError(f"Unknown tool type: {tool.name}")

temporalio/contrib/openai_agents/_temporal_model_stub.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
from agents import (
1414
AgentOutputSchema,
1515
AgentOutputSchemaBase,
16-
ComputerTool,
16+
CodeInterpreterTool,
1717
FileSearchTool,
1818
FunctionTool,
1919
Handoff,
20+
HostedMCPTool,
21+
ImageGenerationTool,
2022
Model,
2123
ModelResponse,
2224
ModelSettings,
@@ -33,6 +35,7 @@
3335
AgentOutputSchemaInput,
3436
FunctionToolInput,
3537
HandoffInput,
38+
HostedMCPToolInput,
3639
ModelActivity,
3740
ModelTracingInput,
3841
ToolInput,
@@ -65,12 +68,18 @@ async def get_response(
6568
prompt: Optional[ResponsePromptParam],
6669
) -> ModelResponse:
6770
def make_tool_info(tool: Tool) -> ToolInput:
68-
if isinstance(tool, (FileSearchTool, WebSearchTool)):
71+
if isinstance(
72+
tool,
73+
(
74+
FileSearchTool,
75+
WebSearchTool,
76+
ImageGenerationTool,
77+
CodeInterpreterTool,
78+
),
79+
):
6980
return tool
70-
elif isinstance(tool, ComputerTool):
71-
raise NotImplementedError(
72-
"Computer search preview is not supported in Temporal model"
73-
)
81+
elif isinstance(tool, HostedMCPTool):
82+
return HostedMCPToolInput(tool_config=tool.tool_config)
7483
elif isinstance(tool, FunctionTool):
7584
return FunctionToolInput(
7685
name=tool.name,
@@ -79,7 +88,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
7988
strict_json_schema=tool.strict_json_schema,
8089
)
8190
else:
82-
raise ValueError(f"Unknown tool type: {tool.name}")
91+
raise ValueError(f"Unsupported tool type: {tool.name}")
8392

8493
tool_infos = [make_tool_info(x) for x in tools]
8594
handoff_infos = [

temporalio/contrib/openai_agents/_temporal_openai_agents.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@
3434
from temporalio.contrib.openai_agents._trace_interceptor import (
3535
OpenAIAgentsTracingInterceptor,
3636
)
37-
from temporalio.contrib.pydantic import pydantic_data_converter
37+
from temporalio.contrib.pydantic import (
38+
PydanticPayloadConverter,
39+
ToJsonOptions,
40+
)
41+
from temporalio.converter import (
42+
DataConverter,
43+
)
3844
from temporalio.worker import Worker, WorkerConfig
3945

4046

@@ -137,6 +143,11 @@ def stream_response(
137143
raise NotImplementedError()
138144

139145

146+
class _OpenAIPayloadConverter(PydanticPayloadConverter):
147+
def __init__(self) -> None:
148+
super().__init__(ToJsonOptions(exclude_unset=True))
149+
150+
140151
class OpenAIAgentsPlugin(temporalio.client.Plugin, temporalio.worker.Plugin):
141152
"""Temporal plugin for integrating OpenAI agents with Temporal workflows.
142153
@@ -232,7 +243,9 @@ def configure_client(self, config: ClientConfig) -> ClientConfig:
232243
Returns:
233244
The modified client configuration.
234245
"""
235-
config["data_converter"] = pydantic_data_converter
246+
config["data_converter"] = DataConverter(
247+
payload_converter_class=_OpenAIPayloadConverter
248+
)
236249
return super().configure_client(config)
237250

238251
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:

temporalio/contrib/pydantic.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
Pydantic v1 is not supported.
1414
"""
1515

16+
from dataclasses import dataclass
1617
from typing import Any, Optional, Type
1718

1819
from pydantic import TypeAdapter
19-
from pydantic_core import to_json
20+
from pydantic_core import SchemaSerializer, to_json
21+
from pydantic_core.core_schema import any_schema
2022

2123
import temporalio.api.common.v1
2224
from temporalio.converter import (
@@ -31,6 +33,13 @@
3133
# implements __get_pydantic_core_schema__ so that pydantic unwraps proxied types.
3234

3335

36+
@dataclass
37+
class ToJsonOptions:
38+
"""Options for converting to JSON with pydantic."""
39+
40+
exclude_unset: bool = False
41+
42+
3443
class PydanticJSONPlainPayloadConverter(EncodingPayloadConverter):
3544
"""Pydantic JSON payload converter.
3645
@@ -44,6 +53,11 @@ class PydanticJSONPlainPayloadConverter(EncodingPayloadConverter):
4453
See https://docs.pydantic.dev/latest/api/standard_library_types/
4554
"""
4655

56+
def __init__(self, to_json_options: Optional[ToJsonOptions] = None):
57+
"""Create a new payload converter."""
58+
self._schema_serializer = SchemaSerializer(any_schema())
59+
self._to_json_options = to_json_options
60+
4761
@property
4862
def encoding(self) -> str:
4963
"""See base class."""
@@ -57,8 +71,15 @@ def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]:
5771
See
5872
https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.to_json.
5973
"""
74+
data = (
75+
self._schema_serializer.to_json(
76+
value, exclude_unset=self._to_json_options.exclude_unset
77+
)
78+
if self._to_json_options
79+
else to_json(value)
80+
)
6081
return temporalio.api.common.v1.Payload(
61-
metadata={"encoding": self.encoding.encode()}, data=to_json(value)
82+
metadata={"encoding": self.encoding.encode()}, data=data
6283
)
6384

6485
def from_payload(
@@ -85,9 +106,9 @@ class PydanticPayloadConverter(CompositePayloadConverter):
85106
:py:class:`PydanticJSONPlainPayloadConverter`.
86107
"""
87108

88-
def __init__(self) -> None:
109+
def __init__(self, to_json_options: Optional[ToJsonOptions] = None) -> None:
89110
"""Initialize object"""
90-
json_payload_converter = PydanticJSONPlainPayloadConverter()
111+
json_payload_converter = PydanticJSONPlainPayloadConverter(to_json_options)
91112
super().__init__(
92113
*(
93114
c

0 commit comments

Comments
 (0)